Checkpoint

This commit is contained in:
2016-06-21 23:13:42 -07:00
parent 156120fbec
commit 3adb7650b4
4 changed files with 183 additions and 50 deletions

View File

@@ -14,6 +14,17 @@ instance Eq Name where
(Name _ _ x _) == (Name _ _ y _) = x == y
(Name _ _ x _) /= (Name _ _ y _) = x /= y
instance Ord Name where
compare (Name _ _ x _) (Name _ _ y _) = compare x y
--
max n1@(Name _ _ x _) n2@(Name _ _ y _) = if x > y then n1 else n2
min n1@(Name _ _ x _) n2@(Name _ _ y _) = if x > y then n2 else n1
--
(Name _ _ x _) < (Name _ _ y _) = x < y
(Name _ _ x _) <= (Name _ _ y _) = x <= y
(Name _ _ x _) >= (Name _ _ y _) = x >= y
(Name _ _ x _) > (Name _ _ y _) = x > y
data Module = Module Name [Declaration]
deriving (Show)
@@ -41,6 +52,14 @@ data Type = TypeUnit Location Kind
| TypeForAll [Name] Type
deriving (Show)
instance Eq Type where
(TypeUnit _ _) == (TypeUnit _ _) = True
(TypePrim _ _ a) == (TypePrim _ _ b) = a == b
(TypeRef _ _ n) == (TypeRef _ _ m) = n == m
(TypeLambda _ _ at et) == (TypeLambda _ _ bt ft) = (at == bt) && (et == ft)
(TypeApp _ _ at bt) == (TypeApp _ _ ct dt) = (at == ct) && (bt == dt)
(TypeForAll ns t) == (TypeForAll ms u) = (ns == ms) && (t == u)
kind :: Type -> Kind
kind (TypeUnit _ k) = k
kind (TypePrim _ k _) = k
@@ -51,4 +70,4 @@ kind (TypeForAll _ t) = kind t
data Kind = Star
| KindArrow Kind Kind
deriving (Show)
deriving (Show, Eq)

View File

@@ -1,77 +1,137 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
module Bang.TypeInfer(typeInfer)
where
import Bang.Syntax.AST
import Data.List(union, nub, concat, intersect)
import Bang.Syntax.AST
import Bang.Syntax.Location(unknownLocation)
import Control.Lens(view, over)
import Control.Lens.TH(makeLenses)
import Data.List(union, nub, concat, intersect)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Text.Lazy(pack)
import MonadLib(StateT, ExceptionT, Id,
StateM(..), ExceptionM(..), RunExceptionM(..),
runStateT, runExceptionT, runId,
get, raise)
type Subst = [(Name, Type)]
-- -----------------------------------------------------------------------------
nullSubst :: Subst
nullSubst = []
type Substitution = Map Name Type
() :: Name -> Type -> Subst
() n t = [(n,t)]
class Types t where
apply :: Substitution -> t -> t
tv :: t -> [Name]
nullSubstitution :: Substitution
nullSubstitution = Map.empty
() :: Name -> Type -> Substitution
() = Map.singleton
infixr 4 @@
(@@) :: Subst -> Subst -> Subst
(@@) s1 s2 = [(u, apply s1 t) | (u,t) <- s2] ++ s1
(@@) :: Substitution -> Substitution -> Substitution
(@@) s1 s2 =
let s2' = Map.map (\ t -> apply s1 t) s1
in Map.union s2' s1
merge :: Monad m => Subst -> Subst -> m Subst
merge s1 s2 | agree = return (s1 ++ s2)
| otherwise = fail "merge failed"
-- -----------------------------------------------------------------------------
data InferenceError = UnificationError Type Type
| OccursCheckFails Name Type
| KindCheckFails Name Type
| MatchFailure Type Type
| UnboundIdentifier Name
| MergeFailure Substitution Substitution
deriving (Show)
data InferenceState = InferenceState {
_istCurrentSubstitution :: Substitution
, _istNextIdentifier :: Word
}
makeLenses ''InferenceState
newtype Infer a = Infer {
unInfer :: StateT InferenceState (ExceptionT InferenceError Id) a
}
deriving (Functor, Applicative, Monad)
instance StateM Infer InferenceState where
get = Infer get
set = Infer . set
instance ExceptionM Infer InferenceError where
raise = Infer . raise
instance RunExceptionM Infer InferenceError where
try m = Infer (try (unInfer m))
-- -----------------------------------------------------------------------------
merge :: Substitution -> Substitution -> Infer Substitution
merge s1 s2 | agree = return (Map.union s1 s2)
| otherwise = raise (MergeFailure s1 s2)
where
names = Map.keys (Map.intersection s1 s2)
agree = all (\ v ->
let refv = TypeReef genLoc Star v
let refv = TypeRef (error "Internal error, TypeInfer") Star v
in apply s1 refv == apply s2 refv)
(map fst s1 `intersect` map fst s2)
names
mostGeneralUnifier :: Monad m => Type -> Type -> m Subst
mostGeneralUnifier :: Type -> Type -> Infer Substitution
mostGeneralUnifier t1 t2 =
case (t1, t2) of
(TypeApp _ _ l r, TypeApp l' r') ->
(TypeApp _ _ l r, TypeApp _ _ l' r') ->
do s1 <- mostGeneralUnifier l l'
s2 <- mostGeneralUnifier (apply s1 r) (apply s1 r')
return (s2 @@ s1)
(TypeRef _ _ u, t) -> varBind u t
(t, TypeRef _ _ u) -> varBind u t
(TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubst
(u@(TypeRef _ _ _), t) -> varBind u t
(t, u@(TypeRef _ _ _)) -> varBind u t
(TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution
(t1, t2) -> raise (UnificationError t1 t2)
varBind :: Monad m => Name -> Type -> m Subst
varBind u t | t == TypeRef _ _ u = return nullSubst
| u `elem` tv t = raise (OccursCheckFails u t)
| kind u /= kind t = raise (KindCheckFails u t)
| otherwise = return (u t)
varBind :: Type -> Type -> Infer Substitution
varBind (TypeRef _ k u) t
| TypeRef _ _ u' <- t, u' == u = return nullSubstitution
| u `elem` tv t = raise (OccursCheckFails u t)
| k /= kind t = raise (KindCheckFails u t)
| otherwise = return (u t)
match :: Monad m => Type -> Type -> m Subst
match :: Type -> Type -> Infer Substitution
match t1 t2 =
case (t1, t2) of
(TypeApp _ _ l r, TypeApp l' r') ->
(TypeApp _ _ l r, TypeApp _ _ l' r') ->
do sl <- match l l'
sr <- match r r'
merge sl sr
(TypeRef _ _ u, t) | kind u == kind t -> return (u t)
(TypePrim tc1, TypePrim tc2) | tc1 == tc2 -> return nullSubst
(TypeRef _ k u, t) | k == kind t -> return (u t)
(TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution
(t1, t2) -> raise (MatchFailure t1 t2)
data Scheme = Forall [Kind] Type
instance Types Scheme where
apply s (Forall ks t) = Forall ks (apply s t)
tv (Forall ks qt) = tv qt
data Assumption = Name :>: Scheme
instance Types Assumption where
apply s (i :>: sc) = i :>: (apply s sc)
tv (i :>: sc) = tv sc
find :: Monad m => Name -> [Assumption] -> m Scheme
find :: Name -> [Assumption] -> Infer Scheme
find i [] = raise (UnboundIdentifier i)
find i ((i' :>: sc) : as) | i == i' = return sc
| otherwise = find i as
class Types t where
apply :: Subst -> t -> t
tv :: t -> [Name]
instance Types Type where
apply s v@(TypeRef _ _ n) = case lookup n s of
apply s v@(TypeRef _ _ n) = case Map.lookup n s of
Just t -> t
Nothing -> v
apply s (TypeApp l k t u) = TypeApp l k (apply s t) (apply s u)
@@ -85,5 +145,59 @@ instance Types [Type] where
apply s = map (apply s)
tv = nub . concat . map tv
typeInfer :: Module -> Either String Module
typeInfer = undefined
getSubstitution :: Infer Substitution
getSubstitution = view istCurrentSubstitution `fmap` get
extendSubstitution :: Substitution -> Infer ()
extendSubstitution s' =
do s <- get
set (over istCurrentSubstitution (s' @@) s)
unify :: Type -> Type -> Infer ()
unify t1 t2 =
do s <- getSubstitution
u <- mostGeneralUnifier (apply s t1) (apply s t2)
extendSubstitution u
gensym :: Kind -> Infer Type
gensym k =
do s <- get
set (over istNextIdentifier (+1) s)
let num = view istNextIdentifier s
str = "gensym:" ++ show num
name = Name unknownLocation TypeEnv num (pack str)
return (TypeRef unknownLocation k name)
data Predicate = IsIn String Type
deriving (Eq)
inferConstant :: ConstantValue -> Infer ([Predicate], Type)
inferConstant c =
do v <- gensym Star
let constraint | ConstantInt _ _ <- c = IsIn "IntLike" v
| ConstantChar _ <- c = IsIn "CharLake" v
| ConstantString _ <- c = IsIn "StringLike" v
| ConstantFloat _ <- c = IsIn "FloatLike" v
return ([constraint], v)
inferExpression :: ClassEnvironment -> [Assumptions] ->
Expression ->
Infer ([Predicate], Type)
inferExpression classEnv assumpts expr =
case expr of
ConstantExp _ cv -> inferConstant cv
ReferenceExp _ n -> do sc <- find n assumpts
(ps :=> t) <- freshInst sc
return (ps, t)
LambdaExp _ n e -> error "FIXME, here"
infer :: Module -> Infer Module
infer = undefined
typeInfer :: Word -> Module -> Either InferenceError Module
typeInfer gensymState mdl =
let inferM = unInfer (infer mdl)
excM = runStateT (InferenceState nullSubstitution gensymState) inferM
idM = runExceptionT excM
resWState = runId idM
in fmap fst resWState