From 9d41cf0da7fffb06e5bc55c42533211e727ef3ec Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Mon, 5 Feb 2024 17:30:16 -0600 Subject: [PATCH] ran into another type inference problem --- examples/basic/function0001.ngr | 5 +++ src/backend/error.rs | 7 ++-- src/backend/into_crane.rs | 43 ++++++++++++++++++++-- src/eval.rs | 20 +++++++++++ src/ir/arbitrary.rs | 7 ++-- src/ir/ast.rs | 26 +++++++++++--- src/ir/eval.rs | 27 ++++++++++++++ src/ir/top_level.rs | 2 +- src/syntax.rs | 2 +- src/syntax/ast.rs | 6 ++++ src/syntax/eval.rs | 27 ++++++++++++++ src/syntax/parser.lalrpop | 15 ++++++++ src/syntax/pretty.rs | 7 +++- src/syntax/validate.rs | 11 ++++++ src/type_infer/convert.rs | 63 +++++++++++++++++++++++++++++++-- src/type_infer/finalize.rs | 9 +++++ src/type_infer/solve.rs | 14 +++++--- 17 files changed, 267 insertions(+), 24 deletions(-) create mode 100644 examples/basic/function0001.ngr diff --git a/examples/basic/function0001.ngr b/examples/basic/function0001.ngr new file mode 100644 index 0000000..1a845d8 --- /dev/null +++ b/examples/basic/function0001.ngr @@ -0,0 +1,5 @@ +x = 1; +function add_x(y) x + y +a = 3; +result = add_x(a); +print result; \ No newline at end of file diff --git a/src/backend/error.rs b/src/backend/error.rs index e5e2a1e..c6dc995 100644 --- a/src/backend/error.rs +++ b/src/backend/error.rs @@ -53,10 +53,9 @@ pub enum BackendError { impl From for Diagnostic { fn from(value: BackendError) -> Self { match value { - BackendError::Cranelift(me) => { - Diagnostic::error().with_message(format!("Internal cranelift error: {}", me)) - .with_notes(vec![format!("{:?}", me)]) - } + BackendError::Cranelift(me) => Diagnostic::error() + .with_message(format!("Internal cranelift error: {}", me)) + .with_notes(vec![format!("{:?}", me)]), BackendError::BuiltinError(me) => { Diagnostic::error().with_message(format!("Internal runtime function error: {}", me)) } diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 81a9142..740d101 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -4,7 +4,8 @@ use crate::eval::PrimitiveType; use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; use crate::syntax::{ConstantType, Location}; use cranelift_codegen::ir::{ - self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName + self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, MemFlags, Signature, + UserFuncName, }; use cranelift_codegen::isa::CallConv; use cranelift_codegen::Context; @@ -29,7 +30,9 @@ impl ReferenceBuilder { ReferenceBuilder::Global(ty, gv) => { let cranelift_type = ir::Type::from(*ty); let ptr_value = builder.ins().symbol_value(types::I64, *gv); - let value = builder.ins().load(cranelift_type, MemFlags::new(), ptr_value, 0); + let value = builder + .ins() + .load(cranelift_type, MemFlags::new(), ptr_value, 0); (value, *ty) } @@ -435,7 +438,8 @@ impl Backend { // Look up the value for the variable. Because this might be a // global variable (and that requires special logic), we just turn // this into an `Expression` and re-use the logic in that implementation. - let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var.clone()); + let fake_ref = + ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var.clone()); let (val, vtype) = self.compile_value_or_ref(fake_ref, variables, builder)?; let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); @@ -473,6 +477,39 @@ impl Backend { variables.insert(name, ReferenceBuilder::Local(value_type, variable)); Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)) } + + Expression::Call(_, _, function, args) => { + let (arguments, _argument_types): (Vec<_>, Vec<_>) = args + .into_iter() + .map(|x| self.compile_value_or_ref(x, variables, builder)) + .collect::,BackendError>>()? + .into_iter() + .unzip(); + + match *function { + ValueOrRef::Value(_, _, _) => { + panic!("Can't use a value for a function") + } + + ValueOrRef::Ref(_, result_type, name) => match self.defined_functions.get(&name) { + None => panic!("Couldn't find function {} to call", name), + Some(function) => { + let func_ref = self.module.declare_func_in_func(*function, builder.func); + let call = builder.ins().call(func_ref, &arguments); + let results = builder.inst_results(call); + + match results { + [] => Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)), + [result] => match result_type { + Type::Primitive(ct) => Ok((*result, ct.into())), + Type::Function(_, _) => panic!("return value is a function?"), + } + _ => panic!("don't support multi-value returns yet"), + } + } + } + } + } } } diff --git a/src/eval.rs b/src/eval.rs index f0fee04..182f640 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -39,6 +39,7 @@ mod primtype; mod value; use cranelift_module::ModuleError; +use internment::ArcIntern; pub use primop::PrimOpError; pub use primtype::PrimitiveType; pub use value::Value; @@ -76,6 +77,15 @@ pub enum EvalError { UnknownPrimType(#[from] UnknownPrimType), #[error("Variable lookup failed for {1} at {0:?}")] LookupFailed(crate::syntax::Location, String), + #[error("Attempted to call something that wasn't a function at {0:?} (it was a {1})")] + NotAFunction(crate::syntax::Location, Value), + #[error("Wrong argument call for function ({1:?}) at {0:?}; expected {2}, saw {3}")] + WrongArgCount( + crate::syntax::Location, + Option>, + usize, + usize, + ), } impl PartialEq> for EvalError { @@ -129,6 +139,16 @@ impl PartialEq> for EvalError { EvalError::UnknownPrimType(b) => a == b, _ => false, }, + + EvalError::NotAFunction(a, b) => match other { + EvalError::NotAFunction(x, y) => a == x && b == y, + _ => false, + }, + + EvalError::WrongArgCount(a, b, c, d) => match other { + EvalError::WrongArgCount(w, x, y, z) => a == w && b == x && c == y && d == z, + _ => false, + }, } } } diff --git a/src/ir/arbitrary.rs b/src/ir/arbitrary.rs index acf3cc8..1ac45bb 100644 --- a/src/ir/arbitrary.rs +++ b/src/ir/arbitrary.rs @@ -1,5 +1,7 @@ use crate::eval::PrimitiveType; -use crate::ir::{Expression, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, Variable}; +use crate::ir::{ + Expression, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, Variable, +}; use crate::syntax::Location; use crate::util::scoped_map::ScopedMap; use proptest::strategy::{NewTree, Strategy, ValueTree}; @@ -300,7 +302,8 @@ fn generate_random_expression( if !next_type.is_void() { let name = generate_random_name(rng); env.insert(name.clone(), next_type.clone()); - next = Expression::Bind(Location::manufactured(), name, next_type, Box::new(next)); + next = + Expression::Bind(Location::manufactured(), name, next_type, Box::new(next)); } stmts.push(next); } diff --git a/src/ir/ast.rs b/src/ir/ast.rs index c065414..cfdc1f4 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -160,6 +160,7 @@ pub enum Expression { Primitive(Location, Type, Primitive, Vec>), Block(Location, Type, Vec>), Print(Location, Variable), + Call(Location, Type, Box>, Vec>), Bind(Location, Variable, Type, Box>), } @@ -173,6 +174,7 @@ impl Expression { Expression::Primitive(_, t, _, _) => t.clone(), Expression::Block(_, t, _) => t.clone(), Expression::Print(_, _) => Type::void(), + Expression::Call(_, t, _, _) => t.clone(), Expression::Bind(_, _, _, _) => Type::void(), } } @@ -186,6 +188,7 @@ impl Expression { Expression::Primitive(l, _, _, _) => l, Expression::Block(l, _, _) => l, Expression::Print(l, _) => l, + Expression::Call(l, _, _, _) => l, Expression::Bind(l, _, _, _) => l, } } @@ -221,6 +224,12 @@ where Expression::Primitive(_, _, op, exprs) => { allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) } + Expression::Call(_, _, fun, args) => { + let args = args.iter().map(|x| x.pretty(allocator)); + let comma_sepped_args = + allocator.intersperse(args, crate::syntax::pretty::CommaSep {}); + fun.pretty(allocator).append(comma_sepped_args.parens()) + } Expression::Block(_, _, exprs) => match exprs.split_last() { None => allocator.text("()"), Some((last, &[])) => last.pretty(allocator), @@ -660,13 +669,20 @@ impl TryFrom for Type { fn try_from(value: TypeOrVar) -> Result { match value { TypeOrVar::Function(args, ret) => { - let args = args - .into_iter() + let converted_args = args + .iter() + .cloned() .map(Type::try_from) - .collect::>()?; - let ret = Type::try_from(*ret)?; + .collect::>(); + let converted_ret = Type::try_from((*ret).clone()); + + if let Ok(args) = converted_args { + if let Ok(ret) = converted_ret { + return Ok(Type::Function(args, Box::new(ret))); + } + } - Ok(Type::Function(args, Box::new(ret))) + Err(TypeOrVar::Function(args, ret)) } TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)), diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 30bf920..27f1e84 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -103,6 +103,33 @@ where env.insert(name.clone(), value); Ok(Value::Void) } + + Expression::Call(loc, _, fun, args) => { + let function = fun.eval(env)?; + + match function { + Value::Closure(name, mut env, arguments, body) => { + if args.len() != arguments.len() { + return Err(EvalError::WrongArgCount( + loc.clone(), + name, + arguments.len(), + args.len(), + )); + } + + env.new_scope(); + for (name, value) in arguments.into_iter().zip(args.into_iter()) { + let value = value.eval(&mut env)?; + env.insert(name, value); + } + let result = body.eval(&mut env, stdout)?; + env.release_scope(); + Ok(result) + } + _ => Err(EvalError::NotAFunction(loc.clone(), function)), + } + } } } } diff --git a/src/ir/top_level.rs b/src/ir/top_level.rs index c3459fb..728472d 100644 --- a/src/ir/top_level.rs +++ b/src/ir/top_level.rs @@ -38,7 +38,7 @@ impl Expression { let mut tlvs = expr.get_top_level_variables(); tlvs.insert(name.clone(), ty.clone()); tlvs - }, + } _ => HashMap::new(), } } diff --git a/src/syntax.rs b/src/syntax.rs index 0047573..b532f02 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -37,7 +37,7 @@ lalrpop_mod!( parser, "/syntax/parser.rs" ); -mod pretty; +pub mod pretty; mod validate; #[cfg(test)] diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index de005ec..89928e8 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -124,6 +124,7 @@ pub enum Expression { Reference(Location, String), Cast(Location, String, Box), Primitive(Location, String, Vec), + Call(Location, Box, Vec), Block(Location, Vec), } @@ -146,6 +147,10 @@ impl PartialEq for Expression { Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, _ => false, }, + Expression::Call(_, f1, a1) => match other { + Expression::Call(_, f2, a2) => f1 == f2 && a1 == a2, + _ => false, + }, Expression::Block(_, stmts1) => match other { Expression::Block(_, stmts2) => stmts1 == stmts2, _ => false, @@ -162,6 +167,7 @@ impl Expression { Expression::Reference(loc, _) => loc, Expression::Cast(loc, _, _) => loc, Expression::Primitive(loc, _, _) => loc, + Expression::Call(loc, _, _) => loc, Expression::Block(loc, _) => loc, } } diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index 6093633..e84a0b9 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -118,6 +118,33 @@ impl Expression { Ok(Value::calculate(op, arg_values)?) } + Expression::Call(loc, fun, args) => { + let function = fun.eval(stdout, env)?; + + match function { + Value::Closure(name, mut closure_env, arguments, body) => { + if args.len() != arguments.len() { + return Err(EvalError::WrongArgCount( + loc.clone(), + name, + arguments.len(), + args.len(), + )); + } + + closure_env.new_scope(); + for (name, value) in arguments.into_iter().zip(args.iter()) { + let value = value.eval(stdout, env)?; + closure_env.insert(name, value); + } + let result = body.eval(stdout, &mut closure_env)?; + closure_env.release_scope(); + Ok(result) + } + _ => Err(EvalError::NotAFunction(loc.clone(), function)), + } + } + Expression::Block(_, stmts) => { let mut result = Value::Void; diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index 8c66e63..03ac419 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -205,9 +205,24 @@ UnaryExpression: Expression = { Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]), "<" "> ">" => Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)), + CallExpression, +} + +CallExpression: Expression = { + "(" ")" => + Expression::Call(Location::new(file_idx, s..e), Box::new(f), args), AtomicExpression, } +CallArguments: Vec = { + => vec![], + => vec![e], + "," => { + args.push(e); + args + } +} + // finally, we describe our lowest-level expressions as "atomic", because // they cannot be further divided into parts AtomicExpression: Expression = { diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index 859f51f..6cafc43 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -105,6 +105,11 @@ where let comma_sepped_args = allocator.intersperse(args, CommaSep {}); call.append(comma_sepped_args.parens()) } + Expression::Call(_, fun, args) => { + let args = args.iter().map(|x| x.pretty(allocator)); + let comma_sepped_args = allocator.intersperse(args, CommaSep {}); + fun.pretty(allocator).append(comma_sepped_args.parens()) + } Expression::Block(_, stmts) => match stmts.split_last() { None => allocator.text("()"), Some((last, &[])) => last.pretty(allocator), @@ -167,7 +172,7 @@ fn type_suffix(x: &Option) -> &'static str { } #[derive(Clone, Copy)] -struct CommaSep {} +pub struct CommaSep {} impl<'a, D, A> Pretty<'a, D, A> for CommaSep where diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index 1503c65..331d5e6 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -212,6 +212,17 @@ impl Expression { (errors, warnings) } + Expression::Call(_, func, args) => { + let (mut errors, mut warnings) = func.validate(variable_map); + + for arg in args.iter() { + let (mut e, mut w) = arg.validate(variable_map); + errors.append(&mut e); + warnings.append(&mut w); + } + + (errors, warnings) + } Expression::Block(_, stmts) => { let mut errors = vec![]; let mut warnings = vec![]; diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 7e06013..9f2f98a 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -286,9 +286,7 @@ fn convert_expression( let (aexp, atype) = convert_expression(arg, constraint_db, renames, bindings); let (aprereqs, asimple) = simplify_expr(aexp); - if let Some(prereq) = aprereqs { - prereqs.push(prereq); - } + merge_prereq(&mut prereqs, aprereqs); nargs.push(asimple); atypes.push(atype); } @@ -313,6 +311,59 @@ fn convert_expression( } } + syntax::Expression::Call(loc, fun, args) => { + let return_type = ir::TypeOrVar::new(); + let arg_types = args + .iter() + .map(|_| ir::TypeOrVar::new()) + .collect::>(); + + let (new_fun, new_fun_type) = + convert_expression(*fun, constraint_db, renames, bindings); + let target_fun_type = + ir::TypeOrVar::Function(arg_types.clone(), Box::new(return_type.clone())); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + new_fun_type, + target_fun_type, + )); + let mut prereqs = vec![]; + + let (fun_prereqs, fun) = simplify_expr(new_fun); + merge_prereq(&mut prereqs, fun_prereqs); + + let new_args = args + .into_iter() + .zip(arg_types.into_iter()) + .map(|(arg, target_type)| { + let (new_arg, inferred_type) = + convert_expression(arg, constraint_db, renames, bindings); + let location = new_arg.location().clone(); + let (arg_prereq, new_valref) = simplify_expr(new_arg); + merge_prereq(&mut prereqs, arg_prereq); + constraint_db.push(Constraint::Equivalent( + location, + inferred_type, + target_type, + )); + new_valref + }) + .collect(); + + let last_call = + ir::Expression::Call(loc.clone(), return_type.clone(), Box::new(fun), new_args); + + if prereqs.is_empty() { + (last_call, return_type) + } else { + prereqs.push(last_call); + ( + ir::Expression::Block(loc, return_type.clone(), prereqs), + return_type, + ) + } + } + syntax::Expression::Block(loc, stmts) => { let mut ret_type = ir::TypeOrVar::Primitive(PrimitiveType::Void); let mut exprs = vec![]; @@ -381,6 +432,12 @@ fn finalize_name( } } +fn merge_prereq(left: &mut Vec, prereq: Option) { + if let Some(item) = prereq { + left.push(item) + } +} + #[cfg(test)] mod tests { // use super::*; diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 1b9e703..4173794 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -91,6 +91,15 @@ fn finalize_expression( Expression::Print(loc, var) => Expression::Print(loc, var), + Expression::Call(loc, ty, fun, args) => Expression::Call( + loc, + finalize_type(ty, resolutions), + Box::new(finalize_val_or_ref(*fun, resolutions)), + args.into_iter() + .map(|x| finalize_val_or_ref(x, resolutions)) + .collect(), + ), + Expression::Bind(loc, var, ty, subexp) => Expression::Bind( loc, var, diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index 0b1b7f8..88ea802 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -192,7 +192,7 @@ impl From for Diagnostic { } TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => { loc.labelled_error("internal error").with_message(format!( - "could not determine if {} and {:#?} were equivalent", + "could not determine if {} and {} were equivalent", a, b )) } @@ -264,7 +264,7 @@ pub fn solve_constraints( // constraints. Internal to the loop, we have a check that will make sure that we // do (eventually) stop. while changed_something && !constraint_db.is_empty() { - println!("CONSTRAINT:"); + println!("\n\n\nCONSTRAINT:"); for constraint in constraint_db.iter() { println!(" {}", constraint); } @@ -300,8 +300,9 @@ pub fn solve_constraints( Constraint::IsSomething(_, TypeOrVar::Variable(_, ref name)) => { if resolutions.get(name).is_none() { constraint_db.push(constraint); + } else { + changed_something = true; } - changed_something = true; } // Case #1a: We have two primitive types. If they're equal, we've discharged this @@ -355,12 +356,15 @@ pub fn solve_constraints( // type. Constraint::Equivalent(loc, t, TypeOrVar::Variable(vloc, name)) | Constraint::Equivalent(loc, TypeOrVar::Variable(vloc, name), t) => { + println!("IN THIS CASE with {}", name); match resolutions.get(&name) { None => match t.try_into() { Ok(real_type) => { + println!(" HERE with {} and {}", name, real_type); resolutions.insert(name, real_type); } Err(variable_type) => { + println!(" REJECTED INTO RETURN with {} and {}", name, variable_type); constraint_db.push(Constraint::Equivalent( loc, variable_type, @@ -369,7 +373,9 @@ pub fn solve_constraints( continue; } }, - Some(t2) if &t == t2 => {} + Some(t2) if &t == t2 => { + println!(" MATCHED at {} == {}", t, t2); + } Some(t2) => errors.push(TypeInferenceError::NotEquivalent( loc, t,