Checkpoint
This commit is contained in:
@@ -14,6 +14,17 @@ instance Eq Name where
|
||||
(Name _ _ x _) == (Name _ _ y _) = x == y
|
||||
(Name _ _ x _) /= (Name _ _ y _) = x /= y
|
||||
|
||||
instance Ord Name where
|
||||
compare (Name _ _ x _) (Name _ _ y _) = compare x y
|
||||
--
|
||||
max n1@(Name _ _ x _) n2@(Name _ _ y _) = if x > y then n1 else n2
|
||||
min n1@(Name _ _ x _) n2@(Name _ _ y _) = if x > y then n2 else n1
|
||||
--
|
||||
(Name _ _ x _) < (Name _ _ y _) = x < y
|
||||
(Name _ _ x _) <= (Name _ _ y _) = x <= y
|
||||
(Name _ _ x _) >= (Name _ _ y _) = x >= y
|
||||
(Name _ _ x _) > (Name _ _ y _) = x > y
|
||||
|
||||
data Module = Module Name [Declaration]
|
||||
deriving (Show)
|
||||
|
||||
@@ -41,6 +52,14 @@ data Type = TypeUnit Location Kind
|
||||
| TypeForAll [Name] Type
|
||||
deriving (Show)
|
||||
|
||||
instance Eq Type where
|
||||
(TypeUnit _ _) == (TypeUnit _ _) = True
|
||||
(TypePrim _ _ a) == (TypePrim _ _ b) = a == b
|
||||
(TypeRef _ _ n) == (TypeRef _ _ m) = n == m
|
||||
(TypeLambda _ _ at et) == (TypeLambda _ _ bt ft) = (at == bt) && (et == ft)
|
||||
(TypeApp _ _ at bt) == (TypeApp _ _ ct dt) = (at == ct) && (bt == dt)
|
||||
(TypeForAll ns t) == (TypeForAll ms u) = (ns == ms) && (t == u)
|
||||
|
||||
kind :: Type -> Kind
|
||||
kind (TypeUnit _ k) = k
|
||||
kind (TypePrim _ k _) = k
|
||||
@@ -51,4 +70,4 @@ kind (TypeForAll _ t) = kind t
|
||||
|
||||
data Kind = Star
|
||||
| KindArrow Kind Kind
|
||||
deriving (Show)
|
||||
deriving (Show, Eq)
|
||||
|
||||
@@ -1,77 +1,137 @@
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE DeriveFunctor #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
module Bang.TypeInfer(typeInfer)
|
||||
where
|
||||
|
||||
import Bang.Syntax.AST
|
||||
import Data.List(union, nub, concat, intersect)
|
||||
import Bang.Syntax.AST
|
||||
import Bang.Syntax.Location(unknownLocation)
|
||||
import Control.Lens(view, over)
|
||||
import Control.Lens.TH(makeLenses)
|
||||
import Data.List(union, nub, concat, intersect)
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Text.Lazy(pack)
|
||||
import MonadLib(StateT, ExceptionT, Id,
|
||||
StateM(..), ExceptionM(..), RunExceptionM(..),
|
||||
runStateT, runExceptionT, runId,
|
||||
get, raise)
|
||||
|
||||
type Subst = [(Name, Type)]
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
nullSubst :: Subst
|
||||
nullSubst = []
|
||||
type Substitution = Map Name Type
|
||||
|
||||
(⟼) :: Name -> Type -> Subst
|
||||
(⟼) n t = [(n,t)]
|
||||
class Types t where
|
||||
apply :: Substitution -> t -> t
|
||||
tv :: t -> [Name]
|
||||
|
||||
nullSubstitution :: Substitution
|
||||
nullSubstitution = Map.empty
|
||||
|
||||
(⟼) :: Name -> Type -> Substitution
|
||||
(⟼) = Map.singleton
|
||||
|
||||
infixr 4 @@
|
||||
(@@) :: Subst -> Subst -> Subst
|
||||
(@@) s1 s2 = [(u, apply s1 t) | (u,t) <- s2] ++ s1
|
||||
(@@) :: Substitution -> Substitution -> Substitution
|
||||
(@@) s1 s2 =
|
||||
let s2' = Map.map (\ t -> apply s1 t) s1
|
||||
in Map.union s2' s1
|
||||
|
||||
merge :: Monad m => Subst -> Subst -> m Subst
|
||||
merge s1 s2 | agree = return (s1 ++ s2)
|
||||
| otherwise = fail "merge failed"
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
data InferenceError = UnificationError Type Type
|
||||
| OccursCheckFails Name Type
|
||||
| KindCheckFails Name Type
|
||||
| MatchFailure Type Type
|
||||
| UnboundIdentifier Name
|
||||
| MergeFailure Substitution Substitution
|
||||
deriving (Show)
|
||||
|
||||
data InferenceState = InferenceState {
|
||||
_istCurrentSubstitution :: Substitution
|
||||
, _istNextIdentifier :: Word
|
||||
}
|
||||
|
||||
makeLenses ''InferenceState
|
||||
|
||||
newtype Infer a = Infer {
|
||||
unInfer :: StateT InferenceState (ExceptionT InferenceError Id) a
|
||||
}
|
||||
deriving (Functor, Applicative, Monad)
|
||||
|
||||
instance StateM Infer InferenceState where
|
||||
get = Infer get
|
||||
set = Infer . set
|
||||
|
||||
instance ExceptionM Infer InferenceError where
|
||||
raise = Infer . raise
|
||||
|
||||
instance RunExceptionM Infer InferenceError where
|
||||
try m = Infer (try (unInfer m))
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
merge :: Substitution -> Substitution -> Infer Substitution
|
||||
merge s1 s2 | agree = return (Map.union s1 s2)
|
||||
| otherwise = raise (MergeFailure s1 s2)
|
||||
where
|
||||
names = Map.keys (Map.intersection s1 s2)
|
||||
agree = all (\ v ->
|
||||
let refv = TypeReef genLoc Star v
|
||||
let refv = TypeRef (error "Internal error, TypeInfer") Star v
|
||||
in apply s1 refv == apply s2 refv)
|
||||
(map fst s1 `intersect` map fst s2)
|
||||
names
|
||||
|
||||
mostGeneralUnifier :: Monad m => Type -> Type -> m Subst
|
||||
mostGeneralUnifier :: Type -> Type -> Infer Substitution
|
||||
mostGeneralUnifier t1 t2 =
|
||||
case (t1, t2) of
|
||||
(TypeApp _ _ l r, TypeApp l' r') ->
|
||||
(TypeApp _ _ l r, TypeApp _ _ l' r') ->
|
||||
do s1 <- mostGeneralUnifier l l'
|
||||
s2 <- mostGeneralUnifier (apply s1 r) (apply s1 r')
|
||||
return (s2 @@ s1)
|
||||
(TypeRef _ _ u, t) -> varBind u t
|
||||
(t, TypeRef _ _ u) -> varBind u t
|
||||
(TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubst
|
||||
(u@(TypeRef _ _ _), t) -> varBind u t
|
||||
(t, u@(TypeRef _ _ _)) -> varBind u t
|
||||
(TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution
|
||||
(t1, t2) -> raise (UnificationError t1 t2)
|
||||
|
||||
varBind :: Monad m => Name -> Type -> m Subst
|
||||
varBind u t | t == TypeRef _ _ u = return nullSubst
|
||||
| u `elem` tv t = raise (OccursCheckFails u t)
|
||||
| kind u /= kind t = raise (KindCheckFails u t)
|
||||
| otherwise = return (u ⟼ t)
|
||||
varBind :: Type -> Type -> Infer Substitution
|
||||
varBind (TypeRef _ k u) t
|
||||
| TypeRef _ _ u' <- t, u' == u = return nullSubstitution
|
||||
| u `elem` tv t = raise (OccursCheckFails u t)
|
||||
| k /= kind t = raise (KindCheckFails u t)
|
||||
| otherwise = return (u ⟼ t)
|
||||
|
||||
match :: Monad m => Type -> Type -> m Subst
|
||||
match :: Type -> Type -> Infer Substitution
|
||||
match t1 t2 =
|
||||
case (t1, t2) of
|
||||
(TypeApp _ _ l r, TypeApp l' r') ->
|
||||
(TypeApp _ _ l r, TypeApp _ _ l' r') ->
|
||||
do sl <- match l l'
|
||||
sr <- match r r'
|
||||
merge sl sr
|
||||
(TypeRef _ _ u, t) | kind u == kind t -> return (u ⟼ t)
|
||||
(TypePrim tc1, TypePrim tc2) | tc1 == tc2 -> return nullSubst
|
||||
(TypeRef _ k u, t) | k == kind t -> return (u ⟼ t)
|
||||
(TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution
|
||||
(t1, t2) -> raise (MatchFailure t1 t2)
|
||||
|
||||
data Scheme = Forall [Kind] Type
|
||||
|
||||
instance Types Scheme where
|
||||
apply s (Forall ks t) = Forall ks (apply s t)
|
||||
tv (Forall ks qt) = tv qt
|
||||
|
||||
data Assumption = Name :>: Scheme
|
||||
|
||||
instance Types Assumption where
|
||||
apply s (i :>: sc) = i :>: (apply s sc)
|
||||
tv (i :>: sc) = tv sc
|
||||
|
||||
find :: Monad m => Name -> [Assumption] -> m Scheme
|
||||
find :: Name -> [Assumption] -> Infer Scheme
|
||||
find i [] = raise (UnboundIdentifier i)
|
||||
find i ((i' :>: sc) : as) | i == i' = return sc
|
||||
| otherwise = find i as
|
||||
|
||||
class Types t where
|
||||
apply :: Subst -> t -> t
|
||||
tv :: t -> [Name]
|
||||
|
||||
instance Types Type where
|
||||
apply s v@(TypeRef _ _ n) = case lookup n s of
|
||||
apply s v@(TypeRef _ _ n) = case Map.lookup n s of
|
||||
Just t -> t
|
||||
Nothing -> v
|
||||
apply s (TypeApp l k t u) = TypeApp l k (apply s t) (apply s u)
|
||||
@@ -85,5 +145,59 @@ instance Types [Type] where
|
||||
apply s = map (apply s)
|
||||
tv = nub . concat . map tv
|
||||
|
||||
typeInfer :: Module -> Either String Module
|
||||
typeInfer = undefined
|
||||
getSubstitution :: Infer Substitution
|
||||
getSubstitution = view istCurrentSubstitution `fmap` get
|
||||
|
||||
extendSubstitution :: Substitution -> Infer ()
|
||||
extendSubstitution s' =
|
||||
do s <- get
|
||||
set (over istCurrentSubstitution (s' @@) s)
|
||||
|
||||
unify :: Type -> Type -> Infer ()
|
||||
unify t1 t2 =
|
||||
do s <- getSubstitution
|
||||
u <- mostGeneralUnifier (apply s t1) (apply s t2)
|
||||
extendSubstitution u
|
||||
|
||||
gensym :: Kind -> Infer Type
|
||||
gensym k =
|
||||
do s <- get
|
||||
set (over istNextIdentifier (+1) s)
|
||||
let num = view istNextIdentifier s
|
||||
str = "gensym:" ++ show num
|
||||
name = Name unknownLocation TypeEnv num (pack str)
|
||||
return (TypeRef unknownLocation k name)
|
||||
|
||||
data Predicate = IsIn String Type
|
||||
deriving (Eq)
|
||||
|
||||
inferConstant :: ConstantValue -> Infer ([Predicate], Type)
|
||||
inferConstant c =
|
||||
do v <- gensym Star
|
||||
let constraint | ConstantInt _ _ <- c = IsIn "IntLike" v
|
||||
| ConstantChar _ <- c = IsIn "CharLake" v
|
||||
| ConstantString _ <- c = IsIn "StringLike" v
|
||||
| ConstantFloat _ <- c = IsIn "FloatLike" v
|
||||
return ([constraint], v)
|
||||
|
||||
inferExpression :: ClassEnvironment -> [Assumptions] ->
|
||||
Expression ->
|
||||
Infer ([Predicate], Type)
|
||||
inferExpression classEnv assumpts expr =
|
||||
case expr of
|
||||
ConstantExp _ cv -> inferConstant cv
|
||||
ReferenceExp _ n -> do sc <- find n assumpts
|
||||
(ps :=> t) <- freshInst sc
|
||||
return (ps, t)
|
||||
LambdaExp _ n e -> error "FIXME, here"
|
||||
|
||||
infer :: Module -> Infer Module
|
||||
infer = undefined
|
||||
|
||||
typeInfer :: Word -> Module -> Either InferenceError Module
|
||||
typeInfer gensymState mdl =
|
||||
let inferM = unInfer (infer mdl)
|
||||
excM = runStateT (InferenceState nullSubstitution gensymState) inferM
|
||||
idM = runExceptionT excM
|
||||
resWState = runId idM
|
||||
in fmap fst resWState
|
||||
|
||||
@@ -47,6 +47,6 @@ withParsed action path body =
|
||||
|
||||
withInferred :: (Module -> IO ()) -> Module -> IO ()
|
||||
withInferred action mdl =
|
||||
case typeInfer mdl of
|
||||
Left err -> exit err
|
||||
case typeInfer 0 mdl of
|
||||
Left err -> exit (show err)
|
||||
Right mdl' -> action mdl'
|
||||
|
||||
Reference in New Issue
Block a user