Checkpoint. We're inferring stuff, sorta.
This commit is contained in:
@@ -24,7 +24,7 @@ import Bang.Syntax.Location(Location)
|
||||
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||
import Control.Lens(Lens', view, set, lens)
|
||||
import Control.Lens(makeLenses)
|
||||
import Data.List(delete, union)
|
||||
import Data.Set(delete, union)
|
||||
import Text.PrettyPrint.Annotated(Doc, text, (<+>), ($+$), (<>), empty, space)
|
||||
|
||||
data TypeDeclaration = TypeDeclaration
|
||||
|
||||
@@ -35,6 +35,7 @@ import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||
import Bang.Utils.Pretty(text')
|
||||
import Control.Lens(view)
|
||||
import Control.Lens.TH(makeLenses)
|
||||
import Data.Set(empty, singleton, fromList, (\\))
|
||||
import Data.Text.Lazy(Text)
|
||||
import Text.PrettyPrint.Annotated(Doc, text, hsep, (<>), (<+>))
|
||||
|
||||
@@ -74,7 +75,7 @@ instance MkConstExp Expression where
|
||||
mkConstExp l v = ConstExp (mkConstExp l v)
|
||||
|
||||
instance CanHaveFreeVars ConstantExpression where
|
||||
freeVariables _ = []
|
||||
freeVariables _ = empty
|
||||
|
||||
ppConstantExpression :: ConstantExpression -> Doc a
|
||||
ppConstantExpression = ppConstantValue . _constValue
|
||||
@@ -100,7 +101,7 @@ instance MkRefExp Expression where
|
||||
mkRefExp l n = RefExp (ReferenceExpression l n)
|
||||
|
||||
instance CanHaveFreeVars ReferenceExpression where
|
||||
freeVariables r = [_refName r]
|
||||
freeVariables r = singleton (_refName r)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
@@ -126,8 +127,8 @@ instance MkLambdaExp Expression where
|
||||
mkLambdaExp l a b = LambdaExp (LambdaExpression l a b)
|
||||
|
||||
instance CanHaveFreeVars LambdaExpression where
|
||||
freeVariables l = filter (\ x -> not (x `elem` (_lambdaArgumentNames l)))
|
||||
(freeVariables (_lambdaBody l))
|
||||
freeVariables l = freeVariables (_lambdaBody l) \\
|
||||
fromList (_lambdaArgumentNames l)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ module Bang.AST.Type
|
||||
, FunctionType
|
||||
, ppFunctionType
|
||||
, mkFunType
|
||||
, ftLocation, ftKind, ftArgumentTypes, ftResultType
|
||||
, ftLocation, ftKind, ftArgumentType, ftResultType
|
||||
-- * type application
|
||||
, TypeApplication
|
||||
, ppTypeApplication
|
||||
@@ -36,9 +36,9 @@ import Bang.Syntax.Location(Location)
|
||||
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||
import Bang.Utils.Pretty(text')
|
||||
import Control.Lens.TH(makeLenses)
|
||||
import Data.List(foldl', union)
|
||||
import Data.Set(union, empty, singleton)
|
||||
import Data.Text.Lazy(Text)
|
||||
import Text.PrettyPrint.Annotated(Doc, (<+>), (<>), text, hsep)
|
||||
import Text.PrettyPrint.Annotated(Doc, (<+>), (<>), text)
|
||||
|
||||
data Kind = Star
|
||||
| Unknown
|
||||
@@ -62,7 +62,7 @@ instance Kinded UnitType where
|
||||
kind _ = Star
|
||||
|
||||
instance CanHaveFreeVars UnitType where
|
||||
freeVariables _ = []
|
||||
freeVariables _ = empty
|
||||
|
||||
ppUnitType :: UnitType -> Doc a
|
||||
ppUnitType _ = text "()"
|
||||
@@ -88,7 +88,7 @@ instance MkPrimType Type where
|
||||
mkPrimType l t = TypePrim (PrimitiveType l t)
|
||||
|
||||
instance CanHaveFreeVars PrimitiveType where
|
||||
freeVariables _ = []
|
||||
freeVariables _ = empty
|
||||
|
||||
ppPrimitiveType :: PrimitiveType -> Doc a
|
||||
ppPrimitiveType pt = text "llvm:" <> text' (_ptName pt)
|
||||
@@ -118,20 +118,20 @@ instance MkTypeRef Type where
|
||||
mkTypeRef l k n = TypeRef (ReferenceType l k n)
|
||||
|
||||
instance CanHaveFreeVars ReferenceType where
|
||||
freeVariables r = [_rtName r]
|
||||
freeVariables r = singleton (_rtName r)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
data FunctionType = FunctionType
|
||||
{ _ftLocation :: Location
|
||||
, _ftKind :: Kind
|
||||
, _ftArgumentTypes :: [Type]
|
||||
, _ftResultType :: Type
|
||||
{ _ftLocation :: Location
|
||||
, _ftKind :: Kind
|
||||
, _ftArgumentType :: Type
|
||||
, _ftResultType :: Type
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
class MkFunType a where
|
||||
mkFunType :: Location -> [Type] -> Type -> a
|
||||
mkFunType :: Location -> Type -> Type -> a
|
||||
|
||||
instance MkFunType FunctionType where
|
||||
mkFunType l a r = FunctionType l Star a r
|
||||
@@ -141,16 +141,14 @@ instance MkFunType Type where
|
||||
|
||||
ppFunctionType :: FunctionType -> Doc a
|
||||
ppFunctionType ft =
|
||||
hsep (map ppType (_ftArgumentTypes ft)) <+> text "->" <+>
|
||||
ppType (_ftResultType ft)
|
||||
ppType (_ftArgumentType ft) <+> text "->" <+> ppType (_ftResultType ft)
|
||||
|
||||
instance Kinded FunctionType where
|
||||
kind = _ftKind
|
||||
|
||||
instance CanHaveFreeVars FunctionType where
|
||||
freeVariables ft = foldl' (\ acc x -> acc `union` freeVariables x)
|
||||
(freeVariables (_ftResultType ft))
|
||||
(_ftArgumentTypes ft)
|
||||
freeVariables ft = freeVariables (_ftArgumentType ft) `union`
|
||||
freeVariables (_ftResultType ft)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -162,7 +162,10 @@ err' :: BangError e => e -> Compiler s ()
|
||||
err' e = Compiler (\ st -> runError e False >> return (st, ()))
|
||||
|
||||
runWarning :: BangWarning w => w -> IO ()
|
||||
runWarning = undefined
|
||||
runWarning w = putStrLn (go (ppWarning w))
|
||||
where
|
||||
go (Nothing, doc) = render doc
|
||||
go (Just a, doc) = render (ppLocation a $+$ nest 3 doc)
|
||||
|
||||
runError :: BangError w => w -> Bool -> IO ()
|
||||
runError e die =
|
||||
|
||||
@@ -174,8 +174,8 @@ Type :: { Type }
|
||||
: RawType { $1 }
|
||||
|
||||
RawType :: { Type }
|
||||
: RawType '->' BaseType { mkFunType $2 [$1] $3 }
|
||||
| BaseType { $1 }
|
||||
: RawType '->' BaseType { mkFunType $2 $1 $3 }
|
||||
| BaseType { $1 }
|
||||
|
||||
BaseType :: { Type }
|
||||
: TypeIdent {%
|
||||
|
||||
@@ -13,7 +13,7 @@ import Bang.AST.Declaration(Declaration(..), declName,
|
||||
import Bang.AST.Expression(Expression(..), isEmptyExpression, refName,
|
||||
lambdaArgumentNames, lambdaBody,
|
||||
isEmptyExpression)
|
||||
import Bang.AST.Type(Type(..), rtName, ftArgumentTypes, ftResultType,
|
||||
import Bang.AST.Type(Type(..), rtName, ftArgumentType, ftResultType,
|
||||
taLeftType, taRightType)
|
||||
import Bang.Monad(Compiler, BangError(..), err, err', registerName)
|
||||
import Bang.Syntax.Location(Location, ppLocation)
|
||||
@@ -26,6 +26,7 @@ import Data.Graph(SCC(..))
|
||||
import Data.Graph.SCC(stronglyConnComp)
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Set(toList)
|
||||
import Data.Text.Lazy(uncons)
|
||||
import Text.PrettyPrint.Annotated(text, ($+$), (<+>), nest, quotes)
|
||||
|
||||
@@ -115,19 +116,15 @@ linkNames decls =
|
||||
let t' = set rtName name t
|
||||
return (TypeRef t', nameMap')
|
||||
linkType nameMap (TypeFun t) =
|
||||
do (argTypes, nameMap') <- foldM linkTypes ([], nameMap) (view ftArgumentTypes t)
|
||||
(resType, nameMap'') <- linkType nameMap' (view ftResultType t)
|
||||
return (TypeFun (set ftArgumentTypes argTypes $
|
||||
set ftResultType resType t),
|
||||
do (argType, nameMap') <- linkType nameMap (view ftArgumentType t)
|
||||
(resType, nameMap'') <- linkType nameMap' (view ftResultType t)
|
||||
return (TypeFun (set ftArgumentType argType $
|
||||
set ftResultType resType t),
|
||||
nameMap'')
|
||||
linkType nameMap (TypeApp t) =
|
||||
do (lt, nameMap') <- linkType nameMap (view taLeftType t)
|
||||
(rt, nameMap'') <- linkType nameMap' (view taRightType t)
|
||||
return (TypeApp (set taLeftType lt (set taRightType rt t)), nameMap'')
|
||||
--
|
||||
linkTypes (acc, nameMap) argType =
|
||||
do (argType', nameMap') <- linkType nameMap argType
|
||||
return (acc ++ [argType'], nameMap')
|
||||
--
|
||||
linkExpr _ x | isEmptyExpression x = return x
|
||||
linkExpr _ x@(ConstExp _) = return x
|
||||
@@ -217,4 +214,4 @@ orderDecls decls = map unSCC (stronglyConnComp nodes)
|
||||
unSCC (CyclicSCC xs) = xs
|
||||
--
|
||||
nodes = map tuplify decls
|
||||
tuplify d = (d, view declName d, freeVariables d)
|
||||
tuplify d = (d, view declName d, toList (freeVariables d))
|
||||
|
||||
@@ -1,174 +1,269 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeSynonymInstances #-}
|
||||
module Bang.TypeInfer(runTypeInference)
|
||||
where
|
||||
|
||||
import Bang.AST(Module)
|
||||
import Bang.Monad(Compiler)
|
||||
|
||||
runTypeInference :: Module -> Compiler ps Module
|
||||
runTypeInference x = return x
|
||||
|
||||
{- Better version
|
||||
import Bang.Monad(Compiler, BangError(..), err,
|
||||
runPass, getPassState, setPassState,
|
||||
viewPassState, overPassState,
|
||||
registerNewName, genName)
|
||||
import Bang.Syntax.Location(Location, unknownLocation)
|
||||
import Bang.Syntax.ParserMonad(NameDatabase(..))
|
||||
import Bang.Utils.Pretty(BangDoc)
|
||||
import Control.Lens(set, view, over)
|
||||
import Control.Lens.TH(makeLenses)
|
||||
import Data.List(union, nub, concat)
|
||||
import Bang.AST(Module, moduleDeclarations)
|
||||
import Bang.AST.Declaration(Declaration(..), ValueDeclaration,
|
||||
vdName, vdDeclaredType, vdValue,
|
||||
tdName, tdType)
|
||||
import Bang.AST.Expression(Expression(..), ConstantValue(..),
|
||||
lambdaArgumentNames, lambdaBody,
|
||||
constLocation, constValue, refName)
|
||||
import Bang.AST.Name(Name, NameEnvironment(..),
|
||||
nameLocation, nameText, ppName)
|
||||
import Bang.AST.Type(Type(..), ppType, rtName, ftArgumentType,
|
||||
ftResultType, taLeftType, taRightType,
|
||||
mkPrimType, mkFunType, mkTypeRef,
|
||||
Kind(..))
|
||||
import Bang.Monad(Compiler, BangError(..), BangWarning(..),
|
||||
registerNewName, err', err, warn,
|
||||
getPassState, mapPassState, runPass)
|
||||
import Bang.Syntax.Location(Location, fakeLocation)
|
||||
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||
import Bang.Utils.Pretty(BangDoc, text')
|
||||
import Control.Lens(view, over)
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Set(Set)
|
||||
import Data.Set(Set, (\\))
|
||||
import qualified Data.Set as Set
|
||||
import Data.Text.Lazy(Text, pack)
|
||||
import Text.PrettyPrint.Annotated(text, (<+>), quotes)
|
||||
import Text.PrettyPrint.Annotated(text, nest, quotes, ($+$), (<+>))
|
||||
|
||||
data InferenceState = InferenceState {
|
||||
_nameDatabase :: NameDatabase
|
||||
}
|
||||
|
||||
makeLenses ''InferenceState
|
||||
|
||||
type Infer a = Compiler InferenceState a
|
||||
|
||||
runInfer :: NameDatabase -> Infer a -> Compiler ps a
|
||||
runInfer ndb action = snd `fmap` runPass initial action
|
||||
where initial = InferenceState ndb
|
||||
runTypeInference :: Module -> Compiler ps Module
|
||||
runTypeInference x =
|
||||
do _ <- runPass emptyEnvironment (mapM_ typeInferDecls (view moduleDeclarations x))
|
||||
return x
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
data InferError = UnboundVariable Location Name
|
||||
deriving (Show)
|
||||
type Infer a = Compiler TypeEnvironment a
|
||||
|
||||
instance BangError InferError where
|
||||
ppError = prettyError
|
||||
getNamesTypeScheme :: Name -> Infer (Maybe Scheme)
|
||||
getNamesTypeScheme n = Map.lookup n `fmap` getPassState
|
||||
|
||||
prettyError :: InferError -> (Maybe Location, BangDoc)
|
||||
prettyError e =
|
||||
case e of
|
||||
UnboundVariable l n ->
|
||||
(Just l, text "Unbound variable '" <+> quotes (ppName n))
|
||||
addToTypeEnvironment :: [Name] -> [Scheme] -> Infer ()
|
||||
addToTypeEnvironment ns schms = mapPassState (add ns schms)
|
||||
where
|
||||
add :: [Name] -> [Scheme] -> TypeEnvironment -> TypeEnvironment
|
||||
add [] [] acc = acc
|
||||
add (n:restns) (s:rschms) acc =
|
||||
Map.insertWithKey errorFn n s (add restns rschms acc)
|
||||
add _ _ _ =
|
||||
error "Wackiness has insued."
|
||||
--
|
||||
errorFn k _ _ = error ("Redefinition of " ++ show k)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
type Substitutions = Map Name Type
|
||||
type Substitution = Map Name Type
|
||||
|
||||
noSubstitutions :: Substitutions
|
||||
noSubstitutions = Map.empty
|
||||
nullSubstitution :: Substitution
|
||||
nullSubstitution = Map.empty
|
||||
|
||||
composeSubstitutions :: Substitutions -> Substitutions -> Substitutions
|
||||
composeSubstitutions :: Substitution -> Substitution -> Substitution
|
||||
composeSubstitutions s1 s2 = Map.map (apply s1) s2 `Map.union` s1
|
||||
|
||||
class Types a where
|
||||
freeTypeVariables :: a -> Set Name
|
||||
apply :: Substitutions -> a -> a
|
||||
class ApplySubst t where
|
||||
apply :: Substitution -> t -> t
|
||||
|
||||
instance Types Type where
|
||||
freeTypeVariables t =
|
||||
case t of
|
||||
TypeUnit _ _ -> Set.empty
|
||||
TypePrim _ _ _ -> Set.empty
|
||||
TypeRef _ _ n -> Set.singleton n
|
||||
TypeLambda _ _ ns e -> Set.unions (map freeTypeVariables ns) `Set.union`
|
||||
freeTypeVariables e
|
||||
TypeApp _ _ a b -> freeTypeVariables a `Set.union` freeTypeVariables b
|
||||
apply substs t =
|
||||
case t of
|
||||
TypeRef _ _ n -> case Map.lookup n substs of
|
||||
Nothing -> t
|
||||
Just t' -> t'
|
||||
TypeLambda l k ns e -> TypeLambda l k (apply substs ns) (apply substs e)
|
||||
TypeApp l k a b -> TypeApp l k (apply substs a) (apply substs b)
|
||||
_ -> t
|
||||
instance ApplySubst Type where
|
||||
apply s (TypeUnit t) = TypeUnit t
|
||||
apply s (TypePrim t) = TypePrim t
|
||||
apply s (TypeRef t) = case Map.lookup (view rtName t) s of
|
||||
Nothing -> TypeRef t
|
||||
Just t' -> t'
|
||||
apply s (TypeFun t) = TypeFun (over ftArgumentType (apply s) $
|
||||
over ftResultType (apply s) t)
|
||||
apply s (TypeApp t) = TypeApp (over taLeftType (apply s) $
|
||||
over taRightType (apply s) t)
|
||||
|
||||
instance Types a => Types [a] where
|
||||
freeTypeVariables l = Set.unions (map freeTypeVariables l)
|
||||
apply s = map (apply s)
|
||||
instance ApplySubst a => ApplySubst [a] where
|
||||
apply s = map (apply s)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
inferModule :: Module -> Infer Module
|
||||
inferModule = undefined
|
||||
data Scheme = Scheme [Name] Type
|
||||
|
||||
instance CanHaveFreeVars Scheme where
|
||||
freeVariables (Scheme ns t) = freeVariables t \\ Set.fromList ns
|
||||
|
||||
instance ApplySubst Scheme where
|
||||
apply s (Scheme vars t) = Scheme vars (apply s t)
|
||||
|
||||
newTypeVar :: Name -> Infer Type
|
||||
newTypeVar n =
|
||||
do let loc = view nameLocation n
|
||||
n' <- registerNewName TypeEnv (view nameText n)
|
||||
return (mkTypeRef loc Unknown n')
|
||||
|
||||
instantiate :: Scheme -> Infer Type
|
||||
instantiate (Scheme vars t) =
|
||||
do refs <- mapM newTypeVar vars
|
||||
let newSubsts = Map.fromList (zip vars refs)
|
||||
return (apply newSubsts t)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
mostGeneralUnifier :: Type -> Type -> Infer Substitution
|
||||
mostGeneralUnifier a b =
|
||||
case (a, b) of
|
||||
(TypeUnit _, TypeUnit _) -> return nullSubstitution
|
||||
(TypePrim _, TypePrim _) -> return nullSubstitution
|
||||
(TypeRef t1, t2) -> varBind (view rtName t1) t2
|
||||
(t2, TypeRef t1) -> varBind (view rtName t1) t2
|
||||
(TypeFun t1, TypeFun t2) -> do let at1 = view ftArgumentType t1
|
||||
at2 = view ftArgumentType t2
|
||||
s1 <- mostGeneralUnifier at1 at2
|
||||
let rt1 = apply s1 (view ftResultType t1)
|
||||
rt2 = apply s1 (view ftResultType t2)
|
||||
s2 <- mostGeneralUnifier rt1 rt2
|
||||
return (s1 `composeSubstitutions` s2)
|
||||
(TypeApp t1, TypeApp t2) -> do let lt1 = view taLeftType t1
|
||||
lt2 = view taLeftType t2
|
||||
s1 <- mostGeneralUnifier lt1 lt2
|
||||
let rt1 = apply s1 (view taRightType t1)
|
||||
rt2 = apply s1 (view taRightType t2)
|
||||
s2 <- mostGeneralUnifier rt1 rt2
|
||||
return (s1 `composeSubstitutions` s2)
|
||||
_ -> do err' (TypesDontUnify a b)
|
||||
return nullSubstitution
|
||||
|
||||
varBind :: Name -> Type -> Infer Substitution
|
||||
varBind u t | TypeRef t' <- t,
|
||||
view rtName t' == u = return nullSubstitution
|
||||
| u `Set.member` freeVariables t = do err' (OccursFail u t)
|
||||
return nullSubstitution
|
||||
| otherwise = return (Map.singleton u t)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
type TypeEnvironment = Map Name Scheme
|
||||
|
||||
emptyEnvironment :: TypeEnvironment
|
||||
emptyEnvironment = Map.empty
|
||||
|
||||
instance ApplySubst TypeEnvironment where
|
||||
apply s tenv = Map.map (apply s) tenv
|
||||
|
||||
instance CanHaveFreeVars TypeEnvironment where
|
||||
freeVariables tenv = freeVariables (Map.elems tenv)
|
||||
|
||||
generalize :: TypeEnvironment -> Type -> Scheme
|
||||
generalize env t = Scheme vars t
|
||||
where vars = Set.toList (freeVariables t \\ freeVariables env)
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
data InferenceError = InternalError
|
||||
| TypesDontUnify Type Type
|
||||
| OccursFail Name Type
|
||||
| UnboundVariable Name
|
||||
|
||||
instance BangError InferenceError where
|
||||
ppError = prettyError
|
||||
|
||||
prettyError :: InferenceError -> (Maybe Location, BangDoc)
|
||||
prettyError e =
|
||||
case e of
|
||||
InternalError ->
|
||||
(Nothing, text "<internal error>")
|
||||
TypesDontUnify t1 t2 ->
|
||||
(Nothing, text "Types don't unify:" $+$
|
||||
(nest 3
|
||||
(text "first type: " <+> ppType t1 $+$
|
||||
text "second type: " <+> ppType t2)))
|
||||
OccursFail n t ->
|
||||
(Just (view nameLocation n),
|
||||
text "Occurs check failed:" $+$
|
||||
(nest 3 (text "Type: " <+> ppType t)))
|
||||
UnboundVariable n ->
|
||||
(Just (view nameLocation n),
|
||||
text "Unbound variable (in type checker?):" <+> ppName n)
|
||||
|
||||
data InferenceWarning = TopLevelWithoutType Name Type
|
||||
| DeclarationMismatch Name Type Type
|
||||
|
||||
instance BangWarning InferenceWarning where
|
||||
ppWarning = prettyWarning
|
||||
|
||||
prettyWarning :: InferenceWarning -> (Maybe Location, BangDoc)
|
||||
prettyWarning w =
|
||||
case w of
|
||||
TopLevelWithoutType n t ->
|
||||
(Just (view nameLocation n),
|
||||
text "Variable" <+> quotes (text' (view nameText n)) <+>
|
||||
text "is defined without a type." $+$
|
||||
text "Inferred type:" $+$ nest 3 (ppType t))
|
||||
DeclarationMismatch n dt it ->
|
||||
(Just (view nameLocation n),
|
||||
text "Mismatch between declared and inferred type of" <+>
|
||||
quotes (text' (view nameText n)) $+$
|
||||
nest 3 (text "declared type:" <+> ppType dt $+$
|
||||
text "inferred type:" <+> ppType it))
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
-- Infer the type of a group of declarations with cyclic dependencies.
|
||||
typeInferDecls :: [Declaration] -> Infer ()
|
||||
typeInferDecls decls =
|
||||
do (names, schemes, decls') <- getInitialSchemes decls
|
||||
addToTypeEnvironment names schemes
|
||||
mapM_ typeInferDecl decls'
|
||||
where
|
||||
getInitialSchemes [] =
|
||||
return ([], [], [])
|
||||
getInitialSchemes ((DeclType td) : rest) =
|
||||
do (rn, rs, rd) <- getInitialSchemes rest
|
||||
let n = view tdName td
|
||||
s = Scheme [] (view tdType td)
|
||||
return (n:rn, s:rs, rd)
|
||||
getInitialSchemes ((DeclVal td) : rest) =
|
||||
do (rn, rs, rd) <- getInitialSchemes rest
|
||||
return (rn, rs, (td : rd))
|
||||
|
||||
typeInferDecl :: ValueDeclaration -> Infer ()
|
||||
typeInferDecl vd =
|
||||
do (subs, t) <- typeInferExpr (view vdValue vd)
|
||||
let t' = apply subs t
|
||||
case view vdDeclaredType vd of
|
||||
Nothing ->
|
||||
warn (TopLevelWithoutType (view vdName vd) t')
|
||||
Just dt ->
|
||||
warn (DeclarationMismatch (view vdName vd) dt t)
|
||||
|
||||
typeInferConst :: Location -> ConstantValue ->
|
||||
Infer (Substitution, Type)
|
||||
typeInferConst l (ConstantInt _ _) =
|
||||
return (nullSubstitution, mkPrimType l "i64")
|
||||
typeInferConst l (ConstantChar _) =
|
||||
return (nullSubstitution, mkPrimType l "i8") -- FIXME
|
||||
typeInferConst l (ConstantString _) =
|
||||
return (nullSubstitution, mkPrimType l "i8*") -- FIXME
|
||||
typeInferConst l (ConstantFloat _) =
|
||||
return (nullSubstitution, mkPrimType l "double")
|
||||
|
||||
typeInferExpr :: Expression -> Infer (Substitution, Type)
|
||||
typeInferExpr expr =
|
||||
case expr of
|
||||
ConstExp e ->
|
||||
typeInferConst (view constLocation e) (view constValue e)
|
||||
RefExp e ->
|
||||
do mscheme <- getNamesTypeScheme (view refName e)
|
||||
case mscheme of
|
||||
Nothing -> err (UnboundVariable (view refName e))
|
||||
Just scheme -> do t <- instantiate scheme
|
||||
return (nullSubstitution, t)
|
||||
LambdaExp e ->
|
||||
do let argNames = view lambdaArgumentNames e
|
||||
tvars <- mapM newTypeVar argNames
|
||||
let tvars' = map (Scheme []) tvars
|
||||
addToTypeEnvironment argNames tvars'
|
||||
(s1, t1) <- typeInferExpr (view lambdaBody e)
|
||||
return (s1, mkFunType' (apply s1 tvars) t1)
|
||||
where
|
||||
mkFunType' [] t = t
|
||||
mkFunType' (x:rest) t = mkFunType fakeLocation x (mkFunType' rest t)
|
||||
|
||||
|
||||
runTypeInference :: NameDatabase -> Module -> Compiler ps Module
|
||||
runTypeInference ndb mod = runInfer ndb (inferModule mod)
|
||||
-}
|
||||
-- data Scheme = Scheme [Name] Type
|
||||
--
|
||||
-- getName :: NameEnvironment -> Text -> Infer Name
|
||||
-- getName env nameText =
|
||||
-- do namedb <- viewPassState nameDatabase
|
||||
-- let key = (env, nameText)
|
||||
-- case Map.lookup key namedb of
|
||||
-- Nothing ->
|
||||
-- do name <- registerNewName env nameText
|
||||
-- overPassState (set nameDatabase (Map.insert key name namedb))
|
||||
-- return name
|
||||
-- Just name ->
|
||||
-- return name
|
||||
--
|
||||
-- runTypeInference :: NameDatabase -> Module -> Compiler ps Module
|
||||
-- runTypeInference nameDB mod =
|
||||
-- snd `fmap` (runPass initialState (inferModule mod))
|
||||
-- where initialState = InferenceState nameDB
|
||||
--
|
||||
-- type Substitutions = Map Name Type
|
||||
--
|
||||
-- nullSubst :: Substitutions
|
||||
-- nullSubst = Map.empty
|
||||
--
|
||||
-- type TypeEnv = Map Name Scheme
|
||||
--
|
||||
-- class Substitutable a where
|
||||
-- apply :: Substitutions -> a -> Type
|
||||
--
|
||||
-- instance Substitutable Type where
|
||||
-- apply subs t =
|
||||
-- case t of
|
||||
-- TypeUnit _ _ -> t
|
||||
-- TypePrim _ _ _ -> t
|
||||
-- TypeRef _ _ n -> case Map.lookup n subs of
|
||||
-- Nothing -> t
|
||||
-- Just t' -> t'
|
||||
-- TypeLambda l k ats bt ->
|
||||
-- TypeLambda l k (map (apply subs) ats) (apply subs bt)
|
||||
-- TypeApp l k a b ->
|
||||
-- TypeApp l k (apply subs a) (apply subs b)
|
||||
-- TypeForAll ns t ->
|
||||
-- TypeForAll ns (apply subs t)
|
||||
--
|
||||
-- instance Substitutable Name where
|
||||
-- apply subs n =
|
||||
-- case Map.lookup n subs of
|
||||
-- Nothing -> TypeRef unknownLocation Star n
|
||||
-- Just t -> t
|
||||
--
|
||||
-- instantiate :: Scheme -> Infer Type
|
||||
-- instantiate =
|
||||
-- do
|
||||
--
|
||||
-- inferExpression :: TypeEnv -> Expression ->
|
||||
-- Infer (Substitutions, Type)
|
||||
-- inferExpression typeEnv expr =
|
||||
-- case expr of
|
||||
-- ConstantExp s cv -> do memName <- getName TypeEnv "Memory"
|
||||
-- return (nullSubst, TypeRef s Star memName)
|
||||
-- ReferenceExp s n -> case Map.lookup n typeEnv of
|
||||
-- Nothing -> err (UnboundVariable s n)
|
||||
-- Just t -> do t' <- instantiate t
|
||||
-- return (nullSubst, t')
|
||||
-- LambdaExp s ns e -> do localTypeNames <- mapM (const (genName TypeEnv)) ns
|
||||
-- let localSchemes = map (Scheme [] . TypeRef s Star) localTypeNames
|
||||
-- localEnv = Map.fromList (zip ns localSchemes)
|
||||
-- typeEnv' = typeEnv `Map.union` localEnv
|
||||
-- (s1, t1) <- inferExpression typeEnv' e
|
||||
-- return (s1, TypeLambda s (Star `KindArrow` Star)
|
||||
-- (map (apply s1) localTypeNames)
|
||||
-- t1)
|
||||
--
|
||||
-- inferModule :: Module -> Infer Module
|
||||
-- inferModule = undefined
|
||||
|
||||
@@ -3,11 +3,18 @@ module Bang.Utils.FreeVars(
|
||||
)
|
||||
where
|
||||
|
||||
import Bang.AST.Name(Name)
|
||||
import Bang.AST.Name(Name)
|
||||
import Data.Set(Set)
|
||||
import qualified Data.Set as Set
|
||||
|
||||
class CanHaveFreeVars a where
|
||||
freeVariables :: a -> [Name]
|
||||
freeVariables :: a -> Set Name
|
||||
|
||||
instance CanHaveFreeVars a => CanHaveFreeVars (Maybe a) where
|
||||
freeVariables (Just x) = freeVariables x
|
||||
freeVariables Nothing = []
|
||||
freeVariables Nothing = Set.empty
|
||||
|
||||
instance CanHaveFreeVars a => CanHaveFreeVars [a] where
|
||||
freeVariables [] = Set.empty
|
||||
freeVariables (x:xs) = freeVariables x `Set.union` freeVariables xs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user