diff --git a/bang.cabal b/bang.cabal index b277228..f40a984 100644 --- a/bang.cabal +++ b/bang.cabal @@ -15,16 +15,16 @@ cabal-version: >= 1.10 executable bang main-is: Main.hs build-depends: - array, - base, - bytestring, - containers, - lens, - llvm-pretty, - monadLib, - optparse-applicative, - pretty, - text + array >= 0.5.1.1 && < 0.9, + base >= 4.7 && < 5.0, + bytestring >= 0.10.6 && < 0.13, + containers >= 0.5.4 && < 0.8, + lens >= 4.14 && < 4.16, + llvm-pretty >= 0.4.0.1 && < 0.8, + monadLib >= 3.7.3 && < 3.9, + optparse-applicative >= 0.12.1.0 && < 0.15, + pretty >= 1.1.3.3 && < 1.4, + text >= 1.2.2.1 && < 1.5 hs-source-dirs: src build-tools: alex, happy ghc-options: -Wall diff --git a/src/Bang/Syntax/AST.hs b/src/Bang/Syntax/AST.hs index 4379785..bc88eb3 100644 --- a/src/Bang/Syntax/AST.hs +++ b/src/Bang/Syntax/AST.hs @@ -14,6 +14,17 @@ instance Eq Name where (Name _ _ x _) == (Name _ _ y _) = x == y (Name _ _ x _) /= (Name _ _ y _) = x /= y +instance Ord Name where + compare (Name _ _ x _) (Name _ _ y _) = compare x y + -- + max n1@(Name _ _ x _) n2@(Name _ _ y _) = if x > y then n1 else n2 + min n1@(Name _ _ x _) n2@(Name _ _ y _) = if x > y then n2 else n1 + -- + (Name _ _ x _) < (Name _ _ y _) = x < y + (Name _ _ x _) <= (Name _ _ y _) = x <= y + (Name _ _ x _) >= (Name _ _ y _) = x >= y + (Name _ _ x _) > (Name _ _ y _) = x > y + data Module = Module Name [Declaration] deriving (Show) @@ -41,6 +52,14 @@ data Type = TypeUnit Location Kind | TypeForAll [Name] Type deriving (Show) +instance Eq Type where + (TypeUnit _ _) == (TypeUnit _ _) = True + (TypePrim _ _ a) == (TypePrim _ _ b) = a == b + (TypeRef _ _ n) == (TypeRef _ _ m) = n == m + (TypeLambda _ _ at et) == (TypeLambda _ _ bt ft) = (at == bt) && (et == ft) + (TypeApp _ _ at bt) == (TypeApp _ _ ct dt) = (at == ct) && (bt == dt) + (TypeForAll ns t) == (TypeForAll ms u) = (ns == ms) && (t == u) + kind :: Type -> Kind kind (TypeUnit _ k) = k kind (TypePrim _ k _) = k @@ -51,4 +70,4 @@ kind (TypeForAll _ t) = kind t data Kind = Star | KindArrow Kind Kind - deriving (Show) + deriving (Show, Eq) diff --git a/src/Bang/TypeInfer.hs b/src/Bang/TypeInfer.hs index 5f368c8..4ab7e5a 100644 --- a/src/Bang/TypeInfer.hs +++ b/src/Bang/TypeInfer.hs @@ -1,77 +1,137 @@ -{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TemplateHaskell #-} module Bang.TypeInfer(typeInfer) where -import Bang.Syntax.AST -import Data.List(union, nub, concat, intersect) +import Bang.Syntax.AST +import Bang.Syntax.Location(unknownLocation) +import Control.Lens(view, over) +import Control.Lens.TH(makeLenses) +import Data.List(union, nub, concat, intersect) +import Data.Map.Strict(Map) +import qualified Data.Map.Strict as Map +import Data.Text.Lazy(pack) +import MonadLib(StateT, ExceptionT, Id, + StateM(..), ExceptionM(..), RunExceptionM(..), + runStateT, runExceptionT, runId, + get, raise) -type Subst = [(Name, Type)] +-- ----------------------------------------------------------------------------- -nullSubst :: Subst -nullSubst = [] +type Substitution = Map Name Type -(⟼) :: Name -> Type -> Subst -(⟼) n t = [(n,t)] +class Types t where + apply :: Substitution -> t -> t + tv :: t -> [Name] + +nullSubstitution :: Substitution +nullSubstitution = Map.empty + +(⟼) :: Name -> Type -> Substitution +(⟼) = Map.singleton infixr 4 @@ -(@@) :: Subst -> Subst -> Subst -(@@) s1 s2 = [(u, apply s1 t) | (u,t) <- s2] ++ s1 +(@@) :: Substitution -> Substitution -> Substitution +(@@) s1 s2 = + let s2' = Map.map (\ t -> apply s1 t) s1 + in Map.union s2' s1 -merge :: Monad m => Subst -> Subst -> m Subst -merge s1 s2 | agree = return (s1 ++ s2) - | otherwise = fail "merge failed" +-- ----------------------------------------------------------------------------- + +data InferenceError = UnificationError Type Type + | OccursCheckFails Name Type + | KindCheckFails Name Type + | MatchFailure Type Type + | UnboundIdentifier Name + | MergeFailure Substitution Substitution + deriving (Show) + +data InferenceState = InferenceState { + _istCurrentSubstitution :: Substitution + , _istNextIdentifier :: Word + } + +makeLenses ''InferenceState + +newtype Infer a = Infer { + unInfer :: StateT InferenceState (ExceptionT InferenceError Id) a + } + deriving (Functor, Applicative, Monad) + +instance StateM Infer InferenceState where + get = Infer get + set = Infer . set + +instance ExceptionM Infer InferenceError where + raise = Infer . raise + +instance RunExceptionM Infer InferenceError where + try m = Infer (try (unInfer m)) + +-- ----------------------------------------------------------------------------- + +merge :: Substitution -> Substitution -> Infer Substitution +merge s1 s2 | agree = return (Map.union s1 s2) + | otherwise = raise (MergeFailure s1 s2) where + names = Map.keys (Map.intersection s1 s2) agree = all (\ v -> - let refv = TypeReef genLoc Star v + let refv = TypeRef (error "Internal error, TypeInfer") Star v in apply s1 refv == apply s2 refv) - (map fst s1 `intersect` map fst s2) + names -mostGeneralUnifier :: Monad m => Type -> Type -> m Subst +mostGeneralUnifier :: Type -> Type -> Infer Substitution mostGeneralUnifier t1 t2 = case (t1, t2) of - (TypeApp _ _ l r, TypeApp l' r') -> + (TypeApp _ _ l r, TypeApp _ _ l' r') -> do s1 <- mostGeneralUnifier l l' s2 <- mostGeneralUnifier (apply s1 r) (apply s1 r') return (s2 @@ s1) - (TypeRef _ _ u, t) -> varBind u t - (t, TypeRef _ _ u) -> varBind u t - (TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubst + (u@(TypeRef _ _ _), t) -> varBind u t + (t, u@(TypeRef _ _ _)) -> varBind u t + (TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution (t1, t2) -> raise (UnificationError t1 t2) -varBind :: Monad m => Name -> Type -> m Subst -varBind u t | t == TypeRef _ _ u = return nullSubst - | u `elem` tv t = raise (OccursCheckFails u t) - | kind u /= kind t = raise (KindCheckFails u t) - | otherwise = return (u ⟼ t) +varBind :: Type -> Type -> Infer Substitution +varBind (TypeRef _ k u) t + | TypeRef _ _ u' <- t, u' == u = return nullSubstitution + | u `elem` tv t = raise (OccursCheckFails u t) + | k /= kind t = raise (KindCheckFails u t) + | otherwise = return (u ⟼ t) -match :: Monad m => Type -> Type -> m Subst +match :: Type -> Type -> Infer Substitution match t1 t2 = case (t1, t2) of - (TypeApp _ _ l r, TypeApp l' r') -> + (TypeApp _ _ l r, TypeApp _ _ l' r') -> do sl <- match l l' sr <- match r r' merge sl sr - (TypeRef _ _ u, t) | kind u == kind t -> return (u ⟼ t) - (TypePrim tc1, TypePrim tc2) | tc1 == tc2 -> return nullSubst + (TypeRef _ k u, t) | k == kind t -> return (u ⟼ t) + (TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution (t1, t2) -> raise (MatchFailure t1 t2) +data Scheme = Forall [Kind] Type + +instance Types Scheme where + apply s (Forall ks t) = Forall ks (apply s t) + tv (Forall ks qt) = tv qt + data Assumption = Name :>: Scheme instance Types Assumption where apply s (i :>: sc) = i :>: (apply s sc) tv (i :>: sc) = tv sc -find :: Monad m => Name -> [Assumption] -> m Scheme +find :: Name -> [Assumption] -> Infer Scheme find i [] = raise (UnboundIdentifier i) find i ((i' :>: sc) : as) | i == i' = return sc | otherwise = find i as -class Types t where - apply :: Subst -> t -> t - tv :: t -> [Name] - instance Types Type where - apply s v@(TypeRef _ _ n) = case lookup n s of + apply s v@(TypeRef _ _ n) = case Map.lookup n s of Just t -> t Nothing -> v apply s (TypeApp l k t u) = TypeApp l k (apply s t) (apply s u) @@ -85,5 +145,59 @@ instance Types [Type] where apply s = map (apply s) tv = nub . concat . map tv -typeInfer :: Module -> Either String Module -typeInfer = undefined +getSubstitution :: Infer Substitution +getSubstitution = view istCurrentSubstitution `fmap` get + +extendSubstitution :: Substitution -> Infer () +extendSubstitution s' = + do s <- get + set (over istCurrentSubstitution (s' @@) s) + +unify :: Type -> Type -> Infer () +unify t1 t2 = + do s <- getSubstitution + u <- mostGeneralUnifier (apply s t1) (apply s t2) + extendSubstitution u + +gensym :: Kind -> Infer Type +gensym k = + do s <- get + set (over istNextIdentifier (+1) s) + let num = view istNextIdentifier s + str = "gensym:" ++ show num + name = Name unknownLocation TypeEnv num (pack str) + return (TypeRef unknownLocation k name) + +data Predicate = IsIn String Type + deriving (Eq) + +inferConstant :: ConstantValue -> Infer ([Predicate], Type) +inferConstant c = + do v <- gensym Star + let constraint | ConstantInt _ _ <- c = IsIn "IntLike" v + | ConstantChar _ <- c = IsIn "CharLake" v + | ConstantString _ <- c = IsIn "StringLike" v + | ConstantFloat _ <- c = IsIn "FloatLike" v + return ([constraint], v) + +inferExpression :: ClassEnvironment -> [Assumptions] -> + Expression -> + Infer ([Predicate], Type) +inferExpression classEnv assumpts expr = + case expr of + ConstantExp _ cv -> inferConstant cv + ReferenceExp _ n -> do sc <- find n assumpts + (ps :=> t) <- freshInst sc + return (ps, t) + LambdaExp _ n e -> error "FIXME, here" + +infer :: Module -> Infer Module +infer = undefined + +typeInfer :: Word -> Module -> Either InferenceError Module +typeInfer gensymState mdl = + let inferM = unInfer (infer mdl) + excM = runStateT (InferenceState nullSubstitution gensymState) inferM + idM = runExceptionT excM + resWState = runId idM + in fmap fst resWState diff --git a/src/Main.hs b/src/Main.hs index f753d5f..e661540 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -47,6 +47,6 @@ withParsed action path body = withInferred :: (Module -> IO ()) -> Module -> IO () withInferred action mdl = - case typeInfer mdl of - Left err -> exit err + case typeInfer 0 mdl of + Left err -> exit (show err) Right mdl' -> action mdl'