diff --git a/src/Bang/AST/Declaration.hs b/src/Bang/AST/Declaration.hs index 3a750e9..2b65a4b 100644 --- a/src/Bang/AST/Declaration.hs +++ b/src/Bang/AST/Declaration.hs @@ -24,7 +24,7 @@ import Bang.Syntax.Location(Location) import Bang.Utils.FreeVars(CanHaveFreeVars(..)) import Control.Lens(Lens', view, set, lens) import Control.Lens(makeLenses) -import Data.List(delete, union) +import Data.Set(delete, union) import Text.PrettyPrint.Annotated(Doc, text, (<+>), ($+$), (<>), empty, space) data TypeDeclaration = TypeDeclaration diff --git a/src/Bang/AST/Expression.hs b/src/Bang/AST/Expression.hs index fd2dc58..423e934 100644 --- a/src/Bang/AST/Expression.hs +++ b/src/Bang/AST/Expression.hs @@ -35,6 +35,7 @@ import Bang.Utils.FreeVars(CanHaveFreeVars(..)) import Bang.Utils.Pretty(text') import Control.Lens(view) import Control.Lens.TH(makeLenses) +import Data.Set(empty, singleton, fromList, (\\)) import Data.Text.Lazy(Text) import Text.PrettyPrint.Annotated(Doc, text, hsep, (<>), (<+>)) @@ -74,7 +75,7 @@ instance MkConstExp Expression where mkConstExp l v = ConstExp (mkConstExp l v) instance CanHaveFreeVars ConstantExpression where - freeVariables _ = [] + freeVariables _ = empty ppConstantExpression :: ConstantExpression -> Doc a ppConstantExpression = ppConstantValue . _constValue @@ -100,7 +101,7 @@ instance MkRefExp Expression where mkRefExp l n = RefExp (ReferenceExpression l n) instance CanHaveFreeVars ReferenceExpression where - freeVariables r = [_refName r] + freeVariables r = singleton (_refName r) -- ----------------------------------------------------------------------------- @@ -126,8 +127,8 @@ instance MkLambdaExp Expression where mkLambdaExp l a b = LambdaExp (LambdaExpression l a b) instance CanHaveFreeVars LambdaExpression where - freeVariables l = filter (\ x -> not (x `elem` (_lambdaArgumentNames l))) - (freeVariables (_lambdaBody l)) + freeVariables l = freeVariables (_lambdaBody l) \\ + fromList (_lambdaArgumentNames l) -- ----------------------------------------------------------------------------- diff --git a/src/Bang/AST/Type.hs b/src/Bang/AST/Type.hs index 3ec331f..df560c8 100644 --- a/src/Bang/AST/Type.hs +++ b/src/Bang/AST/Type.hs @@ -22,7 +22,7 @@ module Bang.AST.Type , FunctionType , ppFunctionType , mkFunType - , ftLocation, ftKind, ftArgumentTypes, ftResultType + , ftLocation, ftKind, ftArgumentType, ftResultType -- * type application , TypeApplication , ppTypeApplication @@ -36,9 +36,9 @@ import Bang.Syntax.Location(Location) import Bang.Utils.FreeVars(CanHaveFreeVars(..)) import Bang.Utils.Pretty(text') import Control.Lens.TH(makeLenses) -import Data.List(foldl', union) +import Data.Set(union, empty, singleton) import Data.Text.Lazy(Text) -import Text.PrettyPrint.Annotated(Doc, (<+>), (<>), text, hsep) +import Text.PrettyPrint.Annotated(Doc, (<+>), (<>), text) data Kind = Star | Unknown @@ -62,7 +62,7 @@ instance Kinded UnitType where kind _ = Star instance CanHaveFreeVars UnitType where - freeVariables _ = [] + freeVariables _ = empty ppUnitType :: UnitType -> Doc a ppUnitType _ = text "()" @@ -88,7 +88,7 @@ instance MkPrimType Type where mkPrimType l t = TypePrim (PrimitiveType l t) instance CanHaveFreeVars PrimitiveType where - freeVariables _ = [] + freeVariables _ = empty ppPrimitiveType :: PrimitiveType -> Doc a ppPrimitiveType pt = text "llvm:" <> text' (_ptName pt) @@ -118,20 +118,20 @@ instance MkTypeRef Type where mkTypeRef l k n = TypeRef (ReferenceType l k n) instance CanHaveFreeVars ReferenceType where - freeVariables r = [_rtName r] + freeVariables r = singleton (_rtName r) -- ----------------------------------------------------------------------------- data FunctionType = FunctionType - { _ftLocation :: Location - , _ftKind :: Kind - , _ftArgumentTypes :: [Type] - , _ftResultType :: Type + { _ftLocation :: Location + , _ftKind :: Kind + , _ftArgumentType :: Type + , _ftResultType :: Type } deriving (Show) class MkFunType a where - mkFunType :: Location -> [Type] -> Type -> a + mkFunType :: Location -> Type -> Type -> a instance MkFunType FunctionType where mkFunType l a r = FunctionType l Star a r @@ -141,16 +141,14 @@ instance MkFunType Type where ppFunctionType :: FunctionType -> Doc a ppFunctionType ft = - hsep (map ppType (_ftArgumentTypes ft)) <+> text "->" <+> - ppType (_ftResultType ft) + ppType (_ftArgumentType ft) <+> text "->" <+> ppType (_ftResultType ft) instance Kinded FunctionType where kind = _ftKind instance CanHaveFreeVars FunctionType where - freeVariables ft = foldl' (\ acc x -> acc `union` freeVariables x) - (freeVariables (_ftResultType ft)) - (_ftArgumentTypes ft) + freeVariables ft = freeVariables (_ftArgumentType ft) `union` + freeVariables (_ftResultType ft) -- ----------------------------------------------------------------------------- diff --git a/src/Bang/Monad.hs b/src/Bang/Monad.hs index a08675e..5410a97 100644 --- a/src/Bang/Monad.hs +++ b/src/Bang/Monad.hs @@ -162,7 +162,10 @@ err' :: BangError e => e -> Compiler s () err' e = Compiler (\ st -> runError e False >> return (st, ())) runWarning :: BangWarning w => w -> IO () -runWarning = undefined +runWarning w = putStrLn (go (ppWarning w)) + where + go (Nothing, doc) = render doc + go (Just a, doc) = render (ppLocation a $+$ nest 3 doc) runError :: BangError w => w -> Bool -> IO () runError e die = diff --git a/src/Bang/Syntax/Parser.y b/src/Bang/Syntax/Parser.y index 2bc4127..1760457 100644 --- a/src/Bang/Syntax/Parser.y +++ b/src/Bang/Syntax/Parser.y @@ -174,8 +174,8 @@ Type :: { Type } : RawType { $1 } RawType :: { Type } - : RawType '->' BaseType { mkFunType $2 [$1] $3 } - | BaseType { $1 } + : RawType '->' BaseType { mkFunType $2 $1 $3 } + | BaseType { $1 } BaseType :: { Type } : TypeIdent {% diff --git a/src/Bang/Syntax/PostProcess.hs b/src/Bang/Syntax/PostProcess.hs index aa6909a..b088ad6 100644 --- a/src/Bang/Syntax/PostProcess.hs +++ b/src/Bang/Syntax/PostProcess.hs @@ -13,7 +13,7 @@ import Bang.AST.Declaration(Declaration(..), declName, import Bang.AST.Expression(Expression(..), isEmptyExpression, refName, lambdaArgumentNames, lambdaBody, isEmptyExpression) -import Bang.AST.Type(Type(..), rtName, ftArgumentTypes, ftResultType, +import Bang.AST.Type(Type(..), rtName, ftArgumentType, ftResultType, taLeftType, taRightType) import Bang.Monad(Compiler, BangError(..), err, err', registerName) import Bang.Syntax.Location(Location, ppLocation) @@ -26,6 +26,7 @@ import Data.Graph(SCC(..)) import Data.Graph.SCC(stronglyConnComp) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map +import Data.Set(toList) import Data.Text.Lazy(uncons) import Text.PrettyPrint.Annotated(text, ($+$), (<+>), nest, quotes) @@ -115,19 +116,15 @@ linkNames decls = let t' = set rtName name t return (TypeRef t', nameMap') linkType nameMap (TypeFun t) = - do (argTypes, nameMap') <- foldM linkTypes ([], nameMap) (view ftArgumentTypes t) - (resType, nameMap'') <- linkType nameMap' (view ftResultType t) - return (TypeFun (set ftArgumentTypes argTypes $ - set ftResultType resType t), + do (argType, nameMap') <- linkType nameMap (view ftArgumentType t) + (resType, nameMap'') <- linkType nameMap' (view ftResultType t) + return (TypeFun (set ftArgumentType argType $ + set ftResultType resType t), nameMap'') linkType nameMap (TypeApp t) = do (lt, nameMap') <- linkType nameMap (view taLeftType t) (rt, nameMap'') <- linkType nameMap' (view taRightType t) return (TypeApp (set taLeftType lt (set taRightType rt t)), nameMap'') - -- - linkTypes (acc, nameMap) argType = - do (argType', nameMap') <- linkType nameMap argType - return (acc ++ [argType'], nameMap') -- linkExpr _ x | isEmptyExpression x = return x linkExpr _ x@(ConstExp _) = return x @@ -217,4 +214,4 @@ orderDecls decls = map unSCC (stronglyConnComp nodes) unSCC (CyclicSCC xs) = xs -- nodes = map tuplify decls - tuplify d = (d, view declName d, freeVariables d) + tuplify d = (d, view declName d, toList (freeVariables d)) diff --git a/src/Bang/TypeInfer.hs b/src/Bang/TypeInfer.hs index 03fb69d..9127674 100644 --- a/src/Bang/TypeInfer.hs +++ b/src/Bang/TypeInfer.hs @@ -1,174 +1,269 @@ -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeSynonymInstances #-} module Bang.TypeInfer(runTypeInference) where -import Bang.AST(Module) -import Bang.Monad(Compiler) - -runTypeInference :: Module -> Compiler ps Module -runTypeInference x = return x - -{- Better version -import Bang.Monad(Compiler, BangError(..), err, - runPass, getPassState, setPassState, - viewPassState, overPassState, - registerNewName, genName) -import Bang.Syntax.Location(Location, unknownLocation) -import Bang.Syntax.ParserMonad(NameDatabase(..)) -import Bang.Utils.Pretty(BangDoc) -import Control.Lens(set, view, over) -import Control.Lens.TH(makeLenses) -import Data.List(union, nub, concat) +import Bang.AST(Module, moduleDeclarations) +import Bang.AST.Declaration(Declaration(..), ValueDeclaration, + vdName, vdDeclaredType, vdValue, + tdName, tdType) +import Bang.AST.Expression(Expression(..), ConstantValue(..), + lambdaArgumentNames, lambdaBody, + constLocation, constValue, refName) +import Bang.AST.Name(Name, NameEnvironment(..), + nameLocation, nameText, ppName) +import Bang.AST.Type(Type(..), ppType, rtName, ftArgumentType, + ftResultType, taLeftType, taRightType, + mkPrimType, mkFunType, mkTypeRef, + Kind(..)) +import Bang.Monad(Compiler, BangError(..), BangWarning(..), + registerNewName, err', err, warn, + getPassState, mapPassState, runPass) +import Bang.Syntax.Location(Location, fakeLocation) +import Bang.Utils.FreeVars(CanHaveFreeVars(..)) +import Bang.Utils.Pretty(BangDoc, text') +import Control.Lens(view, over) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map -import Data.Set(Set) +import Data.Set(Set, (\\)) import qualified Data.Set as Set -import Data.Text.Lazy(Text, pack) -import Text.PrettyPrint.Annotated(text, (<+>), quotes) +import Text.PrettyPrint.Annotated(text, nest, quotes, ($+$), (<+>)) -data InferenceState = InferenceState { - _nameDatabase :: NameDatabase - } - -makeLenses ''InferenceState - -type Infer a = Compiler InferenceState a - -runInfer :: NameDatabase -> Infer a -> Compiler ps a -runInfer ndb action = snd `fmap` runPass initial action - where initial = InferenceState ndb +runTypeInference :: Module -> Compiler ps Module +runTypeInference x = + do _ <- runPass emptyEnvironment (mapM_ typeInferDecls (view moduleDeclarations x)) + return x -- ----------------------------------------------------------------------------- -data InferError = UnboundVariable Location Name - deriving (Show) +type Infer a = Compiler TypeEnvironment a -instance BangError InferError where - ppError = prettyError +getNamesTypeScheme :: Name -> Infer (Maybe Scheme) +getNamesTypeScheme n = Map.lookup n `fmap` getPassState -prettyError :: InferError -> (Maybe Location, BangDoc) -prettyError e = - case e of - UnboundVariable l n -> - (Just l, text "Unbound variable '" <+> quotes (ppName n)) +addToTypeEnvironment :: [Name] -> [Scheme] -> Infer () +addToTypeEnvironment ns schms = mapPassState (add ns schms) + where + add :: [Name] -> [Scheme] -> TypeEnvironment -> TypeEnvironment + add [] [] acc = acc + add (n:restns) (s:rschms) acc = + Map.insertWithKey errorFn n s (add restns rschms acc) + add _ _ _ = + error "Wackiness has insued." + -- + errorFn k _ _ = error ("Redefinition of " ++ show k) -- ----------------------------------------------------------------------------- -type Substitutions = Map Name Type +type Substitution = Map Name Type -noSubstitutions :: Substitutions -noSubstitutions = Map.empty +nullSubstitution :: Substitution +nullSubstitution = Map.empty -composeSubstitutions :: Substitutions -> Substitutions -> Substitutions +composeSubstitutions :: Substitution -> Substitution -> Substitution composeSubstitutions s1 s2 = Map.map (apply s1) s2 `Map.union` s1 -class Types a where - freeTypeVariables :: a -> Set Name - apply :: Substitutions -> a -> a +class ApplySubst t where + apply :: Substitution -> t -> t -instance Types Type where - freeTypeVariables t = - case t of - TypeUnit _ _ -> Set.empty - TypePrim _ _ _ -> Set.empty - TypeRef _ _ n -> Set.singleton n - TypeLambda _ _ ns e -> Set.unions (map freeTypeVariables ns) `Set.union` - freeTypeVariables e - TypeApp _ _ a b -> freeTypeVariables a `Set.union` freeTypeVariables b - apply substs t = - case t of - TypeRef _ _ n -> case Map.lookup n substs of - Nothing -> t - Just t' -> t' - TypeLambda l k ns e -> TypeLambda l k (apply substs ns) (apply substs e) - TypeApp l k a b -> TypeApp l k (apply substs a) (apply substs b) - _ -> t +instance ApplySubst Type where + apply s (TypeUnit t) = TypeUnit t + apply s (TypePrim t) = TypePrim t + apply s (TypeRef t) = case Map.lookup (view rtName t) s of + Nothing -> TypeRef t + Just t' -> t' + apply s (TypeFun t) = TypeFun (over ftArgumentType (apply s) $ + over ftResultType (apply s) t) + apply s (TypeApp t) = TypeApp (over taLeftType (apply s) $ + over taRightType (apply s) t) -instance Types a => Types [a] where - freeTypeVariables l = Set.unions (map freeTypeVariables l) - apply s = map (apply s) +instance ApplySubst a => ApplySubst [a] where + apply s = map (apply s) -- ----------------------------------------------------------------------------- -inferModule :: Module -> Infer Module -inferModule = undefined +data Scheme = Scheme [Name] Type + +instance CanHaveFreeVars Scheme where + freeVariables (Scheme ns t) = freeVariables t \\ Set.fromList ns + +instance ApplySubst Scheme where + apply s (Scheme vars t) = Scheme vars (apply s t) + +newTypeVar :: Name -> Infer Type +newTypeVar n = + do let loc = view nameLocation n + n' <- registerNewName TypeEnv (view nameText n) + return (mkTypeRef loc Unknown n') + +instantiate :: Scheme -> Infer Type +instantiate (Scheme vars t) = + do refs <- mapM newTypeVar vars + let newSubsts = Map.fromList (zip vars refs) + return (apply newSubsts t) + +-- ----------------------------------------------------------------------------- + +mostGeneralUnifier :: Type -> Type -> Infer Substitution +mostGeneralUnifier a b = + case (a, b) of + (TypeUnit _, TypeUnit _) -> return nullSubstitution + (TypePrim _, TypePrim _) -> return nullSubstitution + (TypeRef t1, t2) -> varBind (view rtName t1) t2 + (t2, TypeRef t1) -> varBind (view rtName t1) t2 + (TypeFun t1, TypeFun t2) -> do let at1 = view ftArgumentType t1 + at2 = view ftArgumentType t2 + s1 <- mostGeneralUnifier at1 at2 + let rt1 = apply s1 (view ftResultType t1) + rt2 = apply s1 (view ftResultType t2) + s2 <- mostGeneralUnifier rt1 rt2 + return (s1 `composeSubstitutions` s2) + (TypeApp t1, TypeApp t2) -> do let lt1 = view taLeftType t1 + lt2 = view taLeftType t2 + s1 <- mostGeneralUnifier lt1 lt2 + let rt1 = apply s1 (view taRightType t1) + rt2 = apply s1 (view taRightType t2) + s2 <- mostGeneralUnifier rt1 rt2 + return (s1 `composeSubstitutions` s2) + _ -> do err' (TypesDontUnify a b) + return nullSubstitution + +varBind :: Name -> Type -> Infer Substitution +varBind u t | TypeRef t' <- t, + view rtName t' == u = return nullSubstitution + | u `Set.member` freeVariables t = do err' (OccursFail u t) + return nullSubstitution + | otherwise = return (Map.singleton u t) + +-- ----------------------------------------------------------------------------- + +type TypeEnvironment = Map Name Scheme + +emptyEnvironment :: TypeEnvironment +emptyEnvironment = Map.empty + +instance ApplySubst TypeEnvironment where + apply s tenv = Map.map (apply s) tenv + +instance CanHaveFreeVars TypeEnvironment where + freeVariables tenv = freeVariables (Map.elems tenv) + +generalize :: TypeEnvironment -> Type -> Scheme +generalize env t = Scheme vars t + where vars = Set.toList (freeVariables t \\ freeVariables env) + +-- ----------------------------------------------------------------------------- + +data InferenceError = InternalError + | TypesDontUnify Type Type + | OccursFail Name Type + | UnboundVariable Name + +instance BangError InferenceError where + ppError = prettyError + +prettyError :: InferenceError -> (Maybe Location, BangDoc) +prettyError e = + case e of + InternalError -> + (Nothing, text "") + TypesDontUnify t1 t2 -> + (Nothing, text "Types don't unify:" $+$ + (nest 3 + (text "first type: " <+> ppType t1 $+$ + text "second type: " <+> ppType t2))) + OccursFail n t -> + (Just (view nameLocation n), + text "Occurs check failed:" $+$ + (nest 3 (text "Type: " <+> ppType t))) + UnboundVariable n -> + (Just (view nameLocation n), + text "Unbound variable (in type checker?):" <+> ppName n) + +data InferenceWarning = TopLevelWithoutType Name Type + | DeclarationMismatch Name Type Type + +instance BangWarning InferenceWarning where + ppWarning = prettyWarning + +prettyWarning :: InferenceWarning -> (Maybe Location, BangDoc) +prettyWarning w = + case w of + TopLevelWithoutType n t -> + (Just (view nameLocation n), + text "Variable" <+> quotes (text' (view nameText n)) <+> + text "is defined without a type." $+$ + text "Inferred type:" $+$ nest 3 (ppType t)) + DeclarationMismatch n dt it -> + (Just (view nameLocation n), + text "Mismatch between declared and inferred type of" <+> + quotes (text' (view nameText n)) $+$ + nest 3 (text "declared type:" <+> ppType dt $+$ + text "inferred type:" <+> ppType it)) + +-- ----------------------------------------------------------------------------- + +-- Infer the type of a group of declarations with cyclic dependencies. +typeInferDecls :: [Declaration] -> Infer () +typeInferDecls decls = + do (names, schemes, decls') <- getInitialSchemes decls + addToTypeEnvironment names schemes + mapM_ typeInferDecl decls' + where + getInitialSchemes [] = + return ([], [], []) + getInitialSchemes ((DeclType td) : rest) = + do (rn, rs, rd) <- getInitialSchemes rest + let n = view tdName td + s = Scheme [] (view tdType td) + return (n:rn, s:rs, rd) + getInitialSchemes ((DeclVal td) : rest) = + do (rn, rs, rd) <- getInitialSchemes rest + return (rn, rs, (td : rd)) + +typeInferDecl :: ValueDeclaration -> Infer () +typeInferDecl vd = + do (subs, t) <- typeInferExpr (view vdValue vd) + let t' = apply subs t + case view vdDeclaredType vd of + Nothing -> + warn (TopLevelWithoutType (view vdName vd) t') + Just dt -> + warn (DeclarationMismatch (view vdName vd) dt t) + +typeInferConst :: Location -> ConstantValue -> + Infer (Substitution, Type) +typeInferConst l (ConstantInt _ _) = + return (nullSubstitution, mkPrimType l "i64") +typeInferConst l (ConstantChar _) = + return (nullSubstitution, mkPrimType l "i8") -- FIXME +typeInferConst l (ConstantString _) = + return (nullSubstitution, mkPrimType l "i8*") -- FIXME +typeInferConst l (ConstantFloat _) = + return (nullSubstitution, mkPrimType l "double") + +typeInferExpr :: Expression -> Infer (Substitution, Type) +typeInferExpr expr = + case expr of + ConstExp e -> + typeInferConst (view constLocation e) (view constValue e) + RefExp e -> + do mscheme <- getNamesTypeScheme (view refName e) + case mscheme of + Nothing -> err (UnboundVariable (view refName e)) + Just scheme -> do t <- instantiate scheme + return (nullSubstitution, t) + LambdaExp e -> + do let argNames = view lambdaArgumentNames e + tvars <- mapM newTypeVar argNames + let tvars' = map (Scheme []) tvars + addToTypeEnvironment argNames tvars' + (s1, t1) <- typeInferExpr (view lambdaBody e) + return (s1, mkFunType' (apply s1 tvars) t1) + where + mkFunType' [] t = t + mkFunType' (x:rest) t = mkFunType fakeLocation x (mkFunType' rest t) + -runTypeInference :: NameDatabase -> Module -> Compiler ps Module -runTypeInference ndb mod = runInfer ndb (inferModule mod) --} --- data Scheme = Scheme [Name] Type --- --- getName :: NameEnvironment -> Text -> Infer Name --- getName env nameText = --- do namedb <- viewPassState nameDatabase --- let key = (env, nameText) --- case Map.lookup key namedb of --- Nothing -> --- do name <- registerNewName env nameText --- overPassState (set nameDatabase (Map.insert key name namedb)) --- return name --- Just name -> --- return name --- --- runTypeInference :: NameDatabase -> Module -> Compiler ps Module --- runTypeInference nameDB mod = --- snd `fmap` (runPass initialState (inferModule mod)) --- where initialState = InferenceState nameDB --- --- type Substitutions = Map Name Type --- --- nullSubst :: Substitutions --- nullSubst = Map.empty --- --- type TypeEnv = Map Name Scheme --- --- class Substitutable a where --- apply :: Substitutions -> a -> Type --- --- instance Substitutable Type where --- apply subs t = --- case t of --- TypeUnit _ _ -> t --- TypePrim _ _ _ -> t --- TypeRef _ _ n -> case Map.lookup n subs of --- Nothing -> t --- Just t' -> t' --- TypeLambda l k ats bt -> --- TypeLambda l k (map (apply subs) ats) (apply subs bt) --- TypeApp l k a b -> --- TypeApp l k (apply subs a) (apply subs b) --- TypeForAll ns t -> --- TypeForAll ns (apply subs t) --- --- instance Substitutable Name where --- apply subs n = --- case Map.lookup n subs of --- Nothing -> TypeRef unknownLocation Star n --- Just t -> t --- --- instantiate :: Scheme -> Infer Type --- instantiate = --- do --- --- inferExpression :: TypeEnv -> Expression -> --- Infer (Substitutions, Type) --- inferExpression typeEnv expr = --- case expr of --- ConstantExp s cv -> do memName <- getName TypeEnv "Memory" --- return (nullSubst, TypeRef s Star memName) --- ReferenceExp s n -> case Map.lookup n typeEnv of --- Nothing -> err (UnboundVariable s n) --- Just t -> do t' <- instantiate t --- return (nullSubst, t') --- LambdaExp s ns e -> do localTypeNames <- mapM (const (genName TypeEnv)) ns --- let localSchemes = map (Scheme [] . TypeRef s Star) localTypeNames --- localEnv = Map.fromList (zip ns localSchemes) --- typeEnv' = typeEnv `Map.union` localEnv --- (s1, t1) <- inferExpression typeEnv' e --- return (s1, TypeLambda s (Star `KindArrow` Star) --- (map (apply s1) localTypeNames) --- t1) --- --- inferModule :: Module -> Infer Module --- inferModule = undefined diff --git a/src/Bang/Utils/FreeVars.hs b/src/Bang/Utils/FreeVars.hs index 632cfc6..a2369a5 100644 --- a/src/Bang/Utils/FreeVars.hs +++ b/src/Bang/Utils/FreeVars.hs @@ -3,11 +3,18 @@ module Bang.Utils.FreeVars( ) where -import Bang.AST.Name(Name) +import Bang.AST.Name(Name) +import Data.Set(Set) +import qualified Data.Set as Set class CanHaveFreeVars a where - freeVariables :: a -> [Name] + freeVariables :: a -> Set Name instance CanHaveFreeVars a => CanHaveFreeVars (Maybe a) where freeVariables (Just x) = freeVariables x - freeVariables Nothing = [] + freeVariables Nothing = Set.empty + +instance CanHaveFreeVars a => CanHaveFreeVars [a] where + freeVariables [] = Set.empty + freeVariables (x:xs) = freeVariables x `Set.union` freeVariables xs +