Checkpoint. We're inferring stuff, sorta.
This commit is contained in:
@@ -24,7 +24,7 @@ import Bang.Syntax.Location(Location)
|
|||||||
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||||
import Control.Lens(Lens', view, set, lens)
|
import Control.Lens(Lens', view, set, lens)
|
||||||
import Control.Lens(makeLenses)
|
import Control.Lens(makeLenses)
|
||||||
import Data.List(delete, union)
|
import Data.Set(delete, union)
|
||||||
import Text.PrettyPrint.Annotated(Doc, text, (<+>), ($+$), (<>), empty, space)
|
import Text.PrettyPrint.Annotated(Doc, text, (<+>), ($+$), (<>), empty, space)
|
||||||
|
|
||||||
data TypeDeclaration = TypeDeclaration
|
data TypeDeclaration = TypeDeclaration
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
|||||||
import Bang.Utils.Pretty(text')
|
import Bang.Utils.Pretty(text')
|
||||||
import Control.Lens(view)
|
import Control.Lens(view)
|
||||||
import Control.Lens.TH(makeLenses)
|
import Control.Lens.TH(makeLenses)
|
||||||
|
import Data.Set(empty, singleton, fromList, (\\))
|
||||||
import Data.Text.Lazy(Text)
|
import Data.Text.Lazy(Text)
|
||||||
import Text.PrettyPrint.Annotated(Doc, text, hsep, (<>), (<+>))
|
import Text.PrettyPrint.Annotated(Doc, text, hsep, (<>), (<+>))
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ instance MkConstExp Expression where
|
|||||||
mkConstExp l v = ConstExp (mkConstExp l v)
|
mkConstExp l v = ConstExp (mkConstExp l v)
|
||||||
|
|
||||||
instance CanHaveFreeVars ConstantExpression where
|
instance CanHaveFreeVars ConstantExpression where
|
||||||
freeVariables _ = []
|
freeVariables _ = empty
|
||||||
|
|
||||||
ppConstantExpression :: ConstantExpression -> Doc a
|
ppConstantExpression :: ConstantExpression -> Doc a
|
||||||
ppConstantExpression = ppConstantValue . _constValue
|
ppConstantExpression = ppConstantValue . _constValue
|
||||||
@@ -100,7 +101,7 @@ instance MkRefExp Expression where
|
|||||||
mkRefExp l n = RefExp (ReferenceExpression l n)
|
mkRefExp l n = RefExp (ReferenceExpression l n)
|
||||||
|
|
||||||
instance CanHaveFreeVars ReferenceExpression where
|
instance CanHaveFreeVars ReferenceExpression where
|
||||||
freeVariables r = [_refName r]
|
freeVariables r = singleton (_refName r)
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
@@ -126,8 +127,8 @@ instance MkLambdaExp Expression where
|
|||||||
mkLambdaExp l a b = LambdaExp (LambdaExpression l a b)
|
mkLambdaExp l a b = LambdaExp (LambdaExpression l a b)
|
||||||
|
|
||||||
instance CanHaveFreeVars LambdaExpression where
|
instance CanHaveFreeVars LambdaExpression where
|
||||||
freeVariables l = filter (\ x -> not (x `elem` (_lambdaArgumentNames l)))
|
freeVariables l = freeVariables (_lambdaBody l) \\
|
||||||
(freeVariables (_lambdaBody l))
|
fromList (_lambdaArgumentNames l)
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ module Bang.AST.Type
|
|||||||
, FunctionType
|
, FunctionType
|
||||||
, ppFunctionType
|
, ppFunctionType
|
||||||
, mkFunType
|
, mkFunType
|
||||||
, ftLocation, ftKind, ftArgumentTypes, ftResultType
|
, ftLocation, ftKind, ftArgumentType, ftResultType
|
||||||
-- * type application
|
-- * type application
|
||||||
, TypeApplication
|
, TypeApplication
|
||||||
, ppTypeApplication
|
, ppTypeApplication
|
||||||
@@ -36,9 +36,9 @@ import Bang.Syntax.Location(Location)
|
|||||||
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||||
import Bang.Utils.Pretty(text')
|
import Bang.Utils.Pretty(text')
|
||||||
import Control.Lens.TH(makeLenses)
|
import Control.Lens.TH(makeLenses)
|
||||||
import Data.List(foldl', union)
|
import Data.Set(union, empty, singleton)
|
||||||
import Data.Text.Lazy(Text)
|
import Data.Text.Lazy(Text)
|
||||||
import Text.PrettyPrint.Annotated(Doc, (<+>), (<>), text, hsep)
|
import Text.PrettyPrint.Annotated(Doc, (<+>), (<>), text)
|
||||||
|
|
||||||
data Kind = Star
|
data Kind = Star
|
||||||
| Unknown
|
| Unknown
|
||||||
@@ -62,7 +62,7 @@ instance Kinded UnitType where
|
|||||||
kind _ = Star
|
kind _ = Star
|
||||||
|
|
||||||
instance CanHaveFreeVars UnitType where
|
instance CanHaveFreeVars UnitType where
|
||||||
freeVariables _ = []
|
freeVariables _ = empty
|
||||||
|
|
||||||
ppUnitType :: UnitType -> Doc a
|
ppUnitType :: UnitType -> Doc a
|
||||||
ppUnitType _ = text "()"
|
ppUnitType _ = text "()"
|
||||||
@@ -88,7 +88,7 @@ instance MkPrimType Type where
|
|||||||
mkPrimType l t = TypePrim (PrimitiveType l t)
|
mkPrimType l t = TypePrim (PrimitiveType l t)
|
||||||
|
|
||||||
instance CanHaveFreeVars PrimitiveType where
|
instance CanHaveFreeVars PrimitiveType where
|
||||||
freeVariables _ = []
|
freeVariables _ = empty
|
||||||
|
|
||||||
ppPrimitiveType :: PrimitiveType -> Doc a
|
ppPrimitiveType :: PrimitiveType -> Doc a
|
||||||
ppPrimitiveType pt = text "llvm:" <> text' (_ptName pt)
|
ppPrimitiveType pt = text "llvm:" <> text' (_ptName pt)
|
||||||
@@ -118,20 +118,20 @@ instance MkTypeRef Type where
|
|||||||
mkTypeRef l k n = TypeRef (ReferenceType l k n)
|
mkTypeRef l k n = TypeRef (ReferenceType l k n)
|
||||||
|
|
||||||
instance CanHaveFreeVars ReferenceType where
|
instance CanHaveFreeVars ReferenceType where
|
||||||
freeVariables r = [_rtName r]
|
freeVariables r = singleton (_rtName r)
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
data FunctionType = FunctionType
|
data FunctionType = FunctionType
|
||||||
{ _ftLocation :: Location
|
{ _ftLocation :: Location
|
||||||
, _ftKind :: Kind
|
, _ftKind :: Kind
|
||||||
, _ftArgumentTypes :: [Type]
|
, _ftArgumentType :: Type
|
||||||
, _ftResultType :: Type
|
, _ftResultType :: Type
|
||||||
}
|
}
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
class MkFunType a where
|
class MkFunType a where
|
||||||
mkFunType :: Location -> [Type] -> Type -> a
|
mkFunType :: Location -> Type -> Type -> a
|
||||||
|
|
||||||
instance MkFunType FunctionType where
|
instance MkFunType FunctionType where
|
||||||
mkFunType l a r = FunctionType l Star a r
|
mkFunType l a r = FunctionType l Star a r
|
||||||
@@ -141,16 +141,14 @@ instance MkFunType Type where
|
|||||||
|
|
||||||
ppFunctionType :: FunctionType -> Doc a
|
ppFunctionType :: FunctionType -> Doc a
|
||||||
ppFunctionType ft =
|
ppFunctionType ft =
|
||||||
hsep (map ppType (_ftArgumentTypes ft)) <+> text "->" <+>
|
ppType (_ftArgumentType ft) <+> text "->" <+> ppType (_ftResultType ft)
|
||||||
ppType (_ftResultType ft)
|
|
||||||
|
|
||||||
instance Kinded FunctionType where
|
instance Kinded FunctionType where
|
||||||
kind = _ftKind
|
kind = _ftKind
|
||||||
|
|
||||||
instance CanHaveFreeVars FunctionType where
|
instance CanHaveFreeVars FunctionType where
|
||||||
freeVariables ft = foldl' (\ acc x -> acc `union` freeVariables x)
|
freeVariables ft = freeVariables (_ftArgumentType ft) `union`
|
||||||
(freeVariables (_ftResultType ft))
|
freeVariables (_ftResultType ft)
|
||||||
(_ftArgumentTypes ft)
|
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,10 @@ err' :: BangError e => e -> Compiler s ()
|
|||||||
err' e = Compiler (\ st -> runError e False >> return (st, ()))
|
err' e = Compiler (\ st -> runError e False >> return (st, ()))
|
||||||
|
|
||||||
runWarning :: BangWarning w => w -> IO ()
|
runWarning :: BangWarning w => w -> IO ()
|
||||||
runWarning = undefined
|
runWarning w = putStrLn (go (ppWarning w))
|
||||||
|
where
|
||||||
|
go (Nothing, doc) = render doc
|
||||||
|
go (Just a, doc) = render (ppLocation a $+$ nest 3 doc)
|
||||||
|
|
||||||
runError :: BangError w => w -> Bool -> IO ()
|
runError :: BangError w => w -> Bool -> IO ()
|
||||||
runError e die =
|
runError e die =
|
||||||
|
|||||||
@@ -174,8 +174,8 @@ Type :: { Type }
|
|||||||
: RawType { $1 }
|
: RawType { $1 }
|
||||||
|
|
||||||
RawType :: { Type }
|
RawType :: { Type }
|
||||||
: RawType '->' BaseType { mkFunType $2 [$1] $3 }
|
: RawType '->' BaseType { mkFunType $2 $1 $3 }
|
||||||
| BaseType { $1 }
|
| BaseType { $1 }
|
||||||
|
|
||||||
BaseType :: { Type }
|
BaseType :: { Type }
|
||||||
: TypeIdent {%
|
: TypeIdent {%
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import Bang.AST.Declaration(Declaration(..), declName,
|
|||||||
import Bang.AST.Expression(Expression(..), isEmptyExpression, refName,
|
import Bang.AST.Expression(Expression(..), isEmptyExpression, refName,
|
||||||
lambdaArgumentNames, lambdaBody,
|
lambdaArgumentNames, lambdaBody,
|
||||||
isEmptyExpression)
|
isEmptyExpression)
|
||||||
import Bang.AST.Type(Type(..), rtName, ftArgumentTypes, ftResultType,
|
import Bang.AST.Type(Type(..), rtName, ftArgumentType, ftResultType,
|
||||||
taLeftType, taRightType)
|
taLeftType, taRightType)
|
||||||
import Bang.Monad(Compiler, BangError(..), err, err', registerName)
|
import Bang.Monad(Compiler, BangError(..), err, err', registerName)
|
||||||
import Bang.Syntax.Location(Location, ppLocation)
|
import Bang.Syntax.Location(Location, ppLocation)
|
||||||
@@ -26,6 +26,7 @@ import Data.Graph(SCC(..))
|
|||||||
import Data.Graph.SCC(stronglyConnComp)
|
import Data.Graph.SCC(stronglyConnComp)
|
||||||
import Data.Map.Strict(Map)
|
import Data.Map.Strict(Map)
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
|
import Data.Set(toList)
|
||||||
import Data.Text.Lazy(uncons)
|
import Data.Text.Lazy(uncons)
|
||||||
import Text.PrettyPrint.Annotated(text, ($+$), (<+>), nest, quotes)
|
import Text.PrettyPrint.Annotated(text, ($+$), (<+>), nest, quotes)
|
||||||
|
|
||||||
@@ -115,20 +116,16 @@ linkNames decls =
|
|||||||
let t' = set rtName name t
|
let t' = set rtName name t
|
||||||
return (TypeRef t', nameMap')
|
return (TypeRef t', nameMap')
|
||||||
linkType nameMap (TypeFun t) =
|
linkType nameMap (TypeFun t) =
|
||||||
do (argTypes, nameMap') <- foldM linkTypes ([], nameMap) (view ftArgumentTypes t)
|
do (argType, nameMap') <- linkType nameMap (view ftArgumentType t)
|
||||||
(resType, nameMap'') <- linkType nameMap' (view ftResultType t)
|
(resType, nameMap'') <- linkType nameMap' (view ftResultType t)
|
||||||
return (TypeFun (set ftArgumentTypes argTypes $
|
return (TypeFun (set ftArgumentType argType $
|
||||||
set ftResultType resType t),
|
set ftResultType resType t),
|
||||||
nameMap'')
|
nameMap'')
|
||||||
linkType nameMap (TypeApp t) =
|
linkType nameMap (TypeApp t) =
|
||||||
do (lt, nameMap') <- linkType nameMap (view taLeftType t)
|
do (lt, nameMap') <- linkType nameMap (view taLeftType t)
|
||||||
(rt, nameMap'') <- linkType nameMap' (view taRightType t)
|
(rt, nameMap'') <- linkType nameMap' (view taRightType t)
|
||||||
return (TypeApp (set taLeftType lt (set taRightType rt t)), nameMap'')
|
return (TypeApp (set taLeftType lt (set taRightType rt t)), nameMap'')
|
||||||
--
|
--
|
||||||
linkTypes (acc, nameMap) argType =
|
|
||||||
do (argType', nameMap') <- linkType nameMap argType
|
|
||||||
return (acc ++ [argType'], nameMap')
|
|
||||||
--
|
|
||||||
linkExpr _ x | isEmptyExpression x = return x
|
linkExpr _ x | isEmptyExpression x = return x
|
||||||
linkExpr _ x@(ConstExp _) = return x
|
linkExpr _ x@(ConstExp _) = return x
|
||||||
linkExpr nameMap (RefExp e) =
|
linkExpr nameMap (RefExp e) =
|
||||||
@@ -217,4 +214,4 @@ orderDecls decls = map unSCC (stronglyConnComp nodes)
|
|||||||
unSCC (CyclicSCC xs) = xs
|
unSCC (CyclicSCC xs) = xs
|
||||||
--
|
--
|
||||||
nodes = map tuplify decls
|
nodes = map tuplify decls
|
||||||
tuplify d = (d, view declName d, freeVariables d)
|
tuplify d = (d, view declName d, toList (freeVariables d))
|
||||||
|
|||||||
@@ -1,174 +1,269 @@
|
|||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE TemplateHaskell #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
|
{-# LANGUAGE TypeSynonymInstances #-}
|
||||||
module Bang.TypeInfer(runTypeInference)
|
module Bang.TypeInfer(runTypeInference)
|
||||||
where
|
where
|
||||||
|
|
||||||
import Bang.AST(Module)
|
import Bang.AST(Module, moduleDeclarations)
|
||||||
import Bang.Monad(Compiler)
|
import Bang.AST.Declaration(Declaration(..), ValueDeclaration,
|
||||||
|
vdName, vdDeclaredType, vdValue,
|
||||||
runTypeInference :: Module -> Compiler ps Module
|
tdName, tdType)
|
||||||
runTypeInference x = return x
|
import Bang.AST.Expression(Expression(..), ConstantValue(..),
|
||||||
|
lambdaArgumentNames, lambdaBody,
|
||||||
{- Better version
|
constLocation, constValue, refName)
|
||||||
import Bang.Monad(Compiler, BangError(..), err,
|
import Bang.AST.Name(Name, NameEnvironment(..),
|
||||||
runPass, getPassState, setPassState,
|
nameLocation, nameText, ppName)
|
||||||
viewPassState, overPassState,
|
import Bang.AST.Type(Type(..), ppType, rtName, ftArgumentType,
|
||||||
registerNewName, genName)
|
ftResultType, taLeftType, taRightType,
|
||||||
import Bang.Syntax.Location(Location, unknownLocation)
|
mkPrimType, mkFunType, mkTypeRef,
|
||||||
import Bang.Syntax.ParserMonad(NameDatabase(..))
|
Kind(..))
|
||||||
import Bang.Utils.Pretty(BangDoc)
|
import Bang.Monad(Compiler, BangError(..), BangWarning(..),
|
||||||
import Control.Lens(set, view, over)
|
registerNewName, err', err, warn,
|
||||||
import Control.Lens.TH(makeLenses)
|
getPassState, mapPassState, runPass)
|
||||||
import Data.List(union, nub, concat)
|
import Bang.Syntax.Location(Location, fakeLocation)
|
||||||
|
import Bang.Utils.FreeVars(CanHaveFreeVars(..))
|
||||||
|
import Bang.Utils.Pretty(BangDoc, text')
|
||||||
|
import Control.Lens(view, over)
|
||||||
import Data.Map.Strict(Map)
|
import Data.Map.Strict(Map)
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import Data.Set(Set)
|
import Data.Set(Set, (\\))
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import Data.Text.Lazy(Text, pack)
|
import Text.PrettyPrint.Annotated(text, nest, quotes, ($+$), (<+>))
|
||||||
import Text.PrettyPrint.Annotated(text, (<+>), quotes)
|
|
||||||
|
|
||||||
data InferenceState = InferenceState {
|
runTypeInference :: Module -> Compiler ps Module
|
||||||
_nameDatabase :: NameDatabase
|
runTypeInference x =
|
||||||
}
|
do _ <- runPass emptyEnvironment (mapM_ typeInferDecls (view moduleDeclarations x))
|
||||||
|
return x
|
||||||
makeLenses ''InferenceState
|
|
||||||
|
|
||||||
type Infer a = Compiler InferenceState a
|
|
||||||
|
|
||||||
runInfer :: NameDatabase -> Infer a -> Compiler ps a
|
|
||||||
runInfer ndb action = snd `fmap` runPass initial action
|
|
||||||
where initial = InferenceState ndb
|
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
data InferError = UnboundVariable Location Name
|
type Infer a = Compiler TypeEnvironment a
|
||||||
deriving (Show)
|
|
||||||
|
|
||||||
instance BangError InferError where
|
getNamesTypeScheme :: Name -> Infer (Maybe Scheme)
|
||||||
ppError = prettyError
|
getNamesTypeScheme n = Map.lookup n `fmap` getPassState
|
||||||
|
|
||||||
prettyError :: InferError -> (Maybe Location, BangDoc)
|
addToTypeEnvironment :: [Name] -> [Scheme] -> Infer ()
|
||||||
prettyError e =
|
addToTypeEnvironment ns schms = mapPassState (add ns schms)
|
||||||
case e of
|
where
|
||||||
UnboundVariable l n ->
|
add :: [Name] -> [Scheme] -> TypeEnvironment -> TypeEnvironment
|
||||||
(Just l, text "Unbound variable '" <+> quotes (ppName n))
|
add [] [] acc = acc
|
||||||
|
add (n:restns) (s:rschms) acc =
|
||||||
|
Map.insertWithKey errorFn n s (add restns rschms acc)
|
||||||
|
add _ _ _ =
|
||||||
|
error "Wackiness has insued."
|
||||||
|
--
|
||||||
|
errorFn k _ _ = error ("Redefinition of " ++ show k)
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
type Substitutions = Map Name Type
|
type Substitution = Map Name Type
|
||||||
|
|
||||||
noSubstitutions :: Substitutions
|
nullSubstitution :: Substitution
|
||||||
noSubstitutions = Map.empty
|
nullSubstitution = Map.empty
|
||||||
|
|
||||||
composeSubstitutions :: Substitutions -> Substitutions -> Substitutions
|
composeSubstitutions :: Substitution -> Substitution -> Substitution
|
||||||
composeSubstitutions s1 s2 = Map.map (apply s1) s2 `Map.union` s1
|
composeSubstitutions s1 s2 = Map.map (apply s1) s2 `Map.union` s1
|
||||||
|
|
||||||
class Types a where
|
class ApplySubst t where
|
||||||
freeTypeVariables :: a -> Set Name
|
apply :: Substitution -> t -> t
|
||||||
apply :: Substitutions -> a -> a
|
|
||||||
|
|
||||||
instance Types Type where
|
instance ApplySubst Type where
|
||||||
freeTypeVariables t =
|
apply s (TypeUnit t) = TypeUnit t
|
||||||
case t of
|
apply s (TypePrim t) = TypePrim t
|
||||||
TypeUnit _ _ -> Set.empty
|
apply s (TypeRef t) = case Map.lookup (view rtName t) s of
|
||||||
TypePrim _ _ _ -> Set.empty
|
Nothing -> TypeRef t
|
||||||
TypeRef _ _ n -> Set.singleton n
|
Just t' -> t'
|
||||||
TypeLambda _ _ ns e -> Set.unions (map freeTypeVariables ns) `Set.union`
|
apply s (TypeFun t) = TypeFun (over ftArgumentType (apply s) $
|
||||||
freeTypeVariables e
|
over ftResultType (apply s) t)
|
||||||
TypeApp _ _ a b -> freeTypeVariables a `Set.union` freeTypeVariables b
|
apply s (TypeApp t) = TypeApp (over taLeftType (apply s) $
|
||||||
apply substs t =
|
over taRightType (apply s) t)
|
||||||
case t of
|
|
||||||
TypeRef _ _ n -> case Map.lookup n substs of
|
|
||||||
Nothing -> t
|
|
||||||
Just t' -> t'
|
|
||||||
TypeLambda l k ns e -> TypeLambda l k (apply substs ns) (apply substs e)
|
|
||||||
TypeApp l k a b -> TypeApp l k (apply substs a) (apply substs b)
|
|
||||||
_ -> t
|
|
||||||
|
|
||||||
instance Types a => Types [a] where
|
instance ApplySubst a => ApplySubst [a] where
|
||||||
freeTypeVariables l = Set.unions (map freeTypeVariables l)
|
apply s = map (apply s)
|
||||||
apply s = map (apply s)
|
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
inferModule :: Module -> Infer Module
|
data Scheme = Scheme [Name] Type
|
||||||
inferModule = undefined
|
|
||||||
|
instance CanHaveFreeVars Scheme where
|
||||||
|
freeVariables (Scheme ns t) = freeVariables t \\ Set.fromList ns
|
||||||
|
|
||||||
|
instance ApplySubst Scheme where
|
||||||
|
apply s (Scheme vars t) = Scheme vars (apply s t)
|
||||||
|
|
||||||
|
newTypeVar :: Name -> Infer Type
|
||||||
|
newTypeVar n =
|
||||||
|
do let loc = view nameLocation n
|
||||||
|
n' <- registerNewName TypeEnv (view nameText n)
|
||||||
|
return (mkTypeRef loc Unknown n')
|
||||||
|
|
||||||
|
instantiate :: Scheme -> Infer Type
|
||||||
|
instantiate (Scheme vars t) =
|
||||||
|
do refs <- mapM newTypeVar vars
|
||||||
|
let newSubsts = Map.fromList (zip vars refs)
|
||||||
|
return (apply newSubsts t)
|
||||||
|
|
||||||
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
mostGeneralUnifier :: Type -> Type -> Infer Substitution
|
||||||
|
mostGeneralUnifier a b =
|
||||||
|
case (a, b) of
|
||||||
|
(TypeUnit _, TypeUnit _) -> return nullSubstitution
|
||||||
|
(TypePrim _, TypePrim _) -> return nullSubstitution
|
||||||
|
(TypeRef t1, t2) -> varBind (view rtName t1) t2
|
||||||
|
(t2, TypeRef t1) -> varBind (view rtName t1) t2
|
||||||
|
(TypeFun t1, TypeFun t2) -> do let at1 = view ftArgumentType t1
|
||||||
|
at2 = view ftArgumentType t2
|
||||||
|
s1 <- mostGeneralUnifier at1 at2
|
||||||
|
let rt1 = apply s1 (view ftResultType t1)
|
||||||
|
rt2 = apply s1 (view ftResultType t2)
|
||||||
|
s2 <- mostGeneralUnifier rt1 rt2
|
||||||
|
return (s1 `composeSubstitutions` s2)
|
||||||
|
(TypeApp t1, TypeApp t2) -> do let lt1 = view taLeftType t1
|
||||||
|
lt2 = view taLeftType t2
|
||||||
|
s1 <- mostGeneralUnifier lt1 lt2
|
||||||
|
let rt1 = apply s1 (view taRightType t1)
|
||||||
|
rt2 = apply s1 (view taRightType t2)
|
||||||
|
s2 <- mostGeneralUnifier rt1 rt2
|
||||||
|
return (s1 `composeSubstitutions` s2)
|
||||||
|
_ -> do err' (TypesDontUnify a b)
|
||||||
|
return nullSubstitution
|
||||||
|
|
||||||
|
varBind :: Name -> Type -> Infer Substitution
|
||||||
|
varBind u t | TypeRef t' <- t,
|
||||||
|
view rtName t' == u = return nullSubstitution
|
||||||
|
| u `Set.member` freeVariables t = do err' (OccursFail u t)
|
||||||
|
return nullSubstitution
|
||||||
|
| otherwise = return (Map.singleton u t)
|
||||||
|
|
||||||
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type TypeEnvironment = Map Name Scheme
|
||||||
|
|
||||||
|
emptyEnvironment :: TypeEnvironment
|
||||||
|
emptyEnvironment = Map.empty
|
||||||
|
|
||||||
|
instance ApplySubst TypeEnvironment where
|
||||||
|
apply s tenv = Map.map (apply s) tenv
|
||||||
|
|
||||||
|
instance CanHaveFreeVars TypeEnvironment where
|
||||||
|
freeVariables tenv = freeVariables (Map.elems tenv)
|
||||||
|
|
||||||
|
generalize :: TypeEnvironment -> Type -> Scheme
|
||||||
|
generalize env t = Scheme vars t
|
||||||
|
where vars = Set.toList (freeVariables t \\ freeVariables env)
|
||||||
|
|
||||||
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
data InferenceError = InternalError
|
||||||
|
| TypesDontUnify Type Type
|
||||||
|
| OccursFail Name Type
|
||||||
|
| UnboundVariable Name
|
||||||
|
|
||||||
|
instance BangError InferenceError where
|
||||||
|
ppError = prettyError
|
||||||
|
|
||||||
|
prettyError :: InferenceError -> (Maybe Location, BangDoc)
|
||||||
|
prettyError e =
|
||||||
|
case e of
|
||||||
|
InternalError ->
|
||||||
|
(Nothing, text "<internal error>")
|
||||||
|
TypesDontUnify t1 t2 ->
|
||||||
|
(Nothing, text "Types don't unify:" $+$
|
||||||
|
(nest 3
|
||||||
|
(text "first type: " <+> ppType t1 $+$
|
||||||
|
text "second type: " <+> ppType t2)))
|
||||||
|
OccursFail n t ->
|
||||||
|
(Just (view nameLocation n),
|
||||||
|
text "Occurs check failed:" $+$
|
||||||
|
(nest 3 (text "Type: " <+> ppType t)))
|
||||||
|
UnboundVariable n ->
|
||||||
|
(Just (view nameLocation n),
|
||||||
|
text "Unbound variable (in type checker?):" <+> ppName n)
|
||||||
|
|
||||||
|
data InferenceWarning = TopLevelWithoutType Name Type
|
||||||
|
| DeclarationMismatch Name Type Type
|
||||||
|
|
||||||
|
instance BangWarning InferenceWarning where
|
||||||
|
ppWarning = prettyWarning
|
||||||
|
|
||||||
|
prettyWarning :: InferenceWarning -> (Maybe Location, BangDoc)
|
||||||
|
prettyWarning w =
|
||||||
|
case w of
|
||||||
|
TopLevelWithoutType n t ->
|
||||||
|
(Just (view nameLocation n),
|
||||||
|
text "Variable" <+> quotes (text' (view nameText n)) <+>
|
||||||
|
text "is defined without a type." $+$
|
||||||
|
text "Inferred type:" $+$ nest 3 (ppType t))
|
||||||
|
DeclarationMismatch n dt it ->
|
||||||
|
(Just (view nameLocation n),
|
||||||
|
text "Mismatch between declared and inferred type of" <+>
|
||||||
|
quotes (text' (view nameText n)) $+$
|
||||||
|
nest 3 (text "declared type:" <+> ppType dt $+$
|
||||||
|
text "inferred type:" <+> ppType it))
|
||||||
|
|
||||||
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
-- Infer the type of a group of declarations with cyclic dependencies.
|
||||||
|
typeInferDecls :: [Declaration] -> Infer ()
|
||||||
|
typeInferDecls decls =
|
||||||
|
do (names, schemes, decls') <- getInitialSchemes decls
|
||||||
|
addToTypeEnvironment names schemes
|
||||||
|
mapM_ typeInferDecl decls'
|
||||||
|
where
|
||||||
|
getInitialSchemes [] =
|
||||||
|
return ([], [], [])
|
||||||
|
getInitialSchemes ((DeclType td) : rest) =
|
||||||
|
do (rn, rs, rd) <- getInitialSchemes rest
|
||||||
|
let n = view tdName td
|
||||||
|
s = Scheme [] (view tdType td)
|
||||||
|
return (n:rn, s:rs, rd)
|
||||||
|
getInitialSchemes ((DeclVal td) : rest) =
|
||||||
|
do (rn, rs, rd) <- getInitialSchemes rest
|
||||||
|
return (rn, rs, (td : rd))
|
||||||
|
|
||||||
|
typeInferDecl :: ValueDeclaration -> Infer ()
|
||||||
|
typeInferDecl vd =
|
||||||
|
do (subs, t) <- typeInferExpr (view vdValue vd)
|
||||||
|
let t' = apply subs t
|
||||||
|
case view vdDeclaredType vd of
|
||||||
|
Nothing ->
|
||||||
|
warn (TopLevelWithoutType (view vdName vd) t')
|
||||||
|
Just dt ->
|
||||||
|
warn (DeclarationMismatch (view vdName vd) dt t)
|
||||||
|
|
||||||
|
typeInferConst :: Location -> ConstantValue ->
|
||||||
|
Infer (Substitution, Type)
|
||||||
|
typeInferConst l (ConstantInt _ _) =
|
||||||
|
return (nullSubstitution, mkPrimType l "i64")
|
||||||
|
typeInferConst l (ConstantChar _) =
|
||||||
|
return (nullSubstitution, mkPrimType l "i8") -- FIXME
|
||||||
|
typeInferConst l (ConstantString _) =
|
||||||
|
return (nullSubstitution, mkPrimType l "i8*") -- FIXME
|
||||||
|
typeInferConst l (ConstantFloat _) =
|
||||||
|
return (nullSubstitution, mkPrimType l "double")
|
||||||
|
|
||||||
|
typeInferExpr :: Expression -> Infer (Substitution, Type)
|
||||||
|
typeInferExpr expr =
|
||||||
|
case expr of
|
||||||
|
ConstExp e ->
|
||||||
|
typeInferConst (view constLocation e) (view constValue e)
|
||||||
|
RefExp e ->
|
||||||
|
do mscheme <- getNamesTypeScheme (view refName e)
|
||||||
|
case mscheme of
|
||||||
|
Nothing -> err (UnboundVariable (view refName e))
|
||||||
|
Just scheme -> do t <- instantiate scheme
|
||||||
|
return (nullSubstitution, t)
|
||||||
|
LambdaExp e ->
|
||||||
|
do let argNames = view lambdaArgumentNames e
|
||||||
|
tvars <- mapM newTypeVar argNames
|
||||||
|
let tvars' = map (Scheme []) tvars
|
||||||
|
addToTypeEnvironment argNames tvars'
|
||||||
|
(s1, t1) <- typeInferExpr (view lambdaBody e)
|
||||||
|
return (s1, mkFunType' (apply s1 tvars) t1)
|
||||||
|
where
|
||||||
|
mkFunType' [] t = t
|
||||||
|
mkFunType' (x:rest) t = mkFunType fakeLocation x (mkFunType' rest t)
|
||||||
|
|
||||||
|
|
||||||
runTypeInference :: NameDatabase -> Module -> Compiler ps Module
|
|
||||||
runTypeInference ndb mod = runInfer ndb (inferModule mod)
|
|
||||||
-}
|
|
||||||
-- data Scheme = Scheme [Name] Type
|
|
||||||
--
|
|
||||||
-- getName :: NameEnvironment -> Text -> Infer Name
|
|
||||||
-- getName env nameText =
|
|
||||||
-- do namedb <- viewPassState nameDatabase
|
|
||||||
-- let key = (env, nameText)
|
|
||||||
-- case Map.lookup key namedb of
|
|
||||||
-- Nothing ->
|
|
||||||
-- do name <- registerNewName env nameText
|
|
||||||
-- overPassState (set nameDatabase (Map.insert key name namedb))
|
|
||||||
-- return name
|
|
||||||
-- Just name ->
|
|
||||||
-- return name
|
|
||||||
--
|
|
||||||
-- runTypeInference :: NameDatabase -> Module -> Compiler ps Module
|
|
||||||
-- runTypeInference nameDB mod =
|
|
||||||
-- snd `fmap` (runPass initialState (inferModule mod))
|
|
||||||
-- where initialState = InferenceState nameDB
|
|
||||||
--
|
|
||||||
-- type Substitutions = Map Name Type
|
|
||||||
--
|
|
||||||
-- nullSubst :: Substitutions
|
|
||||||
-- nullSubst = Map.empty
|
|
||||||
--
|
|
||||||
-- type TypeEnv = Map Name Scheme
|
|
||||||
--
|
|
||||||
-- class Substitutable a where
|
|
||||||
-- apply :: Substitutions -> a -> Type
|
|
||||||
--
|
|
||||||
-- instance Substitutable Type where
|
|
||||||
-- apply subs t =
|
|
||||||
-- case t of
|
|
||||||
-- TypeUnit _ _ -> t
|
|
||||||
-- TypePrim _ _ _ -> t
|
|
||||||
-- TypeRef _ _ n -> case Map.lookup n subs of
|
|
||||||
-- Nothing -> t
|
|
||||||
-- Just t' -> t'
|
|
||||||
-- TypeLambda l k ats bt ->
|
|
||||||
-- TypeLambda l k (map (apply subs) ats) (apply subs bt)
|
|
||||||
-- TypeApp l k a b ->
|
|
||||||
-- TypeApp l k (apply subs a) (apply subs b)
|
|
||||||
-- TypeForAll ns t ->
|
|
||||||
-- TypeForAll ns (apply subs t)
|
|
||||||
--
|
|
||||||
-- instance Substitutable Name where
|
|
||||||
-- apply subs n =
|
|
||||||
-- case Map.lookup n subs of
|
|
||||||
-- Nothing -> TypeRef unknownLocation Star n
|
|
||||||
-- Just t -> t
|
|
||||||
--
|
|
||||||
-- instantiate :: Scheme -> Infer Type
|
|
||||||
-- instantiate =
|
|
||||||
-- do
|
|
||||||
--
|
|
||||||
-- inferExpression :: TypeEnv -> Expression ->
|
|
||||||
-- Infer (Substitutions, Type)
|
|
||||||
-- inferExpression typeEnv expr =
|
|
||||||
-- case expr of
|
|
||||||
-- ConstantExp s cv -> do memName <- getName TypeEnv "Memory"
|
|
||||||
-- return (nullSubst, TypeRef s Star memName)
|
|
||||||
-- ReferenceExp s n -> case Map.lookup n typeEnv of
|
|
||||||
-- Nothing -> err (UnboundVariable s n)
|
|
||||||
-- Just t -> do t' <- instantiate t
|
|
||||||
-- return (nullSubst, t')
|
|
||||||
-- LambdaExp s ns e -> do localTypeNames <- mapM (const (genName TypeEnv)) ns
|
|
||||||
-- let localSchemes = map (Scheme [] . TypeRef s Star) localTypeNames
|
|
||||||
-- localEnv = Map.fromList (zip ns localSchemes)
|
|
||||||
-- typeEnv' = typeEnv `Map.union` localEnv
|
|
||||||
-- (s1, t1) <- inferExpression typeEnv' e
|
|
||||||
-- return (s1, TypeLambda s (Star `KindArrow` Star)
|
|
||||||
-- (map (apply s1) localTypeNames)
|
|
||||||
-- t1)
|
|
||||||
--
|
|
||||||
-- inferModule :: Module -> Infer Module
|
|
||||||
-- inferModule = undefined
|
|
||||||
|
|||||||
@@ -3,11 +3,18 @@ module Bang.Utils.FreeVars(
|
|||||||
)
|
)
|
||||||
where
|
where
|
||||||
|
|
||||||
import Bang.AST.Name(Name)
|
import Bang.AST.Name(Name)
|
||||||
|
import Data.Set(Set)
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
|
||||||
class CanHaveFreeVars a where
|
class CanHaveFreeVars a where
|
||||||
freeVariables :: a -> [Name]
|
freeVariables :: a -> Set Name
|
||||||
|
|
||||||
instance CanHaveFreeVars a => CanHaveFreeVars (Maybe a) where
|
instance CanHaveFreeVars a => CanHaveFreeVars (Maybe a) where
|
||||||
freeVariables (Just x) = freeVariables x
|
freeVariables (Just x) = freeVariables x
|
||||||
freeVariables Nothing = []
|
freeVariables Nothing = Set.empty
|
||||||
|
|
||||||
|
instance CanHaveFreeVars a => CanHaveFreeVars [a] where
|
||||||
|
freeVariables [] = Set.empty
|
||||||
|
freeVariables (x:xs) = freeVariables x `Set.union` freeVariables xs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user