diff --git a/src/Bang/Monad.hs b/src/Bang/Monad.hs index 67e13aa..5dca725 100644 --- a/src/Bang/Monad.hs +++ b/src/Bang/Monad.hs @@ -10,7 +10,7 @@ module Bang.Monad( , runCompiler , runPass , getPassState, setPassState, overPassState, viewPassState - , genName, genTypeRef, genVarRef + , registerNewName, genName, genTypeRef, genVarRef , warn, err ) where @@ -83,13 +83,14 @@ runCompiler cmd opts action = Left _ -> exit ("Unable to open file '" ++ path ++ "'") Right txt -> snd `fmap` unCompiler (action orig txt) (initialState cmd) -runPass :: s2 -> (Compiler s2 a) -> Compiler s1 a +runPass :: s2 -> (Compiler s2 a) -> Compiler s1 (s2, a) runPass s2 action = Compiler (\ cst1 -> do let cst2 = set csPassState s2 cst1 s1 = view csPassState cst1 (cst2', v) <- unCompiler action cst2 - return (set csPassState s1 cst2', v)) + let retval = (view csPassState cst2', v) + return (set csPassState s1 cst2', retval)) getPassState :: Compiler s s getPassState = Compiler (\ st -> return (st, view csPassState st)) @@ -105,6 +106,13 @@ viewPassState l = Compiler (\ st -> return (st, view (csPassState . l) st)) -- ----------------------------------------------------------------------------- +registerNewName :: NameEnvironment -> Text -> Compiler s Name +registerNewName env name = + Compiler (\ st -> + do let current = view csNextIdent st + res = Name unknownLocation env current name + return (over csNextIdent (+1) st, res)) + genName :: NameEnvironment -> Compiler s Name genName env = Compiler (\ st -> do let current = view csNextIdent st diff --git a/src/Bang/Syntax/AST.hs b/src/Bang/Syntax/AST.hs index 5595c7f..f5e7b1d 100644 --- a/src/Bang/Syntax/AST.hs +++ b/src/Bang/Syntax/AST.hs @@ -47,9 +47,8 @@ data ConstantValue = ConstantInt Word Text data Type = TypeUnit Location Kind | TypePrim Location Kind Text | TypeRef Location Kind Name - | TypeLambda Location Kind Type Type + | TypeLambda Location Kind [Type] Type | TypeApp Location Kind Type Type - | TypeForAll [Name] Type deriving (Show) instance Eq Type where @@ -58,7 +57,6 @@ instance Eq Type where (TypeRef _ _ n) == (TypeRef _ _ m) = n == m (TypeLambda _ _ at et) == (TypeLambda _ _ bt ft) = (at == bt) && (et == ft) (TypeApp _ _ at bt) == (TypeApp _ _ ct dt) = (at == ct) && (bt == dt) - (TypeForAll ns t) == (TypeForAll ms u) = (ns == ms) && (t == u) _ == _ = False kind :: Type -> Kind @@ -67,7 +65,6 @@ kind (TypePrim _ k _) = k kind (TypeRef _ k _) = k kind (TypeLambda _ k _ _) = k kind (TypeApp _ k _ _) = k -kind (TypeForAll _ t) = kind t data Kind = Star | KindArrow Kind Kind diff --git a/src/Bang/Syntax/Parser.y b/src/Bang/Syntax/Parser.y index 187a8d5..5be79fe 100644 --- a/src/Bang/Syntax/Parser.y +++ b/src/Bang/Syntax/Parser.y @@ -181,14 +181,14 @@ Type :: { Type } [] -> return result xs -> do unregisterNames TypeEnv xs - return (TypeForAll xs result) + return result } RawType :: { (Type, [Name]) } : RawType '->' BaseType {% do let (p1, names1) = $1 (p2, names2) = $3 - return (TypeLambda $2 (Star `KindArrow` Star) p1 p2, union names1 names2) + return (TypeLambda $2 (Star `KindArrow` Star) [p1] p2, union names1 names2) } | BaseType { $1 } diff --git a/src/Bang/Syntax/ParserMonad.hs b/src/Bang/Syntax/ParserMonad.hs index fbb5560..0704c96 100644 --- a/src/Bang/Syntax/ParserMonad.hs +++ b/src/Bang/Syntax/ParserMonad.hs @@ -2,6 +2,7 @@ {-# LANGUAGE TemplateHaskell #-} module Bang.Syntax.ParserMonad( Parser + , NameDatabase , runParser , addFixities , registerName @@ -21,7 +22,7 @@ import Bang.Syntax.Location(Location(..), Located(..), advanceWith', locatedAt) import Bang.Syntax.ParserError(ParserError(..)) import Bang.Syntax.Token(Token(..), Fixity) -import Control.Lens(view, set, over) +import Control.Lens(view, set, over, _1) import Control.Lens.TH(makeLenses) import Control.Monad(forM_) import Data.Char(digitToInt, isSpace) @@ -30,6 +31,8 @@ import qualified Data.Map.Strict as Map import Data.Text.Lazy(Text) import qualified Data.Text.Lazy as T +type NameDatabase = Map (NameEnvironment, Text) Name + data ParserState = ParserState { _psPrecTable :: Map Text Fixity , _psNameDatabase :: Map (NameEnvironment, Text) Name @@ -42,8 +45,9 @@ makeLenses ''ParserState type Parser a = Compiler ParserState a -runParser :: Origin -> Text -> Parser a -> Compiler ps a -runParser origin stream action = runPass pstate action +runParser :: Origin -> Text -> Parser a -> Compiler ps (NameDatabase, a) +runParser origin stream action = + over _1 (view psNameDatabase) `fmap` runPass pstate action where initInput = AlexInput initialPosition stream pstate = ParserState Map.empty Map.empty 1 origin initInput diff --git a/src/Bang/Syntax/Pretty.hs b/src/Bang/Syntax/Pretty.hs index bed681d..5b896e9 100644 --- a/src/Bang/Syntax/Pretty.hs +++ b/src/Bang/Syntax/Pretty.hs @@ -1,5 +1,9 @@ module Bang.Syntax.Pretty( ppModule + , ppDeclaration + , ppExpression + , ppType + , ppName ) where @@ -53,11 +57,8 @@ ppType t = TypeUnit _ _ -> text "()" TypePrim _ _ n -> text (unpack n) TypeRef _ _ n -> ppName n - TypeLambda _ _ a b -> ppType a <> space <> text "->" <> space <> ppType b + TypeLambda _ _ as b -> hsep (map ppType as) <> space <> text "->" <> space <> ppType b TypeApp _ _ a b -> ppType a <> space <> ppType b - TypeForAll ns s -> - text "∀" <> space <> hsep (punctuate comma (map ppName ns)) <> - space <> text "." <> space <> ppType s text' :: Text -> Doc a text' = text . unpack diff --git a/src/Bang/TypeInfer.hs b/src/Bang/TypeInfer.hs index 34c5512..875edd0 100644 --- a/src/Bang/TypeInfer.hs +++ b/src/Bang/TypeInfer.hs @@ -1,191 +1,169 @@ -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE TemplateHaskell #-} -module Bang.TypeInfer +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} +module Bang.TypeInfer(runTypeInference) where import Bang.Monad(Compiler, BangError(..), err, - getPassState, setPassState) + runPass, getPassState, setPassState, + viewPassState, overPassState, + registerNewName, genName) import Bang.Syntax.AST -import Bang.Syntax.Location(unknownLocation) -import Control.Lens(view, over) +import Bang.Syntax.Location(Location, unknownLocation) +import Bang.Syntax.ParserMonad(NameDatabase(..)) +import Bang.Syntax.Pretty(ppName) +import Bang.Utils.Pretty(BangDoc) +import Control.Lens(set, view, over) import Control.Lens.TH(makeLenses) import Data.List(union, nub, concat) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map -import Data.Text.Lazy(pack) - --- ----------------------------------------------------------------------------- - -type Substitution = Map Name Type - -class Types t where - apply :: Substitution -> t -> t - tv :: t -> [Name] - -nullSubstitution :: Substitution -nullSubstitution = Map.empty - -(⟼) :: Name -> Type -> Substitution -(⟼) = Map.singleton - -infixr 4 @@ -(@@) :: Substitution -> Substitution -> Substitution -(@@) s1 s2 = - let s2' = Map.map (\ t -> apply s1 t) s2 - in Map.union s2' s1 - --- ----------------------------------------------------------------------------- - -data InferenceError = UnificationError Type Type - | OccursCheckFails Name Type - | KindCheckFails Name Type - | MatchFailure Type Type - | UnboundIdentifier Name - | MergeFailure Substitution Substitution - deriving (Show) - -instance BangError InferenceError where - ppError = undefined +import Data.Set(Set) +import qualified Data.Set as Set +import Data.Text.Lazy(Text, pack) +import Text.PrettyPrint.Annotated(text, (<+>), quotes) data InferenceState = InferenceState { - _istCurrentSubstitution :: Substitution - , _istNextIdentifier :: Word + _nameDatabase :: NameDatabase } makeLenses ''InferenceState type Infer a = Compiler InferenceState a +runInfer :: NameDatabase -> Infer a -> Compiler ps a +runInfer ndb action = snd `fmap` runPass initial action + where initial = InferenceState ndb + -- ----------------------------------------------------------------------------- -merge :: Substitution -> Substitution -> Infer Substitution -merge s1 s2 | agree = return (Map.union s1 s2) - | otherwise = err (MergeFailure s1 s2) - where - names = Map.keys (Map.intersection s1 s2) - agree = all (\ v -> - let refv = TypeRef (error "Internal error, TypeInfer") Star v - in apply s1 refv == apply s2 refv) - names +data InferError = UnboundVariable Location Name + deriving (Show) -mostGeneralUnifier :: Type -> Type -> Infer Substitution -mostGeneralUnifier t1 t2 = - case (t1, t2) of - (TypeApp _ _ l r, TypeApp _ _ l' r') -> - do s1 <- mostGeneralUnifier l l' - s2 <- mostGeneralUnifier (apply s1 r) (apply s1 r') - return (s2 @@ s1) - (u@(TypeRef _ _ _), t) -> varBind u t - (t, u@(TypeRef _ _ _)) -> varBind u t - (TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution - _ -> err (UnificationError t1 t2) +instance BangError InferError where + ppError = prettyError -varBind :: Type -> Type -> Infer Substitution -varBind = undefined --- | TypeRef _ _ u' <- t, u' == u = return nullSubstitution --- | u `elem` tv t = err (OccursCheckFails u t) --- | k /= kind t = err (KindCheckFails u t) --- | otherwise = return (u ⟼ t) --- -match :: Type -> Type -> Infer Substitution -match t1 t2 = - case (t1, t2) of - (TypeApp _ _ l r, TypeApp _ _ l' r') -> - do sl <- match l l' - sr <- match r r' - merge sl sr - (TypeRef _ k u, t) | k == kind t -> return (u ⟼ t) - (TypePrim _ _ tc1, TypePrim _ _ tc2) | tc1 == tc2 -> return nullSubstitution - _ -> err (MatchFailure t1 t2) +prettyError :: InferError -> (Maybe Location, BangDoc) +prettyError e = + case e of + UnboundVariable l n -> + (Just l, text "Unbound variable '" <+> quotes (ppName n)) -data Scheme = Forall [Kind] Type - -instance Types Scheme where - apply s (Forall ks t) = Forall ks (apply s t) - tv (Forall _ qt) = tv qt +-- ----------------------------------------------------------------------------- -data Assumption = Name :>: Scheme +type Substitutions = Map Name Type -instance Types Assumption where - apply s (i :>: sc) = i :>: (apply s sc) - tv (_ :>: sc) = tv sc +noSubstitutions :: Substitutions +noSubstitutions = Map.empty -find :: Name -> [Assumption] -> Infer Scheme -find i [] = err (UnboundIdentifier i) -find i ((i' :>: sc) : as) | i == i' = return sc - | otherwise = find i as +composeSubstitutions :: Substitutions -> Substitutions -> Substitutions +composeSubstitutions s1 s2 = Map.map (apply s1) s2 `Map.union` s1 + +class Types a where + freeTypeVariables :: a -> Set Name + apply :: Substitutions -> a -> a instance Types Type where - apply s v@(TypeRef _ _ n) = case Map.lookup n s of - Just t -> t - Nothing -> v - apply s (TypeApp l k t u) = TypeApp l k (apply s t) (apply s u) - apply _ t = t - -- - tv (TypeRef _ _ n) = [n] - tv (TypeApp _ _ t u) = tv t `union` tv u - tv _ = [] + freeTypeVariables t = + case t of + TypeUnit _ _ -> Set.empty + TypePrim _ _ _ -> Set.empty + TypeRef _ _ n -> Set.singleton n + TypeLambda _ _ ns e -> Set.unions (map freeTypeVariables ns) `Set.union` + freeTypeVariables e + TypeApp _ _ a b -> freeTypeVariables a `Set.union` freeTypeVariables b + apply substs t = + case t of + TypeRef _ _ n -> case Map.lookup n substs of + Nothing -> t + Just t' -> t' + TypeLambda l k ns e -> TypeLambda l k (apply substs ns) (apply substs e) + TypeApp l k a b -> TypeApp l k (apply substs a) (apply substs b) + _ -> t -instance Types [Type] where - apply s = map (apply s) - tv = nub . concat . map tv +instance Types a => Types [a] where + freeTypeVariables l = Set.unions (map freeTypeVariables l) + apply s = map (apply s) -getSubstitution :: Infer Substitution -getSubstitution = view istCurrentSubstitution `fmap` getPassState +-- ----------------------------------------------------------------------------- -extendSubstitution :: Substitution -> Infer () -extendSubstitution s' = - do s <- getPassState - setPassState (over istCurrentSubstitution (s' @@) s) +inferModule :: Module -> Infer Module +inferModule = undefined -unify :: Type -> Type -> Infer () -unify t1 t2 = - do s <- getSubstitution - u <- mostGeneralUnifier (apply s t1) (apply s t2) - extendSubstitution u +runTypeInference :: NameDatabase -> Module -> Compiler ps Module +runTypeInference ndb mod = runInfer ndb (inferModule mod) -gensym :: Kind -> Infer Type -gensym k = - do s <- getPassState - setPassState (over istNextIdentifier (+1) s) - let num = view istNextIdentifier s - str = "gensym:" ++ show num - name = Name unknownLocation TypeEnv num (pack str) - return (TypeRef unknownLocation k name) - -data Predicate = IsIn String Type - deriving (Eq) - -inferConstant :: ConstantValue -> Infer ([Predicate], Type) -inferConstant c = - do v <- gensym Star - let constraint | ConstantInt _ _ <- c = IsIn "IntLike" v - | ConstantChar _ <- c = IsIn "CharLake" v - | ConstantString _ <- c = IsIn "StringLike" v - | ConstantFloat _ <- c = IsIn "FloatLike" v - return ([constraint], v) - -data ClassEnvironment = [Predicate] :=> Type - -freshInst :: Scheme -> Infer ClassEnvironment -freshInst = undefined - -inferExpression :: ClassEnvironment -> [Assumption] -> - Expression -> - Infer ([Predicate], Type) -inferExpression _classEnv assumpts expr = - case expr of - ConstantExp _ cv -> inferConstant cv - ReferenceExp _ n -> do sc <- find n assumpts - (ps :=> t) <- freshInst sc - return (ps, t) - LambdaExp _ _ _ -> error "FIXME, here" - -infer :: Module -> Infer Module -infer = undefined - -typeInfer :: Word -> Module -> Either InferenceError Module -typeInfer = undefined +-- data Scheme = Scheme [Name] Type +-- +-- getName :: NameEnvironment -> Text -> Infer Name +-- getName env nameText = +-- do namedb <- viewPassState nameDatabase +-- let key = (env, nameText) +-- case Map.lookup key namedb of +-- Nothing -> +-- do name <- registerNewName env nameText +-- overPassState (set nameDatabase (Map.insert key name namedb)) +-- return name +-- Just name -> +-- return name +-- +-- runTypeInference :: NameDatabase -> Module -> Compiler ps Module +-- runTypeInference nameDB mod = +-- snd `fmap` (runPass initialState (inferModule mod)) +-- where initialState = InferenceState nameDB +-- +-- type Substitutions = Map Name Type +-- +-- nullSubst :: Substitutions +-- nullSubst = Map.empty +-- +-- type TypeEnv = Map Name Scheme +-- +-- class Substitutable a where +-- apply :: Substitutions -> a -> Type +-- +-- instance Substitutable Type where +-- apply subs t = +-- case t of +-- TypeUnit _ _ -> t +-- TypePrim _ _ _ -> t +-- TypeRef _ _ n -> case Map.lookup n subs of +-- Nothing -> t +-- Just t' -> t' +-- TypeLambda l k ats bt -> +-- TypeLambda l k (map (apply subs) ats) (apply subs bt) +-- TypeApp l k a b -> +-- TypeApp l k (apply subs a) (apply subs b) +-- TypeForAll ns t -> +-- TypeForAll ns (apply subs t) +-- +-- instance Substitutable Name where +-- apply subs n = +-- case Map.lookup n subs of +-- Nothing -> TypeRef unknownLocation Star n +-- Just t -> t +-- +-- instantiate :: Scheme -> Infer Type +-- instantiate = +-- do +-- +-- inferExpression :: TypeEnv -> Expression -> +-- Infer (Substitutions, Type) +-- inferExpression typeEnv expr = +-- case expr of +-- ConstantExp s cv -> do memName <- getName TypeEnv "Memory" +-- return (nullSubst, TypeRef s Star memName) +-- ReferenceExp s n -> case Map.lookup n typeEnv of +-- Nothing -> err (UnboundVariable s n) +-- Just t -> do t' <- instantiate t +-- return (nullSubst, t') +-- LambdaExp s ns e -> do localTypeNames <- mapM (const (genName TypeEnv)) ns +-- let localSchemes = map (Scheme [] . TypeRef s Star) localTypeNames +-- localEnv = Map.fromList (zip ns localSchemes) +-- typeEnv' = typeEnv `Map.union` localEnv +-- (s1, t1) <- inferExpression typeEnv' e +-- return (s1, TypeLambda s (Star `KindArrow` Star) +-- (map (apply s1) localTypeNames) +-- t1) +-- +-- inferModule :: Module -> Infer Module +-- inferModule = undefined diff --git a/src/Main.hs b/src/Main.hs index bb49ff5..65fccb1 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -3,6 +3,7 @@ import Bang.Monad import Bang.Syntax.Lexer() import Bang.Syntax.Parser(runParser, parseModule) import Bang.Syntax.Pretty(ppModule) +import Bang.TypeInfer(runTypeInference) import Data.Version(showVersion) import Paths_bang(version) import Text.PrettyPrint.Annotated(render) @@ -10,28 +11,11 @@ import Text.PrettyPrint.Annotated(render) main :: IO () main = getCommand >>= \ cmd -> case cmd of - Parse o -> do mdl <- runCompiler cmd o (\ r t -> runParser r t parseModule) + Parse o -> do (_, mdl) <- runCompiler cmd o (\ r t -> runParser r t parseModule) + putStrLn (render (ppModule mdl)) + TypeCheck o -> do mdl <- runCompiler cmd o (\ r t -> + do (ndb, mdl) <- runParser r t parseModule + runTypeInference ndb mdl) putStrLn (render (ppModule mdl)) - TypeCheck _ -> undefined Help -> putStrLn helpString Version -> putStrLn ("Bang tool, version " ++ showVersion version) - --- run :: CommandsWithInputFile o => o -> (FilePath -> Text -> IO ()) -> IO () --- run opts action = --- do let path = view inputFile opts --- mtxt <- tryJust (guard . isDoesNotExistError) (T.readFile path) --- case mtxt of --- Left _ -> exit ("Unable to open file '" ++ path ++ "'") --- Right txt -> action path txt --- --- withParsed :: (Module -> IO ()) -> FilePath -> Text -> IO () --- withParsed action path body = --- case parseModule (File path) body of --- Left err -> exit (show err) --- Right mdl -> action mdl --- --- withInferred :: (Module -> IO ()) -> Module -> IO () --- withInferred action mdl = --- case typeInfer 0 mdl of --- Left err -> exit (show err) --- Right mdl' -> action mdl'