diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 24c408c..6a7b62d 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -77,7 +77,7 @@ impl Arbitrary for Program { #[derive(Debug)] pub enum TopLevel { Statement(Statement), - Function(Variable, Vec, Expression), + Function(Variable, Vec, Vec, Expression), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b TopLevel @@ -87,7 +87,7 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - TopLevel::Function(name, args, body) => allocator + TopLevel::Function(name, args, stmts, body) => allocator .text("function") .append(allocator.space()) .append(allocator.text(name.as_ref().to_string())) @@ -378,6 +378,7 @@ where #[derive(Clone, Debug, Eq, PartialEq)] pub enum Type { Primitive(PrimitiveType), + Function(Vec, Box), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type diff --git a/src/type_infer/ast.rs b/src/type_infer/ast.rs index 373690c..daffbf7 100644 --- a/src/type_infer/ast.rs +++ b/src/type_infer/ast.rs @@ -68,7 +68,7 @@ where #[derive(Debug)] pub enum TopLevel { Statement(Statement), - Function(Variable, Vec, Expression), + Function(Variable, Vec, Vec, Expression), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b TopLevel @@ -78,22 +78,36 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - TopLevel::Function(name, args, body) => allocator - .text("function") - .append(allocator.space()) - .append(allocator.text(name.as_ref().to_string())) - .append(allocator.space()) - .append( - allocator - .intersperse( - args.iter().map(|x| allocator.text(x.as_ref().to_string())), - ", ", + TopLevel::Function(name, args, stmts, expr) => { + let base = allocator + .text("function") + .append(allocator.space()) + .append(allocator.text(name.as_ref().to_string())) + .append(allocator.space()) + .append( + pretty_comma_separated( + allocator, + &args + .iter() + .map(|x| allocator.text(x.as_ref().to_string())) + .collect(), ) .parens(), - ) - .append(allocator.space()) - .append(body.pretty(allocator)), + ) + .append(allocator.space()); + let mut body = allocator.nil(); + for stmt in stmts { + body = body + .append(stmt.pretty(allocator)) + .append(allocator.text(";")) + .append(allocator.hardline()); + } + body = body.append(expr.pretty(allocator)); + body = body.append(allocator.hardline()); + body = body.braces(); + base.append(body) + } TopLevel::Statement(stmt) => stmt.pretty(allocator), } } @@ -310,6 +324,7 @@ where pub enum Type { Variable(Location, ArcIntern), Primitive(PrimitiveType), + Function(Vec, Box), } impl Type { @@ -327,6 +342,14 @@ where match self { Type::Variable(_, x) => allocator.text(x.to_string()), Type::Primitive(pt) => allocator.text(format!("{}", pt)), + Type::Function(args, rettype) => { + pretty_comma_separated(allocator, &args.iter().collect()) + .parens() + .append(allocator.space()) + .append(allocator.text("->")) + .append(allocator.space()) + .append(rettype.pretty(allocator)) + } } } } @@ -336,6 +359,18 @@ impl fmt::Display for Type { match self { Type::Variable(_, x) => write!(f, "{}", x), Type::Primitive(pt) => pt.fmt(f), + Type::Function(args, ret) => { + write!(f, "(")?; + let mut argiter = args.iter().peekable(); + while let Some(arg) = argiter.next() { + arg.fmt(f)?; + if argiter.peek().is_some() { + write!(f, ",")?; + } + } + write!(f, "->")?; + ret.fmt(f) + } } } } @@ -373,3 +408,16 @@ pub fn gentype() -> Type { Type::Variable(Location::manufactured(), name) } + +fn pretty_comma_separated<'a, D, A, P>( + allocator: &'a D, + args: &Vec

, +) -> pretty::DocBuilder<'a, D, A> +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, + P: Pretty<'a, D, A>, +{ + let individuals = args.iter().map(|x| x.pretty(allocator)); + allocator.intersperse(individuals, ", ") +} diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 56c5732..bff8c2a 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -44,8 +44,41 @@ pub fn convert_top_level( bindings: &mut HashMap, Type>, ) -> Vec { match top_level { - syntax::TopLevel::Function(_, _arg_name, _) => { - unimplemented!() + syntax::TopLevel::Function(name, args, expr) => { + // First, let us figure out what we're going to name this function. If the user + // didn't provide one, we'll just call it "function:" for them. (We'll + // want a name for this function, eventually, so we might as well do it now.) + // + // If they did provide a name, see if we're shadowed. IF we are, then we'll have + // to specialize the name a bit. Otherwise we'll stick with their name. + let funname = match name { + None => ir::gensym("function"), + Some(unbound) => finalize_name(bindings, renames, unbound), + }; + + // Now we manufacture types for the inputs and outputs, and then a type for the + // function itself. We're not going to make any claims on these types, yet; they're + // all just unknown type variables we need to work out. + let argtypes: Vec = args.iter().map(|_| ir::gentype()).collect(); + let rettype = ir::gentype(); + let funtype = Type::Function(argtypes.clone(), Box::new(rettype.clone())); + + // Now let's bind these types into the environment. First, we bind our function + // namae to the function type we just generated. + bindings.insert(funname, funtype); + // And then we attach the argument names to the argument types. (We have to go + // convert all the names, first.) + let iargs: Vec> = + args.iter().map(|x| ArcIntern::new(x.to_string())).collect(); + assert_eq!(argtypes.len(), iargs.len()); + for (arg_name, arg_type) in iargs.iter().zip(argtypes) { + bindings.insert(arg_name.clone(), arg_type.clone()); + } + + let (stmts, expr, ty) = convert_expression(expr, constraint_db, renames, bindings); + constraint_db.push(Constraint::Equivalent(expr.location().clone(), rettype, ty)); + + vec![ir::TopLevel::Function(funname, iargs, stmts, expr)] } syntax::TopLevel::Statement(stmt) => { convert_statement(stmt, constraint_db, renames, bindings) @@ -93,14 +126,7 @@ fn convert_statement( syntax::Statement::Binding(loc, name, expr) => { let (mut prereqs, expr, ty) = convert_expression(expr, constraint_db, renames, bindings); - let iname = ArcIntern::new(name.to_string()); - let final_name = if bindings.contains_key(&iname) { - let new_name = ir::gensym(iname.as_str()); - renames.insert(iname, new_name.clone()); - new_name - } else { - iname - }; + let final_name = finalize_name(bindings, renames, name); bindings.insert(final_name.clone(), ty.clone()); prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); @@ -260,6 +286,20 @@ fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::Va } } +fn finalize_name( + bindings: &HashMap, Type>, + renames: &mut HashMap, ArcIntern>, + name: syntax::Name, +) -> ArcIntern { + if bindings.contains_key(&ArcIntern::new(name.name)) { + let new_name = ir::gensym(&name.name); + renames.insert(ArcIntern::new(name.name.to_string()), new_name.clone()); + new_name + } else { + ArcIntern::new(name.to_string()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 7f8f5d6..763574c 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -16,9 +16,14 @@ pub fn finalize_program( fn finalize_top_level(item: input::TopLevel, resolutions: &TypeResolutions) -> output::TopLevel { match item { - input::TopLevel::Function(name, args, body) => { - output::TopLevel::Function(name, args, finalize_expression(body, resolutions)) - } + input::TopLevel::Function(name, args, mut body, expr) => output::TopLevel::Function( + name, + args, + body.drain(..) + .map(|x| finalize_statement(x, resolutions)) + .collect(), + finalize_expression(expr, resolutions), + ), input::TopLevel::Statement(stmt) => { output::TopLevel::Statement(finalize_statement(stmt, resolutions)) } @@ -73,6 +78,12 @@ fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type None => panic!("Did not resolve type for type variable {}", tvar), Some(pt) => output::Type::Primitive(*pt), }, + input::Type::Function(mut args, ret) => output::Type::Function( + args.drain(..) + .map(|x| finalize_type(x, resolutions)) + .collect(), + Box::new(finalize_type(*ret, resolutions)), + ), } } @@ -89,6 +100,9 @@ fn finalize_val_or_ref( match val { input::Value::Unknown(base, value) => match new_type { + output::Type::Function(_, _) => { + panic!("Somehow inferred that a constant was a function") + } output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value( loc, new_type,