{-
Compositional explanation of types and algorithmic debugging of type errors
author: Olaf Chitil
version: 21.02.2013

Part of the code is derived from Mark Jones' Typing Haskell in Haskell

`Typing Haskell in Haskell' is Copyright (c) Mark P Jones,
and the Oregon Graduate Institute of Science and Technology,
1999-2000, All rights reserved, and is distributed as
free software under the following license.

Redistribution and use in source and binary forms, with or
without modification, are permitted provided that the following
conditions are met:

- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

- Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.

- Neither name of the copyright holders nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND THE
CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR THE
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-}

module Main where

import Data.List(nub, (\\), intersect, union, partition, intersperse)
import Data.Maybe(catMaybes,fromJust,isNothing)
import Data.Char(isAlpha,digitToInt)
import Control.Monad(MonadPlus(..),msum,when)

default ()

-----------------------------------------------------------------------------
-- Conversion to String:
-----------------------------------------------------------------------------

maxIdLength :: Int
maxIdLength = 15

subWidth :: Int
subWidth = 38

-- remove newlines and indentation
singleLine :: String -> String
singleLine "" = ""
singleLine ('\n':xs) = ' ' : removeIdent xs
  where
  removeIdent (' ':xs) = removeIdent xs
  removeIdent ys = singleLine ys
singleLine (x:xs) = x : singleLine xs

fixWidth :: String -> String
fixWidth os 
  | l < subWidth = s ++ replicate (subWidth-l) ' '
  | otherwise     = take (subWidth-6) s ++ " ...  "       
  where
  s = singleLine os
  l = length s

test :: Bool -> String -> String
test b s = if b then s else ""

pparens :: Bool -> String -> String
pparens True x = '(': x ++ ")"
pparens False x = x


ppContext :: Int -> [Pred] -> String
ppContext n [] = ""
ppContext n preds = "\nContext:" ++ replicate n ' ' ++ ppPreds preds


ppStep :: Bool {- with fragment -} -> Bool {- with children -}
       -> Typing -> [Typing] -> String
ppStep f c tyg tygs = 
  (if not f && not c then ppSTyping tyg else ppTyping f tyg) ++ 
    test (c && (not . null $ tygs)) ("\n  because" ++ ppTypings f tygs)

ppTypings :: Bool {- with fragment -} -> [Typing] -> String
ppTypings f tygs@(TyExpr _ _ : _) =
  test qual 
    ("\nContext:    " ++ concatMap (fixWidth . ppPreds . preds) tygs) ++
  "\nExpression: " ++ concatMap (fixWidth . ppExpr 0 0 . exp) tygs ++
  "\nType:       " ++ concatMap (fixWidth . ppType . ty) tygs ++
  ppMonoEnvs 5 (map env tygs)
  where
  exp (TyExpr e _) = e
  preds (TyExpr _ (p :=> _)) = p
  ty (TyExpr _ (_ :=> (_,t))) = t
  env (TyExpr _ (_ :=> (e,_))) = e
  qual = not . null . concatMap preds $ tygs
ppTypings f tygs@(TyAlt _ _ : _) =
  test f ("\nEquation:  " ++ concatMap (fixWidth . ppAlt 6 . alt) tygs) ++
  test qual ("\nContext:   " ++ concatMap (fixWidth . ppPreds . preds) tygs) ++
  ppMonoEnvs 4 (map env tygs)
  where
  alt (TyAlt a _) = a
  preds (TyAlt _ (p :=> _)) = p
  env (TyAlt _ (_ :=> e)) = e  
  qual = not . null . concatMap preds $ tygs
ppTypings f tygs@(TyDef _ _ : _) =
  test f ("\nDefinition: " ++ concatMap (fixWidth . ppDef 6 . def) tygs) ++
  test qual 
    ("\nContext:    " ++ concatMap (fixWidth . ppPreds . preds) tygs) ++
  ppMonoEnvs 5 (map env tygs)
  where
  def (TyDef a _) = a
  preds (TyDef _ (p :=> _)) = p
  env (TyDef _ (_ :=> e)) = e  
  qual = not . null . concatMap preds $ tygs
ppTypings f [TyBindGroup b (p1 :=> (e1,p)), TyExpr e (p2 :=> (e2,t))] =
  -- for let-expression children
  test f ("\nDef./Expr.:   " ++ fixWidth (ppBindGroup 0 b) ++
    fixWidth (ppExpr 0 0 e)) ++
  test qual ("\nContext:      " ++ concatMap (fixWidth . ppPreds) [p1,p2]) ++
  "\nType:         " ++ take subWidth (repeat ' ') ++ fixWidth (ppType t) ++
  ppPolyEnvs 8 [p] ++
  ppMonoEnvs 8 [e1,e2]
  where
  qual = not . null $ (p1 ++ p2)
ppTypings f tygs@(TyBindGroup _ _ : _) =
  test f ("\nDefinitions: " ++ concatMap (fixWidth . ppBindGroup 6 . bg) tygs) ++
  test qual 
    ("\nContext:     " ++ concatMap (fixWidth . ppPreds . preds) tygs) ++
  ppPolyEnvs 8 (map penv tygs) ++
  ppMonoEnvs 8 (map env tygs)
  where
  bg (TyBindGroup b _) = b
  -- something wrong
  bg x = error ("bg: " ++ show x)
  preds (TyBindGroup _ (p :=> _)) = p
  env (TyBindGroup _ (_ :=> (e,_))) = e
  penv (TyBindGroup _ (_ :=> (_,p))) = p  
  qual = not . null . concatMap preds $ tygs
ppTypings _ _ = error "ppTypings" 

ppEnv2 :: Int -> Int -> Int -> [(String,String)] -> String
ppEnv2 l ind idl xs = concatMap line xs
  where
  line (i,t) = '\n' : replicate ind ' ' ++ fixLength l i ++ "  " ++ t

ppMonoEnvs :: Int -> [MonoEnv] -> String
ppMonoEnvs l envs
  | all null envs = ""
  | otherwise = "\nwith " ++ drop 6 (ppEnv2 l 5 maxIdLength (map line ids))
  where
  ids = foldr union [] . map dom $ envs
  line i = (i,concatMap (fixWidth . ppMType i) envs)
  ppMType i env = case find i env of
                    Just ty -> ppType ty
                    Nothing -> "" 

ppPolyEnvs :: Int -> [PolyEnv] -> String
ppPolyEnvs l envs
  | all null envs = ""
  | otherwise = 
      "\nDefining " ++ drop 10 (ppEnv2 l 9 maxIdLength (map line ids))
  where
  ids = foldr union [] . map dom $ envs
  line i = (i,concatMap (fixWidth . ppMType i) envs)
  ppMType i env = 
    case find i env of
      Just (Tree (TyPolyVar polyTy) _) -> ppPolyType polyTy
      Nothing -> "" 

ppSContextS :: [Pred] -> String
ppSContextS preds = test (not . null $ preds) (ppPreds preds ++ " =>")

ppSTyping :: Typing -> String
ppSTyping (TyExpr expr (preds :=> (monoEnv,ty))) = 
  ppSContextS preds ++
  "\n" ++ fixLength l e ++ " :: " ++ ppType ty ++
  ppSMonoEnv l monoEnv
  where
  l = 10 `max` length e
  e = ppExpr 0 0 expr
ppSTyping (TyAlt alt (preds :=> monoEnv)) = 
  ppSContextS preds ++
  ppSMonoEnv 10 monoEnv
ppSTyping (TyDef def (preds :=> monoEnv)) = 
  ppSContextS preds ++
  ppSMonoEnv 10 monoEnv
ppSTyping (TyBindGroup bg (preds :=> (monoEnv,polyEnv))) =
  ppSContextS preds ++
  ppSPolyEnv 10 polyEnv ++
  ppSMonoEnv 10 monoEnv
ppSTyping tyg = ppTyping False tyg 

ppTyping :: Bool {- with fragment -} -> Typing -> String
ppTyping f (TyExpr expr (preds :=> (monoEnv,ty))) = 
  ppContext 4 preds ++
  "\nExpression: " ++ ppExpr 6 0 expr ++
  "\nType:       " ++ ppType ty ++
  ppMonoEnv monoEnv
ppTyping f (TyAlt alt (preds :=> monoEnv)) = 
  test f ("\nEquation: " ++ ppAlt 5 alt) ++ 
  ppContext 2 preds ++
  ppMonoEnv monoEnv
ppTyping f (TyDef def (preds :=> monoEnv)) = 
  test f ("\nDefinition: " ++ ppDef 6 def) ++ 
  ppContext 4 preds ++
  ppMonoEnv monoEnv
ppTyping f (TyBindGroup bg (preds :=> (monoEnv,polyEnv))) = 
  test f ("\nDefinitions: " ++ ppBindGroup 7 bg) ++ 
  ppContext 5 preds ++
  ppPolyEnv polyEnv ++ ppMonoEnv monoEnv
ppTyping f (TyProgram prog polyEnv) =
  "\nWhole Program" ++ test f (ppPolyEnv polyEnv)
ppTyping f (TyPolyVar (preds :=> (monoEnv,t))) =
  "\nPolyVar: " ++ 
  ppContext 1 preds ++
  "\nType:    " ++ ppType t ++ 
  ppMonoEnv monoEnv
ppTyping f (TyUExpr msg expr) = 
  "\nError: " ++ msg ++ ("\nin expression: " ++ ppExpr 8 0 expr) 
ppTyping f (TyUAlt msg alt) = 
  "\nError: " ++ msg ++ ("\nin equation: " ++ ppAlt 7 alt) 
ppTyping f (TyUDef msg def) = 
  "\nError: " ++ msg ++ ("\nin definition: " ++ ppDef 8 def)
ppTyping f (TyUBindGroup msg bg) = 
  "\nError: " ++ msg ++ ("\nin definitions: " ++ ppBindGroup 8 bg) 
ppTyping f (TyUProgram msg prog) =
  "\nError: " ++ msg ++ ("\nin whole program.")

ppKind :: Bool -> Kind -> String

ppKind _ Star = "*"
ppKind b (Kfun k1 k2) = 
  pparens b (ppKind (not b) k1 ++ "->" ++ ppKind False k2)

ppPreds :: [Pred] -> String
ppPreds [] = ""
ppPreds preds = foldr1 (\x xs -> x ++ ", " ++ xs) . map ppPred $ preds

ppPred :: Pred -> String
ppPred (IsIn classId ty) = classId ++ " " ++ ppFType 2 ty

ppType :: Type -> String
ppType = ppFType 0

-- Formatted with fixity, function type is only operator type constructor
ppFType :: Fixity -> Type -> String
ppFType _ (TVar (Tyvar i _)) = i
ppFType _ (TCon (Tycon i _)) = i
ppFType _ (TAp (TCon (Tycon "[]" _)) t) = '[' : ppFType 0 t ++ "]"
ppFType f (TAp (TAp (TCon (Tycon "(->)" _)) t1)  t2) = 
  ppEParens f True $ ppFType 1 t1 ++ "->" ++ ppFType 0 t2 
ppFType f (TAp (TAp (TCon (Tycon "(,)" _)) t1) t2) =
  "(" ++ ppType t1 ++ "," ++ ppType t2 ++ ")"
ppFType f (TAp t1 t2) = 
  ppEParens f False $ ppFType 1 t1 ++ ' ' : ppFType 2 t2
ppFType _ (TGen n) = show n

ppPolyType :: Qual (MonoEnv,Type) -> String
ppPolyType (preds :=> (monoEnv,t)) =
  test (not . null $ preds) (ppPreds preds ++ " => ") ++
  ppType t ++
  test (not . null $ monoEnv) 
    (" | " ++ 
     (foldr1 (\t s -> t ++ ", " ++ s) .
      map (\(i :>: t) -> i ++ " :: " ++ ppType t) $ monoEnv))

ppSPolyEnv :: Int -> PolyEnv -> String
ppSPolyEnv l env = ppSEnv l (map pair env)
  where
  pair (i :>: Tree (TyPolyVar polyTy) _) = 
    (i,ppPolyType polyTy)

ppPolyEnv :: PolyEnv -> String
ppPolyEnv env =
  "\nDefining " ++ drop 10 (ppEnv 9 maxIdLength (map pair env))
  where
  pair (i :>: Tree (TyPolyVar polyTy) _) = 
    (i,ppPolyType polyTy)


ppSMonoEnv :: Int -> MonoEnv -> String
ppSMonoEnv l env = ppSEnv l (map pair env)
  where
  pair (i :>: t) = (i,ppType t)

ppMonoEnv :: MonoEnv -> String
ppMonoEnv env 
  | null env = ""
  | otherwise = "\nwith " ++ drop 6 (ppEnv 5 maxIdLength (map pair env))
  where
  pair (i :>: t) = (i,ppType t)


ppSEnv :: Int -> [(String,String)] -> String
ppSEnv l xs = concatMap line xs
  where
  line (i,t) = '\n' : fixLength l i ++ " :: " ++ t

ppEnv :: Int -> Int -> [(String,String)] -> String
ppEnv ind idl xs = concatMap line xs
  where
  line (i,t) = '\n' : replicate ind ' ' ++ fixLength l i ++ "  " ++ t
  l = idl `min` (foldr max 0 . map length . map fst $ xs)


fixLength :: Int -> String -> String

fixLength l xs 
  | l > length xs = xs ++ replicate (l - length xs) ' '
  | otherwise = take l xs

type Fixity = Int -- 0 no parens, 1 parens around op, 2 always parens

-- Given fixity of context and whether it is an operator, surround by parens
ppEParens :: Fixity -> Bool -> String -> String
ppEParens f True xs | f > 0 = '(' : xs ++ ")"
ppEParens f False xs | f > 1 = '(' : xs ++ ")"
ppEParens _ _ xs = xs

ppLiteral :: Literal -> String
ppLiteral (LitInt i) = show i
ppLiteral (LitChar c) = [c] 

-- Give indentation level and fixity of context
ppExpr :: Int -> Fixity -> Expr -> String
ppExpr _ _ (Var i) | not (isAlpha (head i)) && head i /= '_' = '(' : i ++ ")"
ppExpr _ _ (Var i) = i
ppExpr _ _ (Lit l) = ppLiteral l
ppExpr i f (Let bg e) = 
  ppEParens f False $ "let " ++ ppBindGroup (i+1) bg ++ 
    newIndent (i+1) ++ "in " ++ ppExpr (i+1) 0 e
ppExpr _ _ (Const i) 
  | not (isAlpha (head i)) && head i /= '[' && head i /= '(' = 
    '(' : i ++ ")"
ppExpr _ _ (Const i) = i
ppExpr i _ (Ap (Ap (Const "(,)") e1) e2) =
  "(" ++ ppExpr i 0 e1 ++ ", " ++ ppExpr i 0 e2 ++ ")"
ppExpr i f (Ap (Ap (Const id) e1) e2) 
  | not (isAlpha (head id)) = 
    ppEParens f True $ ppExpr i 1 e1 ++ ' ' : id ++ ' ' : ppExpr i 1 e2
ppExpr i f (Ap (Ap (Var id) e1) e2) 
  | not (isAlpha (head id)) = 
    ppEParens f True $ ppExpr i 1 e1 ++ ' ' : id ++ ' ' : ppExpr i 1 e2
ppExpr i f (Ap e1 e2) = 
  ppEParens f False $ ppExpr i 1 e1 ++ ' ' : ppExpr i 2 e2

ppAlt :: Int -> Alt -> String
ppAlt i (e1,e2) = ppExpr i 0 e1 ++ " = " ++ ppExpr i 0 e2

ppProgram :: Program -> String
ppProgram prog = concat (intersperse "\n" (map (ppBindGroup 0) prog))

ppDef :: Int -> Def -> String
ppDef i alts = concatMap ((newIndent i ++) . ppAlt i) alts

ppBindGroup :: Int -> BindGroup -> String
ppBindGroup i (BG defs) = concat (intersperse "\n" (map (ppDef i) defs))

newIndent :: Int -> String 
newIndent i = '\n' : replicate (2*i) ' '

-----------------------------------------------------------------------------
-- Tree:
-----------------------------------------------------------------------------

data Tree a = Tree a [Tree a]

instance Show a => Show (Tree a) where
  -- show only root element
  show (Tree x _) = show x

instance Types a => Types (Tree a) where
  -- work only on root element
  apply s (Tree e ts) = Tree (apply s e) ts
  tv (Tree s ts) = tv s

data TreeCont a = Empty | TreeCont a [Tree a] (TreeCont a) [Tree a]

treeElem :: Tree a -> a
treeElem (Tree e _) = e

-- dirty trick:

instance Eq a => Eq (Tree a) where
  Tree x _ == Tree y _ = x == y

instance Eq (TreeCont a) where
  _ == _ = True 

type TreePos a = (Tree a,TreeCont a)

-- produce all children TreePoss
children :: TreePos a -> [TreePos a]
children (Tree x ts,cont) = map pick [0..(length ts - 1)]
  where
  pick n = (t,TreeCont x lts cont rts)
    where
    (lts,t:rts) = splitAt n ts

-- parent and sibling TreePoss
parentSibs :: TreePos a -> ([TreePos a],Maybe (TreePos a),[TreePos a])
parentSibs (t,Empty) = ([],Nothing,[])
parentSibs (t,TreeCont x lts cont rts) =
  (map (build . addRts) . splits $ lts
  ,Just (Tree x (lts++t:rts),cont)
  ,map (build . addLts) . splits $ rts)
  where
  splits :: [a] -> [([a],a,[a])]
  splits [] = []
  splits (z:zs) = ([],z,zs) : map (\(us,v,ws)->(z:us,v,ws)) (splits zs)
  addRts (ys,v,ws) = (ys,v,ws++t:rts)
  addLts (ys,v,ws) = (lts++t:ys,v,ws)
  build (ys,v,ws) = (v,TreeCont x ys cont ws)


printStep :: Bool {- with fragment -} -> Bool {- with children -}
          -> Bool {- only show poly vars -}
          -> TreePos Typing -> IO ()
printStep f c v tp@(t@(Tree e _),_) = do
  putStrLn $ 
    ppStep f c e' es'
  where
  (names2,e') = rename2 (abcNames,e)
  (nss,es') = unzip . map rename2 $ zip (names2:nss) es
  es = map (\(Tree tyg _,_) -> tyg) children
  children = getChildren v tp
 

walkTree :: Tree Typing -> IO () 
walkTree tree = walk True True True [] [] (tree,Empty)

walk :: Bool {- with fragment -} -> Bool {- with children -}
     -> Bool {- only show poly vars -} -> Oracle
     -> [(Oracle,TreePos Typing)] -> TreePos Typing -> IO ()
walk f c v oracle history tp@(t@(Tree e _),_) = do
  putStr "\n"
  when errorLocated 
    (putStr (if v then "\nERROR LOCATED! Wrong definition of:" 
                  else "\nERROR LOCATED! Wrong program fragment:"))
  printStep f c v tp
  if errorLocated && v
    then do
      putStr "\nSwitch to detailed level of program fragments."
      walk f c False oracle ((oracle,tp):history) (algNo False oracle tp)
    else do
      putStr (if f || c then "> " 
                        else "Is(are) intended type(s) an instance? (y/n) ")
      choose
  where
  errorLocated = sourceOfError v oracle tp 
  children = getChildren v tp
  choose = do
    l <- getLine
    let i = if length l == 1 then head l else '?'
    case i of
      '?' -> do
               putStrLn $
                 "\nManual navigation:" ++
                 "\nu - up (parent)\nd - down (leftmost child)" ++
                 "\nl - left\nr - right" ++
                 "\nb - back (undo)\ns - start (root node)" ++
                 "\n\nAlgorithmic debugging:" ++
                 "\ny - intended type is an instance" ++
                 "\nn - intended type is not an instance" ++
                 "\nY - not sure about instance, continue as if it is" ++
                 "\nN - not sure about instance, continue as if it is not" ++
                 "\na - amnesia; forget all y/n answers" ++
                 "\n\nToggles:\nf - (don't) show program fragment" ++
                 "\nc - (don't) show children" ++
                 "\nv - detailed (low level) - abstract (high level)" ++
                 "\n\nq - quit"
               walk f c v oracle history tp
      'b' -> if null history 
               then walk f c v oracle history tp
               else walk f c v oracle' history' tp'
             where
             (oracle',tp'):history' = history
      's' -> if null history 
               then walk f c v oracle history tp
               else walk f c v oracle' [] tp'
             where
             (oracle',tp') = last history
      'f' -> do
               putStrLn 
                 (if f then "\nDon't show program fragment." 
                       else "\nShow program fragment.")
               walk (not f) c v oracle history tp
      'c' -> do
               putStrLn 
                 (if c then "\nDon't show children." else "\nShow children.")
               walk f (not c) v oracle history tp              
      'v' -> do
               putStrLn 
                 (if v then "\nShow all typings." 
                       else "\nShow only typings of polymorphic variables.")
               walk f c (not v) oracle history tp              
      'q' -> return ()
      'u' -> ifPossible $ goUp v tp
      'd' -> ifPossible $ goDown v tp
      'l' -> ifPossible $ goLeft v tp
      'r' -> ifPossible $ goRight v tp
      'y' -> let oracle' = addToOracle (e,True) oracle 
             in walk f c v oracle' ((oracle',tp):history) (algYes v oracle' tp)
      'n' -> let oracle' = addToOracle (e,False) oracle 
             in walk f c v oracle' ((oracle',tp):history) (algNo v oracle' tp)
      'Y' -> walk f c v oracle ((oracle,tp):history) (algYes v oracle tp)
      'N' -> walk f c v oracle ((oracle,tp):history) (algNo v oracle tp)
      'a' -> do
               putStrLn "\nAmnesia: forget all y/n answers."
               walk f c v [] history tp
      _ -> let n = digitToInt i in
           if i >= '1' && i <= '9' && n <= length children
             then walk f c v oracle ((oracle,tp):history) (children!!(n-1)) 
             else choose
  ifPossible :: Maybe (TreePos Typing) -> IO ()
  ifPossible (Just tp') = walk f c v oracle ((oracle,tp):history) tp'
  ifPossible Nothing   = do
    putStr "Navigation in this direction impossible.\n> "
    choose

getChildren :: Bool {- only polyvars -} -> TreePos Typing -> [TreePos Typing]
getChildren True  = polyChildren 
getChildren False = children

goLeft :: Bool {- only polyvars -} -> TreePos Typing -> Maybe (TreePos Typing)
goLeft True tp = polyLeft tp
goLeft False (t,TreeCont ec (l@(_:_)) cont' r) =
  Just (last l,TreeCont ec (init l) cont' (t:r))
goLeft _ _ = Nothing

goRight :: Bool {- only polyvars -} -> TreePos Typing -> Maybe (TreePos Typing)
goRight True tp = polyRight tp
goRight False (t,TreeCont ec l cont' (t':ts')) =
  Just (t',TreeCont ec (l++[t]) cont' ts')
goRight _ _ = Nothing

goDown :: Bool {- only polyvars -} -> TreePos Typing -> Maybe (TreePos Typing)
goDown True tp = polyDown tp
goDown False (Tree e (t:ts),cont) = Just (t,TreeCont e [] cont ts)
goDown _ _ = Nothing

goUp :: Bool {- only polyvars -} -> TreePos Typing -> Maybe (TreePos Typing)
goUp True tp = polyUp tp
goUp False (t,TreeCont ec left contcont right) =
  Just (Tree ec (left++[t]++right), contcont)
goUp _ _ = Nothing
     

goNotTrueChild :: Bool {- only polyvars -} 
               -> Oracle -> TreePos Typing -> Maybe (TreePos Typing)
goNotTrueChild v oracle tp = 
  if null notTrue then Nothing else Just (head notTrue)
  where
  children = getChildren v tp
  notTrue = filter (\tp -> lookupOracle tp oracle /= Just True) 
    children


algNo :: Bool {- only polyvars -} 
      -> Oracle -> TreePos Typing -> TreePos Typing
algNo v oracle tp = case goNotTrueChild v oracle tp of
                      Just tp' -> algNext v oracle tp'
                      Nothing -> tp

algYes :: Bool {- only polyvars -} 
       -> Oracle -> TreePos Typing -> TreePos Typing
algYes v oracle tp = case goRight v tp of
                       Just tp' -> algNext v oracle tp'
                       Nothing -> case goUp v tp of
                                    Just tp' -> algNext v oracle tp'
                                    Nothing -> error "altNext: impossible"

algNext :: Bool {- only polyvars -} 
        -> Oracle -> TreePos Typing -> TreePos Typing
algNext v oracle tp =
  case lookupOracle tp oracle of
    Nothing -> tp
    Just True -> algYes v oracle tp
    Just False -> algNo v oracle tp

sourceOfError :: Bool {- only polyvars -} -> Oracle -> TreePos Typing -> Bool
sourceOfError v oracle tp = 
  lookupOracle tp oracle == Just False && 
  isNothing (goNotTrueChild v oracle tp)


-----------------------------------------------------------------------------
-- Oracle:
-----------------------------------------------------------------------------

type Oracle = [(Typing,Bool)]  -- answers to y/n questions

addToOracle :: (Typing,Bool) -> Oracle -> Oracle
addToOracle (tyg,b) oracle = (tyg,b):oracle 

lookupOracle :: TreePos Typing -> Oracle -> Maybe Bool
lookupOracle (Tree (TyExpr (Var _) _) [],_) oracle = Just True  -- simple var
lookupOracle (Tree tyg _,_) oracle 
  | untypable tyg = Just False
  | otherwise = lookup tyg oracle

-----------------------------------------------------------------------------
-- Polymorphic variable children:
-----------------------------------------------------------------------------

polySelfChildren ::TreePos Typing -> [TreePos Typing]
polySelfChildren = go . (:[])

polyChildren :: TreePos Typing  -> [TreePos Typing]
polyChildren = go . children

-- recursively search for poly vars
go :: [TreePos Typing] -> [TreePos Typing]
go [] = []
go (tp@(Tree (TyExpr (Var _) _) (_:_),_) : tps) = tp : go tps
go (tp : tps) = go (children tp ++ go tps)


polyParentSibs :: TreePos Typing 
               -> ([TreePos Typing],Maybe (TreePos Typing),[TreePos Typing])
polyParentSibs tp =
  case mtp of
    Nothing -> ps
    Just (Tree (TyExpr (Var _) _) (_:_),_) -> ps
    Just (Tree tyg _,_) | untypable tyg -> ps
    Just tp' -> let (ltps',mtp',rtps') = polyParentSibs tp'
                in (ltps'++ltps,mtp',rtps++rtps')
  where
  ps@(ltps,mtp,rtps) = parentSibs tp
  
polyUp :: TreePos Typing -> Maybe (TreePos Typing)
polyUp tp = mtp
  where
  (_,mtp,_) = polyParentSibs tp

polyDown :: TreePos Typing -> Maybe (TreePos Typing)
polyDown tp = if null tps then Nothing else Just (head tps)
  where 
  tps = polyChildren tp

polyLeft :: TreePos Typing -> Maybe (TreePos Typing)
polyLeft tp = if null polyLtps then Nothing else Just (last polyLtps)
  where
  polyLtps = concatMap polySelfChildren ltps
  (ltps,_,_) = polyParentSibs tp

polyRight :: TreePos Typing -> Maybe (TreePos Typing)
polyRight tp = if null polyRtps then Nothing else Just (head polyRtps)
  where
  polyRtps = concatMap polySelfChildren rtps
  (_,_,rtps) = polyParentSibs tp


-----------------------------------------------------------------------------
-- Derivation:
-----------------------------------------------------------------------------

data Typing = TyExpr Expr (Qual (MonoEnv,Type))
            | TyAlt Alt (Qual MonoEnv)
            | TyDef Def (Qual MonoEnv)
            | TyBindGroup BindGroup (Qual (MonoEnv, PolyEnv))
            | TyPolyVar (Qual (MonoEnv,Type))
            | TyProgram Program PolyEnv
            | TyUExpr String Expr -- String is error message
            | TyUAlt String Alt
            | TyUDef String Def
            | TyUBindGroup String BindGroup
            | TyUProgram String Program
  deriving Show

instance Eq Typing where
  TyExpr e1 q1 == TyExpr e2 q2 = e1 == e2 && q1 == q2
  TyAlt _ m1 == TyAlt _ m2 = m1 == m2
  TyDef _ m1 == TyDef _ m2 = m1 == m2
  TyAlt _ m1 == TyDef _ m2 = m1 == m2
  TyDef _ m1 == TyAlt _ m2 = m1 == m2
  TyBindGroup _ q1 == TyBindGroup _ q2 = q1 == q2
  TyPolyVar q1 == TyPolyVar q2 = q1 == q2
  TyProgram _ e1 == TyProgram _ e2 = e1 == e2
  TyUExpr s1 e1 == TyUExpr s2 e2 = s1 == s2 && e1 == e2
  TyUAlt s1 e1 == TyUAlt s2 e2 = s1 == s2 && e1 == e2
  TyUDef s1 e1 == TyUDef s2 e2 = s1 == s2 && e1 == e2
  TyUBindGroup s1 e1 == TyUBindGroup s2 e2 = s1 == s2 && e1 == e2
  TyUProgram s1 e1 == TyUProgram s2 e2 = s1 == s2 && e1 == e2
  _ == _ = False

instance Types Typing where
  apply s (TyExpr e qEnvTy) = TyExpr e (apply s qEnvTy)
  apply s (TyAlt alt qEnv) = TyAlt alt (apply s qEnv)
  apply s (TyDef def qEnv) = TyDef def (apply s qEnv)
  apply s (TyBindGroup bg qEnvPolyEnv) = 
    TyBindGroup bg (apply s qEnvPolyEnv)
  apply s (TyPolyVar qEnvTy) = TyPolyVar (apply s qEnvTy) 
  apply s (TyProgram prog qEnvPolyEnv) = TyProgram prog (apply s qEnvPolyEnv)
  apply s other = other

  tv (TyExpr e qEnvTy) = tv qEnvTy
  tv (TyAlt alt qEnv) = tv qEnv
  tv (TyDef def qEnv) = tv qEnv
  tv (TyBindGroup bg qEnvPolyEnv) = tv qEnvPolyEnv 
  tv (TyPolyVar qEnvTy) =  tv qEnvTy
  tv (TyProgram prog qEnvPolyEnv) = tv qEnvPolyEnv
  tv _ = []



-- Rename all free type variables to a,b,c,...
rename :: Types t => t -> t
rename ty = apply freeSubst ty
  where
  freeTyVars = tv ty
  freeKinds = map kind freeTyVars
  newTyVars = map TVar $ zipWith Tyvar abcNames freeKinds
  freeSubst = zip freeTyVars newTyVars
      
-- names for type variables a,b,c ...
abcNames :: [String]
abcNames = map (:[]) ['a'..]

-- Rename all free type variables to given names, returning unused ones
rename2 :: Types t => ([String],t) -> ([String],t)
rename2 (names,ty) = (drop (length freeTyVars) names, apply freeSubst ty)
  where
  freeTyVars = tv ty
  freeKinds = map kind freeTyVars
  newTyVars = map TVar $ zipWith Tyvar names freeKinds
  freeSubst = zip freeTyVars newTyVars
      
 
untypable :: Typing -> Bool
untypable (TyUExpr _ _) = True
untypable (TyUAlt _ _) = True
untypable (TyUDef _ _) = True
untypable (TyUBindGroup _ _) = True
untypable (TyUProgram _ _) = True
untypable _ = False

type Derivation = Tree Typing


-----------------------------------------------------------------------------
-- Id:		Error Monad
-----------------------------------------------------------------------------

data Error a = Correct a | Wrong String

instance Monad Error where
  return = Correct
  (Correct a) >>= f = f a
  (Wrong s) >>= f = Wrong s
  fail = Wrong

instance MonadPlus Error where
  mzero = Wrong "mzero"
  (Wrong _) `mplus` y = y
  x `mplus` _ = x

error2TI :: Error a -> TI String a
error2TI (Correct x) = return x
error2TI (Wrong msg) = errorMsg msg

-----------------------------------------------------------------------------
-- Id:		Identifiers
-----------------------------------------------------------------------------

type Id  = String
enumId  :: Int -> Id
enumId n = "v" ++ show n

-----------------------------------------------------------------------------
-- Kind:		Kinds
-----------------------------------------------------------------------------

data Kind  = Star | Kfun Kind Kind
             deriving (Eq,Show)


-----------------------------------------------------------------------------
-- Type:		Types
-----------------------------------------------------------------------------

data Type  = TVar Tyvar
           | TCon Tycon
           | TAp  Type Type
           | TGen Int
             deriving (Eq,Show)


data Tyvar = Tyvar Id Kind
             deriving (Eq,Show)

data Tycon = Tycon Id Kind
             deriving (Eq,Show)

tUnit    = TCon (Tycon "()" Star)
tChar    = TCon (Tycon "Char" Star)
tInt     = TCon (Tycon "Int" Star)
tInteger = TCon (Tycon "Integer" Star)
tFloat   = TCon (Tycon "Float" Star)
tDouble  = TCon (Tycon "Double" Star)

tList    = TCon (Tycon "[]" (Kfun Star Star))
tArrow   = TCon (Tycon "(->)" (Kfun Star (Kfun Star Star)))
tTuple2  = TCon (Tycon "(,)" (Kfun Star (Kfun Star Star)))

tString    :: Type
tString     = list tChar

infixr      4 `fn`
fn         :: Type -> Type -> Type
a `fn` b    = TAp (TAp tArrow a) b

list       :: Type -> Type
list t      = TAp tList t

pair       :: Type -> Type -> Type
pair a b    = TAp (TAp tTuple2 a) b


class HasKind t where
  kind :: t -> Kind
instance HasKind Tyvar where
  kind (Tyvar v k) = k
instance HasKind Tycon where
  kind (Tycon v k) = k
instance HasKind Type where
  kind (TCon tc) = kind tc
  kind (TVar u)  = kind u
  kind (TAp t _) = case (kind t) of
                     (Kfun _ k) -> k

-----------------------------------------------------------------------------
-- Subst:	Substitutions
-----------------------------------------------------------------------------

type Subst  = [(Tyvar, Type)]

nullSubst  :: Subst
nullSubst   = []

(+->)      :: Tyvar -> Type -> Subst
u +-> t     = [(u, t)]

class Types t where
  apply :: Subst -> t -> t
  tv    :: t -> [Tyvar]

instance Types Type where
  apply s (TVar u)  = case lookup u s of
                       Just t  -> t
                       Nothing -> TVar u
  apply s (TAp l r) = TAp (apply s l) (apply s r)
  apply s t         = t

  tv (TVar u)  = [u]
  tv (TAp l r) = tv l `union` tv r
  tv t         = []

instance Types a => Types [a] where
  apply s = map (apply s)
  tv      = nub . concat . map tv

instance (Types a, Types b) => Types (a,b) where
  apply s (x,y) = (apply s x, apply s y)
  tv (x,y)      = nub (tv x ++ tv y)


infixr 4 @@
(@@)       :: Subst -> Subst -> Subst
s1 @@ s2    = [ (u, apply s1 t) | (u,t) <- s2 ] ++ s1


merge      :: Monad m => Subst -> Subst -> m Subst
merge s1 s2 = if agree then return (s1++s2) else fail "merge fails"
 where agree = all (\v -> apply s1 (TVar v) == apply s2 (TVar v))
                   (map fst s1 `intersect` map fst s2)


-----------------------------------------------------------------------------
-- Unify:	Unification
-----------------------------------------------------------------------------

mgu     :: Monad m => Type -> Type -> m Subst
varBind :: Monad m => Tyvar -> Type -> m Subst

mgu (TAp l r) (TAp l' r') = do s1 <- mgu l l'
                               s2 <- mgu (apply s1 r)
                                         (apply s1 r')
                               return (s2 @@ s1)
mgu (TVar u) t        = varBind u t
mgu t (TVar u)        = varBind u t
mgu (TCon tc1) (TCon tc2) | tc1==tc2 = return nullSubst
mgu (TGen x) (TGen y) = if x /= y then error "TGen" else return nullSubst
mgu t1 t2             = fail "different type constructors cannot be unified"

varBind u t 
  | t == TVar u      = return nullSubst
  | u `elem` tv t    = fail "unification would lead to infinite type"
  | kind u == kind t = return (u +-> t)
  | otherwise        = fail "kinds of types do not agree"


match :: Monad m => Type -> Type -> m Subst
match (TAp l r) (TAp l' r') = do sl <- match l l'
                                 sr <- match r r'
                                 merge sl sr
match (TVar u)   t | kind u == kind t = return (u +-> t)
match (TCon tc1) (TCon tc2)
         | tc1==tc2         = return nullSubst
match t1 t2                 = fail "types do not match"


multiAssumptionEnv :: [MonoEnv] -> Env [Type]

multiAssumptionEnv ass =
  let vars = nub . concatMap dom $ ass 
  in map combineTypes vars

  where
  combineTypes :: Id -> Assump [Type]
  combineTypes i = let scs = catMaybes . map (find i) $ ass in i :>: scs
 

unifyMonoEnvs :: [Qual MonoEnv] -> Error (Qual MonoEnv)
unifyMonoEnvs qEnvs = do
  (env,s) <- unifyMonoEnvs' envs
  return ((nub . apply s . concat $ contexts) :=> env)
  where
  (contexts,envs) = unzip . map (\(pred :=> env)->(pred,env)) $ qEnvs
  

unifyMonoEnvs' :: [MonoEnv] -> Error (MonoEnv,Subst)
unifyMonoEnvs' ass =
  unifyMultiAssumptionEnv (multiAssumptionEnv ass)


unifyMultiAssumptionEnv :: Env [Type] -> Error (MonoEnv,Subst)    
unifyMultiAssumptionEnv [] = return ([],nullSubst)
unifyMultiAssumptionEnv ((i :>: scs) : as) = do
  (sc,s) <- unifyTypes scs
  (as',s') <- unifyMultiAssumptionEnv (apply s as)
  return (apply s' (i :>: sc) : as',s'@@s) 


unifyTypings :: [Qual (MonoEnv,Type)] -> Error (Qual (MonoEnv,Type))
unifyTypings qEnvTys = do
  (t,s) <- unifyTypes ts
  (monoEnv,s2) <- unifyMonoEnvs' (apply s monoEnvs)
  return ((nub . apply (s2@@s) $ concat contexts) :=> (monoEnv,apply s2 t))
  where
  (contexts,monoEnvs,ts) = 
    unzip3 . map (\(preds :=> (monoEnv,ty)) -> (preds,monoEnv,ty)) $ qEnvTys


unify2Types :: Type -> Type -> Error (Type,Subst)
unify2Types t1 t2 = do
  unifier <- mgu t1 t2
  return (apply unifier t1,unifier)


unifyTypes :: [Type] -> Error (Type,Subst)
unifyTypes = unifySeq unify2Types 
 

unifySeq :: Types a => (a -> a -> Error (a,Subst))
                    -> [a] -> Error (a,Subst)
unifySeq unify [] = error "unifySeq of empty list"
unifySeq unify [x] = return (x,nullSubst)
unifySeq unify (x:y:ys) = do
  (z,s) <- unify x y
  (z',s') <- unifySeq unify (z : map (apply s) ys)
  return (z',s'@@s) 

-----------------------------------------------------------------------------
-- Pred:		Predicates
-----------------------------------------------------------------------------

data Qual t = [Pred] :=> t
              deriving (Eq, Show)

data Pred   = IsIn Id Type
              deriving (Eq, Show)

type Context = [Pred]

instance Types t => Types (Qual t) where
  apply s (ps :=> t) = apply s ps :=> apply s t
  tv (ps :=> t)      = tv ps `union` tv t

instance Types Pred where
  apply s (IsIn c t) = IsIn c (apply s t)
  tv (IsIn c t)      = tv t

mguPred, matchPred :: Monad m => Pred -> Pred -> m Subst
mguPred             = lift mgu
matchPred           = lift match

lift :: Monad m => (Type -> Type -> m a) -> Pred -> Pred -> m a
lift m (IsIn i t) (IsIn i' t')
         | i == i'   = m t t'
         | otherwise = fail "classes differ"

type Class    = ([Id], [Inst])  -- superclasses and instances
type Inst     = Qual Pred

-----------------------------------------------------------------------------

data ClassEnv = ClassEnv { classes  :: Id -> Error Class,
                           defaults :: [Type] }

super     :: ClassEnv -> Id -> [Id]
super ce i = case classes ce i of Correct (is, its) -> is

insts     :: ClassEnv -> Id -> [Inst]
insts ce i = case classes ce i of Correct (is, its) -> its

correct :: Error a -> Bool
correct (Correct _) = True
correct (Wrong _) = False

modify       :: ClassEnv -> Id -> Class -> ClassEnv
modify ce i c = ce{classes = \j -> if i==j then return c
                                           else classes ce j}

initialEnv :: ClassEnv
initialEnv  = ClassEnv { classes  = \i -> fail "class not defined",
                         defaults = [tInteger, tDouble] }

type EnvTransformer = ClassEnv -> Error ClassEnv

infixr 5 <:>
(<:>) :: EnvTransformer -> EnvTransformer -> EnvTransformer
f <:> g = \ce -> do 
  ce' <- f ce
  g ce'

addClass                              :: Id -> [Id] -> EnvTransformer
addClass i is ce
 | correct (classes ce i)              = fail "class already defined"
 | any (not . correct . classes ce) is = fail "superclass not defined"
 | otherwise                           = return (modify ce i (is, []))

addPreludeClasses :: EnvTransformer
addPreludeClasses  = addCoreClasses <:> addNumClasses

addCoreClasses ::   EnvTransformer
addCoreClasses  =   addClass "Eq" []
                <:> addClass "Ord" ["Eq"]
                <:> addClass "Show" []
                <:> addClass "Read" []
                <:> addClass "Bounded" []
                <:> addClass "Enum" []
                <:> addClass "Functor" []
                <:> addClass "Monad" []

addNumClasses  ::   EnvTransformer
addNumClasses   =   addClass "Num" ["Eq", "Show"]
                <:> addClass "Real" ["Num", "Ord"]
                <:> addClass "Fractional" ["Num"]
                <:> addClass "Integral" ["Real", "Enum"]
                <:> addClass "RealFrac" ["Real", "Fractional"]
                <:> addClass "Floating" ["Fractional"]
                <:> addClass "RealFloat" ["RealFrac", "Floating"]

addInst                        :: [Pred] -> Pred -> EnvTransformer
addInst ps p@(IsIn i _) ce
 | not (correct (classes ce i)) = fail "no class for instance"
 | any (overlap p) qs           = fail "overlapping instance"
 | otherwise                    = return (modify ce i c)
   where its = insts ce i
         qs  = [ q | (_ :=> q) <- its ]
         c   = (super ce i, (ps:=>p) : its)

overlap       :: Pred -> Pred -> Bool
overlap p q    = correct (mguPred p q)

exampleInsts ::  EnvTransformer
exampleInsts =   addPreludeClasses
             <:> addInst [] (IsIn "Ord" tUnit)
             <:> addInst [] (IsIn "Ord" tChar)
             <:> addInst [] (IsIn "Ord" tInt)
             <:> addInst [IsIn "Ord" (TVar (Tyvar "a" Star)),
                          IsIn "Ord" (TVar (Tyvar "b" Star))]
                         (IsIn "Ord" (pair (TVar (Tyvar "a" Star))
                                           (TVar (Tyvar "b" Star))))
             <:> addInst [] (IsIn "Num" tInt)

-----------------------------------------------------------------------------

bySuper :: ClassEnv -> Pred -> [Pred]
bySuper ce p@(IsIn i t)
 = p : concat [ bySuper ce (IsIn i' t) | i' <- super ce i ]

byInst                   :: MonadPlus m => ClassEnv -> Pred -> m [Pred]
byInst ce p@(IsIn i t)    = msum [ tryInst it | it <- insts ce i ]
 where tryInst (ps :=> h) = do u <- matchPred h p
                               return (map (apply u) ps)

entail        :: ClassEnv -> [Pred] -> Pred -> Bool
entail ce ps p = any (p `elem`) (map (bySuper ce) ps) ||
                 case byInst ce p of
                   Wrong _ -> False
                   Correct qs -> all (entail ce ps) qs

-----------------------------------------------------------------------------

inHnf :: Pred -> Bool
inHnf (IsIn c t) = hnf t
 where hnf (TVar v)  = True
       hnf (TCon tc) = False
       hnf (TAp t _) = hnf t

-- toHnfs :: Monad m => ClassEnv -> [Pred] -> m [Pred]
toHnfs ce ps = do pss <- mapM (toHnf ce) ps
                  return (concat pss)

-- toHnf :: Monad m => ClassEnv -> Pred -> m [Pred]
toHnf ce p | inHnf p   = return [p]
           | otherwise = case byInst ce p of
                           Nothing -> fail "context reduction"
                           Just ps -> toHnfs ce ps

simplify :: ClassEnv -> [Pred] -> [Pred]
simplify ce = loop []
 where loop rs []                            = rs
       loop rs (p:ps) | entail ce (rs++ps) p = loop rs ps
                      | otherwise            = loop (p:rs) ps

reduce      :: [Pred] -> TI String [Pred]
reduce ps = do
  ce <- getClassEnv 
  qs <- error2TI $ toHnfs ce ps
  return (simplify ce qs)

scEntail        :: ClassEnv -> [Pred] -> Pred -> Bool
scEntail ce ps p = any (p `elem`) (map (bySuper ce) ps)


-----------------------------------------------------------------------------
-- Environments:
-----------------------------------------------------------------------------

data Assump t = Id :>: t deriving (Eq,Show)

type Env t = [Assump t]

type MonoEnv = Env Type

type PolyEnv = Env Derivation


instance Types t => Types (Assump t) where
  apply s (i :>: sc) = i :>: (apply s sc)
  tv (i :>: sc)      = tv sc

dom :: Env t -> [Id]
dom as = [ i | (i :>: sc) <- as]

without :: Env t -> [Id] -> Env t
without as is = [ a | a@(i:>:sc) <- as, i `notElem` is] 

find     :: Id -> Env t -> Maybe t
find i as = headMaybe [ sc | (i':>:sc) <- as, i==i' ]

headMaybe :: [a] -> Maybe a
headMaybe [] = Nothing
headMaybe (x:xs) = Just x

-----------------------------------------------------------------------------
-- TIMonad:	Type inference monad
-----------------------------------------------------------------------------

-- changed for Hat
data TI a b = TI (ClassEnv -> Int -> Either a (Int, b))

instance Monad (TI a) where
  return x   = TI (\ce n -> Right (n,x))
  TI c >>= f = TI (\ce n ->
                   case c ce n of
                     Right (m,x) -> let TI fx = f x in fx ce m
                     Left x      -> Left x)

runTI      :: ClassEnv -> TI Derivation Derivation -> Derivation
runTI ce (TI c) = 
  case c ce 0 of
    Right (n,result) -> result
    Left result      -> result


newTVar    :: Kind -> TI a Type
newTVar k   = TI (\ce n ->
                  let v = Tyvar (enumId n) k
                  in  Right (n+1, TVar v))

getClassEnv :: TI a ClassEnv
getClassEnv = TI (\ce n -> Right (n,ce))

errorMsg :: String -> TI String a
errorMsg msg = TI (\ce n -> Left msg)

errorDer :: Derivation -> TI Derivation a
errorDer x = TI (\ce n -> Left x)

updateError :: (b -> c) -> TI b a -> TI c a
updateError f (TI g) = 
  TI (\ce n -> case g ce n of
                 Left b -> Left (f b)
                 Right a -> Right a)

-----------------------------------------------------------------------------
-- TIMain:	Type Inference Algorithm
-----------------------------------------------------------------------------
-----------------------------------------------------------------------------
-- Lit:		Literals
-----------------------------------------------------------------------------

data Literal = LitInt  Integer
             | LitChar Char
  deriving (Eq, Show)

tiLit :: Literal -> TI Derivation Derivation
tiLit l@(LitChar _) = return $ Tree (TyExpr (Lit l) ([] :=> ([],tChar))) []
tiLit l@(LitInt _)  = do
  tvar <- newTVar Star 
  return $ Tree (TyExpr (Lit l) ([IsIn "Num" tvar] :=> ([],tvar))) []


-----------------------------------------------------------------------------

data Expr = Var   Id
          | Lit   Literal
          | Const Id
          | Ap    Expr Expr
          | Let   BindGroup Expr
  deriving (Eq, Show)

infixl `Ap`


substNewTyVars :: Typing -> TI a Typing

substNewTyVars ty = do
  let freeTyVars = tv ty
      freeKinds = map kind freeTyVars
  freeNew <- mapM newTVar freeKinds   
  let freeSubst = zip freeTyVars freeNew 
  return $ apply freeSubst ty   
      
  

tiExpr :: PolyEnv -> Expr -> TI Derivation Derivation

tiExpr polyEnv e@(Var i)
 = case find i polyEnv of
     Just tree@(Tree typing trees) -> do
       TyPolyVar predMonoEnvTy <- substNewTyVars typing
       return $ Tree (TyExpr e predMonoEnvTy) trees
     Nothing -> do
       t <- newTVar Star
       return $ Tree (TyExpr e ([] :=> ([i :>: t],t))) []

tiExpr polyEnv e@(Const i)
 = case find i polyEnv of
     Just (Tree typing _) -> do
       TyPolyVar predMonoEnvTy <- substNewTyVars typing
       return $ Tree (TyExpr e predMonoEnvTy) []
     Nothing -> error ("undefined data constructor " ++ i)

tiExpr polyEnv (Lit l)
 = tiLit l

tiExpr polyEnv e@(Ap e1 e2) = do 
  d1@(Tree (TyExpr _ qty1) _) <- tiExpr polyEnv e1
  d2@(Tree (TyExpr _ (pred2 :=> (monoEnv2,ty2))) _) <- tiExpr polyEnv e2
  t <- newTVar Star
  case unifyTypings [qty1,pred2 :=> (monoEnv2,ty2 `fn` t)] of
    Correct (preds :=> (monoEnv,TAp (TAp tArrow _) tyRes)) -> 
      return $ Tree (TyExpr e (preds :=> (monoEnv,tyRes))) [d1,d2]
    Wrong msg -> errorDer $ Tree (TyUExpr msg e) [d1,d2]

tiExpr polyEnv e@(Let bg e1) = do 
  d1@(Tree (TyBindGroup _ (predBg :=> (usedEnvBg,defEnvBg))) _) 
    <- tiBindGroup polyEnv bg
  d2@(Tree (TyExpr _ qtyE1@(_ :=> (_,tyE1))) _) 
    <- tiExpr (defEnvBg ++ polyEnv) e1
  case unifyTypings [predBg :=> (usedEnvBg,tyE1),qtyE1] of
    Correct qEnvTy -> 
      return $ Tree (TyExpr e qEnvTy) [d1,d2]
    Wrong msg -> errorDer $ Tree (TyUExpr msg e) [d1,d2]

-----------------------------------------------------------------------------

type Alt = (Expr, Expr)

tiAlt :: PolyEnv -> Alt -> TI Derivation Derivation
tiAlt polyEnv (lhs,rhs) = do
  let localVars = vars lhs \\ (defVars lhs)
  let polyEnv' = polyEnv `without` localVars
  d1@(Tree (TyExpr _ qTyLhs) _) 
    <- tiExpr polyEnv' lhs
  d2@(Tree (TyExpr _ qTyRhs) _) 
    <- tiExpr polyEnv' rhs
  case unifyTypings [qTyLhs,qTyRhs] of
    Correct (preds :=> (usedEnv,_)) -> 
      return $ Tree (TyAlt (lhs,rhs) 
                      (preds :=> (usedEnv `without` localVars))) 
                    [d1,d2]
    Wrong msg -> errorDer $ Tree (TyUAlt msg (lhs,rhs)) [d1,d2]



type Def = [Alt]

tiDef :: PolyEnv -> Def -> TI Derivation Derivation
tiDef polyEnv alts = do
  derivations <- mapM (tiAlt polyEnv) alts
  let qUsedEnvs = map getQEnv $ derivations
  case unifyMonoEnvs qUsedEnvs of
    Correct (preds :=> usedEnv) -> 
      return $ Tree (TyDef alts (preds :=> usedEnv)) derivations
    Wrong msg -> errorDer $ Tree (TyUDef msg alts) derivations
  where
  getQEnv (Tree (TyAlt _ qEnv) _ ) = qEnv



vars :: Expr {-Pat-} -> [Id]

vars (Var i) = [i]
vars (Lit _) = []
vars (Const _) = []
vars (Ap e1 e2) = vars e1 ++ vars e2
vars (Let _ _) = error "vars: pattern"
 

class Define a where
  defVars :: a -> [Id]

instance Define Expr {- lhs of function def -} where
  defVars (Var f) = [f]
  defVars (Ap f g) | isConstAp f = defVars f ++ defVars g
  defVars (Ap f g) = defVars f
  defVars (Const _) = []
  defVars _ = error "defVars: not a pattern"

isConstAp :: Expr -> Bool
isConstAp (Const _) = True
isConstAp (Ap f g) = isConstAp f
isConstAp _ = False

instance Define a => Define (a,b) where
  defVars (a,b) = defVars a

instance Define a => Define [a] where
  defVars = concatMap defVars


-----------------------------------------------------------------------------
-- Defaulting:
-----------------------------------------------------------------------------

split :: [Tyvar] -> [Tyvar] -> [Pred] -> TI String ([Pred], [Pred])
split fs gs ps = do 
  ce <- getClassEnv
  ps' <- reduce ps
  let (ds, rs) = partition (all (`elem` fs) . tv) ps'
  rs' <- error2TI $ defaultedPreds ce (fs++gs) rs
  return (ds, rs \\ rs')

type Ambiguity       = (Tyvar, [Pred])

ambiguities         :: ClassEnv -> [Tyvar] -> [Pred] -> [Ambiguity]
ambiguities ce vs ps = [ (v, filter (elem v . tv) ps) | v <- tv ps \\ vs ]

numClasses :: [Id]
numClasses  = ["Num", "Integral", "Floating", "Fractional",
               "Real", "RealFloat", "RealFrac"]

stdClasses :: [Id]
stdClasses  = ["Eq", "Ord", "Show", "Read", "Bounded", "Enum", "Ix",
               "Functor", "Monad", "MonadPlus"] ++ numClasses

candidates           :: ClassEnv -> Ambiguity -> [Type]
candidates ce (v, qs) = [ t' | let is = [ i | IsIn i t <- qs ]
                                   ts = [ t | IsIn i t <- qs ],
                               all ((TVar v)==) ts,
                               any (`elem` numClasses) is,
                               all (`elem` stdClasses) is,
                               t' <- defaults ce,
                               all (entail ce []) [ IsIn i t' | i <- is ] ]

withDefaults :: Monad m => ([Ambiguity] -> [Type] -> a)
                  -> ClassEnv -> [Tyvar] -> [Pred] -> m a
withDefaults f ce vs ps
    | any null tss  = fail "cannot resolve ambiguity"
    | otherwise     = return (f vps (map head tss))
      where vps = ambiguities ce vs ps
            tss = map (candidates ce) vps

defaultedPreds :: Monad m => ClassEnv -> [Tyvar] -> [Pred] -> m [Pred]
defaultedPreds  = withDefaults (\vps ts -> concat (map snd vps))

defaultSubst   :: [Tyvar] -> [Pred] -> PolyEnv -> TI String PolyEnv
defaultSubst vs ps env = do
  ce <- getClassEnv
  s <- withDefaults (\vps ts -> zip (map fst vps) ts) ce vs ps
  return $ apply s env

-----------------------------------------------------------------------------
-- BindGroup
-----------------------------------------------------------------------------

newtype BindGroup = BG [Def] deriving (Eq, Show)


tiBindGroup :: PolyEnv -> BindGroup -> TI Derivation Derivation
tiBindGroup polyEnv bg@(BG defs) = do
  let boundVars = nub . concatMap defVars $ defs
  derivations <- mapM (tiDef (polyEnv `without` boundVars)) defs
  let qEnvs = map getQEnv derivations
  case unifyMonoEnvs qEnvs of
    Wrong msg -> errorDer $ Tree (TyUBindGroup msg bg) derivations
    Correct (defsPreds :=> occEnv) -> do
      let (defsEnv,usedEnv) = 
            partition (\(x :>: _) -> x `elem` boundVars) occEnv
          usedEnvTyVars = tv usedEnv
          defsTyVars = [tv ty | (_ :>: ty) <- defsEnv]
      (ds,rs) <- updateError (\s -> (Tree (TyUBindGroup s bg) derivations))
                  (split usedEnvTyVars (foldr1 intersect defsTyVars) defsPreds)
      let (deferredPreds,retainedPreds) 
             | restricted bg = (ds++rs,[])
             | otherwise     = (ds,rs)

      return $
        Tree (TyBindGroup bg (deferredPreds :=> 
               (usedEnv
               ,[ f :>: Tree (TyPolyVar 
                               (retainedPreds :=> 
                                 (usedEnv,ty)))
                          derivations 
                | f :>: ty <- defsEnv ])))
          derivations
  where
  getQEnv (Tree (TyDef _ qEnv) _) = qEnv
  
restricted   :: BindGroup -> Bool
restricted (BG defs) = any simple defs
  where 
  simple = isVar . fst . head
  isVar (Var _) = True
  isVar _       = False


reduceMonoEnv :: [Tyvar] -> MonoEnv -> MonoEnv
-- Remove unnecessary variables from a monomorphic environment.
-- Not used to keep simple let-rule.
reduceMonoEnv ids monoEnv = filter (hasTyVarsFrom ids) monoEnv
  where
  hasTyVarsFrom :: [Tyvar] -> Assump Type -> Bool
  hasTyVarsFrom ids a = not . null $ ids `intersect` tv a

tiSeq :: (PolyEnv -> term -> TI a Derivation)
      -> PolyEnv -> [term] -> TI a [Derivation]
tiSeq ti polyEnv []
 = return $ []
tiSeq ti polyEnv (bs:bss)
 = do d1@(Tree (TyBindGroup _ (pred :=> (_, polyEnv'))) _)  <- ti polyEnv bs
      derivations <- tiSeq ti (polyEnv' ++ polyEnv) bss
      return $ d1:derivations

-----------------------------------------------------------------------------
-- TIProg:	Type Inference for Whole Programs
-----------------------------------------------------------------------------

type Program = [BindGroup]

tiProgram :: PolyEnv -> Program -> Derivation
tiProgram polyEnv bgs = 
 runTI ((\(Correct x) -> x) (exampleInsts initialEnv)) $ do
  derivations <- tiSeq tiBindGroup polyEnv bgs
  let (preds,defEnv) = case unzip (map getPredsDefEnv derivations) of
                         (predss,defEnvs) -> (concat predss,concat defEnvs)
  remainingPreds <- updateError (\s -> (Tree (TyUProgram s bgs) derivations))
                      (reduce preds)
  defEnv' <- updateError (\s -> (Tree (TyUProgram s bgs) derivations))
               (defaultSubst [] remainingPreds defEnv)
  return $ Tree (TyProgram bgs defEnv') derivations
  where
  getPredsDefEnv (Tree (TyBindGroup _ (preds :=> ([],polyEnv))) _) = 
    (preds,polyEnv)
  getPredsDefEnv (Tree (TyBindGroup _ (preds :=> (monoEnv,polyEnv))) _) =
    error ("Undefined monomorphic variables: " ++ show monoEnv)


-----------------------------------------------------------------------------
-- Main:
-----------------------------------------------------------------------------

inter :: Program -> IO ()
inter prog = do
  putStr "\n\n1 Free navigation through the graph"
  putStr "\n2 Algorithmic debugging"
  putStr "\nq Quit"
  putStr "\nSelect (1 or 2 or q): "
  c <- getLine
  case c of
    "1" -> do
      putStrLn "\nPress ? for help"
      walk True True False [] [] (tiProgram env prog,Empty)
      inter prog
    "2" -> do
      putStrLn "\nPress ? for help"
      let tree = tiProgram env prog
      printStep False False True (tree,Empty)
      walk False False True [] [] (algNo True [] (tree,Empty))
      inter prog
    "q" -> return ()
    _ -> inter prog

main = do 
  putStr "\n\nTypeIlluminator Version 13.09.01\nwritten by Olaf Chitil\n\n"
  putStrLn . concatMap (\(n,p) -> "\n\n" ++ show n ++ ppProgram p) . 
    zip [(1::Int)..] $ progs
  putStr ("\nPlease choose an example program (1-" ++ show numProgs ++ ") ")
  c <- getLine
  let i = digitToInt (head c)
  if c > "0" && c <= "9" && i <= numProgs 
    then inter (progs !! (i-1))
    else main
  where
  numProgs = length progs
  progs = [progStart,prog2,prog3,progPerms,progMkFieldList
          ,progClass1,progClass2,progLarge]

-----------------------------------------------------------------------------
-- Tests:
-----------------------------------------------------------------------------

tyVarA = TVar (Tyvar "a" Star)
tyVarB = TVar (Tyvar "b" Star)
tyVarC = TVar (Tyvar "c" Star)

tIO = TCon (Tycon "IO" (Kfun Star Star))
tBool = TCon (Tycon "Bool" Star)
tExp = TCon (Tycon "Exp" (Kfun Star Star))
tScope = TCon (Tycon "Scope" Star)
tName = TCon (Tycon "Name" (Kfun Star Star))
tSrcSpan = TCon (Tycon "SrcSpan" Star)
tModuleConsts = TCon (Tycon "ModuleConsts" (Kfun Star Star))

consNil = Const "[]"
consCons = Const ":"
consTuple2 = Const "(,)"
consGlobal = Const "Global"

global :: Type -> Tree Typing
-- Construct typing tree for polymorphic variable for which no definition
-- is available; hence ignored on abstract level and by algorithmic debugger.
global t = Tree (TyPolyVar ([] :=> ([],t))) [] 
globalPreds preds t = Tree (TyPolyVar (preds :=> ([],t))) [] 

bogus :: Type -> Tree Typing
-- Construct typing tree for polymorphic variable,
-- pretending there is some definition.
-- Only purpose to see this polymorphic variable on abstract level and 
-- in algorithmic debugger.
bogus t = Tree (TyPolyVar ([] :=> ([],t))) 
            [Tree (TyExpr (Var "undef") ([] :=> ([],t))) []]


env :: PolyEnv
env = 
  ["[]" :>: global (TAp tList tyVarA)
  ,":" :>: global (tyVarA `fn` (TAp tList tyVarA) `fn`  (TAp tList tyVarA))
  ,"False" :>: global tBool
  ,"True" :>: global tBool
  ,"()" :>: global tUnit
  ,"(,)" :>: global (tyVarA `fn` tyVarB `fn` (TAp (TAp tTuple2 tyVarA) tyVarB))
  ,"fst" :>: global ((TAp (TAp tTuple2 tyVarA) tyVarB) `fn` tyVarA)
  ,"snd" :>: global ((TAp (TAp tTuple2 tyVarA) tyVarB) `fn` tyVarB)
  ,"." :>: global ((tyVarA `fn` tyVarB) `fn` (tyVarC `fn` tyVarA) `fn` 
                     (tyVarC `fn` tyVarB))
  ,"$" :>: global ((tyVarA `fn` tyVarB) `fn` tyVarA `fn` tyVarB)
  ,"flip" :>: global ((tyVarA `fn` tyVarB `fn` tyVarC) `fn` 
                       tyVarB `fn` tyVarA `fn` tyVarC)
  ,"print" :>: globalPreds [IsIn "Show" tyVarA] (tyVarA `fn` TAp tIO tUnit)
  ,"putStr" :>: global (tString `fn` TAp tIO tUnit)
  ,"show" :>: globalPreds [IsIn "Show" tyVarA] (tyVarA `fn` tString)
  ,"div" :>: globalPreds [IsIn "Integral" tyVarA] 
               (tyVarA `fn` tyVarA `fn` tyVarA)
  ,"map" :>: global ((tyVarA `fn` tyVarB) `fn` TAp tList tyVarA `fn` TAp tList tyVarB)
  ,"++" :>: global (TAp tList tyVarA `fn` TAp tList tyVarA `fn` TAp tList tyVarA)
  ,"toUpper" :>: global (tChar `fn` tChar)
  ,":" :>: global (tyVarA `fn` TAp tList tyVarA `fn` TAp tList tyVarA)
  ,"[]" :>: global (TAp tList tyVarA)
  ,"+" :>: globalPreds [IsIn "Num" tyVarA] (tyVarA `fn` tyVarA `fn` tyVarA)
  ,"zeroInt" :>: global tInt
  ,"concat" :>: global (TAp tList (TAp tList tyVarA) `fn` TAp tList tyVarA)
  ,"concatMap" :>: global ((tyVarA `fn` TAp tList tyVarB) `fn` TAp tList tyVarA `fn` TAp tList tyVarB)
  ,"Global" :>: global tScope
  ,"mkExpList" :>: bogus (TAp tList (TAp tExp tSrcSpan) `fn` TAp tExp tSrcSpan)
  ,"name2Var" :>: bogus (TAp tName tyVarA `fn` TAp tExp tyVarA)
  ,"nameTraceInfoVar" :>: bogus (tScope `fn` TAp tName tyVarA `fn` TAp tName tyVarA)
  ,"getModuleConsts" :>: bogus (TAp tModuleConsts tyVarA `fn` TAp (TAp tTuple2 (TAp tName tyVarA)) (TAp tList (TAp (TAp tTuple2 (TAp tName tyVarA)) (TAp tList (TAp tName tyVarA)))))
  ]


defHash :: Def
defHash = [(Var "#" `Ap` Var "f" `Ap` Var "x", Var "f" `Ap` Var "x")]

defCopy :: Def
defCopy = [(Var "copy" `Ap` Var "x", Var "x")]

defFoo :: Def
defFoo = [(Var "foo", Var "#" `Ap` copyN 27 `Ap` (Var "+" `Ap` Lit (LitInt 1)) `Ap` Var "zeroInt")]
  where
  copyN :: Int -> Expr
  copyN 1 = Var "copy"
  copyN n = Var "#" `Ap` copyN (n-1) `Ap` Var "copy" 

progLarge = [BG [defHash], BG [defCopy], BG [defFoo]]


defStart :: Def
defStart = [(Var "start" `Ap` Var "xs" `Ap` Var "ys"
           ,((Var ".") `Ap` (Var "map" `Ap` Var "toUpper") `Ap` (Var "++"))
            `Ap` Var "xs" `Ap` Var "ys")]

progStart = [BG [defStart]]


defReverse :: Def
defReverse = [(Ap (Var "reverse") consNil, consNil)
             ,(Ap (Var "reverse") (Ap (Ap consCons (Var "x")) (Var "xs"))
              ,Ap (Ap (Var "++") (Ap (Var "reverse") (Var "xs"))) (Var "x"))] 


defLast :: Def
defLast = [(Ap (Var "last") (Var "xs"), Ap (Var "head") (Ap (Var "reverse") (Var "xs")))]

defInit :: Def
defInit = [(Ap (Var "init") (Var "xs"), Ap (Var "reverse") (Ap (Var "tail") (Ap (Var "reverse") (Var "xs"))))]

defRotateR :: Def
defRotateR = [(Ap (Var "rotateR") (Var "xs"), Ap (Ap consCons (Ap (Var "last") (Var "xs"))) (Ap (Var "init") (Var "xs")))]

defHead :: Def
defHead = [(Ap (Var "head") (Ap (Ap consCons (Var "x")) (Var "xs"))
           ,Var "x")]

defTail :: Def
defTail = [(Ap (Var "tail") (Ap (Ap consCons (Var "x")) (Var "xs"))
           ,Var "xs")]

prog2 :: Program
prog2 = [BG [defHead], BG [defTail], BG [defReverse], BG [defLast], BG [defInit], BG [defRotateR]]

prog3 :: Program
prog3 = [BG [defReverse]]


defClass1 :: Def
defClass1 = [(Var "class1"
          ,((Var ".") `Ap` (Var "print") `Ap` (Var "div")) 
           `Ap` (Lit (LitInt 42)))]

progClass1 = [BG [defClass1]]

defClass2 :: Def
defClass2 = [(Var "class2"
             ,((Var ".") 
              `Ap` ((Var ".") `Ap` (Var "putStr") `Ap` (Var "show"))
              `Ap` ((Var "div") `Ap` (Lit (LitInt 42))))
              `Ap` (Lit (LitInt 2)))]

progClass2 = [BG [defClass2]]

defAddX :: Def
defAddX = [(Var "addX" `Ap` consNil
           ,consCons `Ap` (consCons `Ap` Var "x" `Ap` consNil) `Ap` consNil)
          ,(Var "addX" `Ap` (consCons `Ap` Var "y" `Ap` Var "ys")
           ,consCons `Ap` (consCons `Ap` Var "x" `Ap` (consCons `Ap` Var "y" `Ap` Var "ys")) `Ap` (Var "map" `Ap` (Var "++" `Ap` Var "y") `Ap` (Var "addX" `Ap` Var "ys")))]

defPerms :: Def
defPerms = [(Var "perms" `Ap` consNil
            ,consCons `Ap` consNil `Ap` consNil)
           ,(Var "perms" `Ap` (consCons `Ap` Var "x" `Ap` Var "xs")
            ,Let (BG [defAddX]) (Var "concat" `Ap` (Var "map" `Ap` Var "addX" `Ap` (Var "perms" `Ap` Var "xs"))))]

progPerms :: Program
progPerms = [BG [defPerms]]

defMkFieldList :: Def
defMkFieldList = [(Var "mkFieldList" `Ap` Var "consts"
                  ,Let (BG [[(consTuple2 `Ap` Var "_" `Ap` Var "conss"
                           ,Var "getModuleConsts" `Ap` Var "consts")]])
                       (Var "$" `Ap`
                        (Var "." `Ap` (Var "flip" `Ap` consCons `Ap` consNil) 
                         `Ap` (Var "." `Ap` Var "mkExpList"
                         `Ap` (Var "." `Ap` 
                                (Var "map" `Ap` (Var "." `Ap` Var "name2Var" 
                                `Ap` (Var "nameTraceInfoVar" `Ap` consGlobal)))
                         `Ap` (Var "concatMap" `Ap` Var "snd"))))
                                `Ap` Var "conss"
                        ))]

progMkFieldList :: Program
progMkFieldList = [BG [defMkFieldList]]


-----------------------------------------------------------------------------



