From 7d4f182a67fa4136262e391984e3d25abb8d4304 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Tue, 16 Apr 2024 16:20:31 -0700 Subject: [PATCH] Clean up primitive handling, finally. --- src/backend/into_crane.rs | 208 ++++++++++++++++++++----------------- src/eval/primop.rs | 9 +- src/eval/primtype.rs | 8 ++ src/eval/value.rs | 11 ++ src/ir/arbitrary.rs | 29 +++++- src/ir/ast.rs | 13 +-- src/ir/eval.rs | 52 +++++----- src/ir/pretty.rs | 23 +--- src/repl.rs | 17 +-- src/syntax.rs | 56 +++++----- src/syntax/arbitrary.rs | 26 +++-- src/syntax/ast.rs | 57 ++++------ src/syntax/eval.rs | 66 +++++------- src/syntax/parser.lalrpop | 67 ++++++------ src/syntax/pretty.rs | 37 +------ src/syntax/validate.rs | 52 +--------- src/type_infer/convert.rs | 129 ++++++++++------------- src/type_infer/finalize.rs | 16 +-- src/type_infer/solve.rs | 73 ------------- 19 files changed, 399 insertions(+), 550 deletions(-) diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index cfcd47d..a92a90c 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -374,43 +374,6 @@ impl Backend { } } - Expression::Primitive(_, ret_type, prim, mut vals) => { - let mut values = vec![]; - let mut first_type = None; - - for val in vals.drain(..) { - let (compiled, compiled_type) = - self.compile_value_or_ref(val, variables, builder)?; - - if let Some(leftmost_type) = first_type { - assert_eq!(leftmost_type, compiled_type); - } else { - first_type = Some(compiled_type); - } - - values.push(compiled); - } - - let first_type = first_type.expect("primitive op has at least one argument"); - - // then we just need to tell Cranelift how to do each of our primitives! Much - // like Statements, above, we probably want to eventually shuffle this off into - // a separate function (maybe something off `Primitive`), but for now it's simple - // enough that we just do the `match` here. - match prim { - Primitive::Plus => Ok((builder.ins().iadd(values[0], values[1]), first_type)), - Primitive::Minus if values.len() == 2 => { - Ok((builder.ins().isub(values[0], values[1]), first_type)) - } - Primitive::Minus => Ok((builder.ins().ineg(values[0]), first_type)), - Primitive::Times => Ok((builder.ins().imul(values[0], values[1]), first_type)), - Primitive::Divide if ret_type.is_signed() => { - Ok((builder.ins().sdiv(values[0], values[1]), first_type)) - } - Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)), - } - } - Expression::Construct(_, ty, _, fields) => { let Type::Structure(type_fields) = ty else { panic!("Got to backend with non-structure type in structure construction?!"); @@ -493,60 +456,6 @@ impl Backend { } }, - Expression::Print(_, var) => { - // Get the output buffer (or null) from our general compilation context. - let buffer_ptr = self.output_buffer_ptr(); - let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); - - // Get a reference to the string we want to print. - let var_name = match var { - ValueOrRef::Ref(_, _, ref name) => name.as_ref(), - ValueOrRef::Value(_, _, _) => "", - }; - let string_data_id = self.string_reference(var_name)?; - let local_name_ref = self - .module - .declare_data_in_func(string_data_id, builder.func); - let name_ptr = builder.ins().symbol_value(types::I64, local_name_ref); - - let var_type = var.type_of(); - let (val, _) = self.compile_value_or_ref(var, variables, builder)?; - - let (repr_val, casted_val) = match var_type { - Type::Structure(_) => (ConstantType::I64 as i64, val), - Type::Function(_, _) => (ConstantType::I64 as i64, val), - Type::Primitive(pt) => { - let constant_type = pt.into(); - - let new_val = match constant_type { - ConstantType::U64 | ConstantType::I64 | ConstantType::Void => val, - ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => { - builder.ins().sextend(types::I64, val) - } - ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => { - builder.ins().uextend(types::I64, val) - } - }; - - (constant_type as i64, new_val) - } - }; - - let vtype_repr = builder.ins().iconst(types::I64, repr_val); - - // Finally, we can generate the call to print. - let print_func_ref = self.runtime_functions.include_runtime_function( - "print", - &mut self.module, - builder.func, - )?; - builder.ins().call( - print_func_ref, - &[buffer_ptr, name_ptr, vtype_repr, casted_val], - ); - Ok((builder.ins().iconst(types::I64, 0), VOID_REPR_TYPE)) - } - Expression::Bind(_, name, _, expr) => { let (value, value_type) = self.compile_expression(*expr, variables, builder)?; let variable = self.generate_local(); @@ -576,8 +485,15 @@ impl Backend { Ok((value, value_type)) } - Expression::Call(_, _, function, args) => { - let (arguments, _argument_types): (Vec<_>, Vec<_>) = args + Expression::Call(_, final_type, function, args) => { + // Get a reference to the string we want to print. + let var_name = match args[0] { + ValueOrRef::Ref(_, _, ref name) => name.to_string(), + ValueOrRef::Value(_, _, _) => "".to_string(), + ValueOrRef::Primitive(_, _, n) => format!("", n), + }; + let var_type = args[0].type_of(); + let (arguments, argument_types): (Vec<_>, Vec<_>) = args .into_iter() .map(|x| self.compile_value_or_ref(x, variables, builder)) .collect::, BackendError>>()? @@ -608,6 +524,109 @@ impl Backend { } } } + + ValueOrRef::Primitive(_, _, prim) => match prim { + Primitive::Plus => { + assert_eq!(2, arguments.len()); + Ok(( + builder.ins().iadd(arguments[0], arguments[1]), + argument_types[0], + )) + } + + Primitive::Minus => { + assert_eq!(2, arguments.len()); + Ok(( + builder.ins().isub(arguments[0], arguments[1]), + argument_types[0], + )) + } + + Primitive::Times => { + assert_eq!(2, arguments.len()); + Ok(( + builder.ins().imul(arguments[0], arguments[1]), + argument_types[0], + )) + } + + Primitive::Divide if final_type.is_signed() => { + assert_eq!(2, arguments.len()); + Ok(( + builder.ins().sdiv(arguments[0], arguments[1]), + argument_types[0], + )) + } + + Primitive::Divide => { + assert_eq!(2, arguments.len()); + Ok(( + builder.ins().udiv(arguments[0], arguments[1]), + argument_types[0], + )) + } + + Primitive::Negate => { + assert_eq!(1, arguments.len()); + Ok(( + builder.ins().ineg(arguments[0]), + argument_types[0], + )) + } + + Primitive::Print => { + // Get the output buffer (or null) from our general compilation context. + let buffer_ptr = self.output_buffer_ptr(); + let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); + + assert_eq!(1, arguments.len()); + let string_data_id = self.string_reference(&var_name)?; + let local_name_ref = self + .module + .declare_data_in_func(string_data_id, builder.func); + let name_ptr = builder.ins().symbol_value(types::I64, local_name_ref); + + let (repr_val, casted_val) = match var_type { + Type::Structure(_) => (ConstantType::I64 as i64, arguments[0]), + Type::Function(_, _) => (ConstantType::I64 as i64, arguments[0]), + Type::Primitive(pt) => { + let constant_type = pt.into(); + + let new_val = match constant_type { + ConstantType::U64 + | ConstantType::I64 + | ConstantType::Void => arguments[0], + ConstantType::I8 + | ConstantType::I16 + | ConstantType::I32 => { + builder.ins().sextend(types::I64, arguments[0]) + } + ConstantType::U8 + | ConstantType::U16 + | ConstantType::U32 => { + builder.ins().uextend(types::I64, arguments[0]) + } + }; + + (constant_type as i64, new_val) + } + }; + + let vtype_repr = builder.ins().iconst(types::I64, repr_val); + + // Finally, we can generate the call to print. + let print_func_ref = self.runtime_functions.include_runtime_function( + "print", + &mut self.module, + builder.func, + )?; + builder.ins().call( + print_func_ref, + &[buffer_ptr, name_ptr, vtype_repr, casted_val], + ); + Ok((builder.ins().iconst(types::I64, 0), VOID_REPR_TYPE)) + } + }, } } } @@ -668,6 +687,9 @@ impl Backend { Ok((value, *ctype)) } }, + ValueOrRef::Primitive(_, _, _) => { + unimplemented!() + } } } } diff --git a/src/eval/primop.rs b/src/eval/primop.rs index aeecb57..a122980 100644 --- a/src/eval/primop.rs +++ b/src/eval/primop.rs @@ -87,7 +87,7 @@ macro_rules! run_op { impl Value { fn unary_op(operation: &str, value: &Value) -> Result, PrimOpError> { match operation { - "-" => match value { + "negate" => match value { Value::I8(x) => Ok(Value::I8(x.wrapping_neg())), Value::I16(x) => Ok(Value::I16(x.wrapping_neg())), Value::I32(x) => Ok(Value::I32(x.wrapping_neg())), @@ -192,9 +192,10 @@ impl Value { right.clone(), )), }, - Value::Closure(_, _, _, _) | Value::Structure(_, _) | Value::Void => { - Err(PrimOpError::BadTypeFor(operation.to_string(), left.clone())) - } + Value::Closure(_, _, _, _) + | Value::Structure(_, _) + | Value::Primitive(_) + | Value::Void => Err(PrimOpError::BadTypeFor(operation.to_string(), left.clone())), } } diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index 0d9abc6..19a8dd4 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -42,12 +42,17 @@ impl Display for PrimitiveType { } } +#[allow(clippy::enum_variant_names)] #[derive(Clone, Debug, PartialEq, thiserror::Error)] pub enum ValuePrimitiveTypeError { #[error("Could not convert function value to primitive type (possible function name: {0:?}")] CannotConvertFunction(Option), #[error("Could not convert structure value to primitive type (possible function name: {0:?}")] CannotConvertStructure(Option), + #[error( + "Could not convert primitive operator to primitive type (possible function name: {0:?}" + )] + CannotConvertPrimitive(String), } impl<'a, IR> TryFrom<&'a Value> for PrimitiveType { @@ -72,6 +77,9 @@ impl<'a, IR> TryFrom<&'a Value> for PrimitiveType { Value::Structure(name, _) => Err(ValuePrimitiveTypeError::CannotConvertStructure( name.as_ref().map(|x| (**x).clone()), )), + Value::Primitive(prim) => Err(ValuePrimitiveTypeError::CannotConvertPrimitive( + prim.clone(), + )), } } } diff --git a/src/eval/value.rs b/src/eval/value.rs index da54910..99e8248 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -31,6 +31,7 @@ pub enum Value { Option>, HashMap, Value>, ), + Primitive(String), } impl Value { @@ -59,8 +60,13 @@ impl Value { name.clone(), fields.iter().map(|(n, v)| (n.clone(), v.strip())).collect(), ), + Value::Primitive(name) => Value::Primitive(name.clone()), } } + + pub fn primitive(name: S) -> Self { + Value::Primitive(name.to_string()) + } } fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -87,6 +93,7 @@ fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Resul } write!(f, " }}") } + Value::Primitive(n) => write!(f, "{}", n), } } @@ -165,6 +172,10 @@ impl PartialEq> for Value { } _ => false, }, + Value::Primitive(n1) => match other { + Value::Primitive(n2) => n1 == n2, + _ => false, + }, } } } diff --git a/src/ir/arbitrary.rs b/src/ir/arbitrary.rs index b704eaa..072afc2 100644 --- a/src/ir/arbitrary.rs +++ b/src/ir/arbitrary.rs @@ -347,12 +347,25 @@ fn generate_random_expression( Some((operator, arg_count)) => { let primop = Primitive::from_str(operator).expect("chose valid primitive"); let mut args = vec![base_expr]; + let mut argtys = vec![]; while args.len() < *arg_count { args.push(generate_random_valueref(rng, env, Some(primty))); + argtys.push(Type::Primitive(primty)); } - Expression::Primitive(Location::manufactured(), out_type, primop, args) + let primtype = Type::Function(argtys, Box::new(Type::Primitive(primty))); + + Expression::Call( + Location::manufactured(), + out_type, + Box::new(ValueOrRef::Primitive( + Location::manufactured(), + primtype, + primop, + )), + args, + ) } }, Type::Structure(_) => unimplemented!(), @@ -377,9 +390,19 @@ fn generate_random_expression( } else { let (variable, var_type) = possible_variables.choose(rng).unwrap(); - Expression::Print( + Expression::Call( Location::manufactured(), - ValueOrRef::Ref(Location::manufactured(), var_type.clone(), variable.clone()), + Type::void(), + Box::new(ValueOrRef::Primitive( + Location::manufactured(), + Type::void(), + Primitive::Print, + )), + vec![ValueOrRef::Ref( + Location::manufactured(), + var_type.clone(), + variable.clone(), + )], ) } } diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 0e58076..24ab128 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -106,7 +106,6 @@ impl TopLevel { pub enum Expression { Atomic(ValueOrRef), Cast(Location, Type, ValueOrRef), - Primitive(Location, Type, Primitive, Vec>), Construct( Location, Type, @@ -115,7 +114,6 @@ pub enum Expression { ), FieldRef(Location, Type, Type, ValueOrRef, ArcIntern), Block(Location, Type, Vec>), - Print(Location, ValueOrRef), Call(Location, Type, Box>, Vec>), Bind(Location, Variable, Type, Box>), } @@ -127,11 +125,9 @@ impl Expression { match self { Expression::Atomic(x) => x.type_of(), Expression::Cast(_, t, _) => t.clone(), - Expression::Primitive(_, t, _, _) => t.clone(), Expression::Construct(_, t, _, _) => t.clone(), Expression::FieldRef(_, t, _, _, _) => t.clone(), Expression::Block(_, t, _) => t.clone(), - Expression::Print(_, _) => Type::void(), Expression::Call(_, t, _, _) => t.clone(), Expression::Bind(_, _, _, _) => Type::void(), } @@ -142,12 +138,11 @@ impl Expression { match self { Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, + Expression::Atomic(ValueOrRef::Primitive(l, _, _)) => l, Expression::Cast(l, _, _) => l, - Expression::Primitive(l, _, _, _) => l, Expression::Construct(l, _, _, _) => l, Expression::FieldRef(l, _, _, _, _) => l, Expression::Block(l, _, _) => l, - Expression::Print(l, _) => l, Expression::Call(l, _, _, _) => l, Expression::Bind(l, _, _, _) => l, } @@ -166,6 +161,8 @@ pub enum Primitive { Minus, Times, Divide, + Print, + Negate, } impl FromStr for Primitive { @@ -177,6 +174,8 @@ impl FromStr for Primitive { "-" => Ok(Primitive::Minus), "*" => Ok(Primitive::Times), "/" => Ok(Primitive::Divide), + "print" => Ok(Primitive::Print), + "negate" => Ok(Primitive::Negate), _ => Err(format!("Illegal primitive {}", value)), } } @@ -191,6 +190,7 @@ impl FromStr for Primitive { pub enum ValueOrRef { Value(Location, Type, Value), Ref(Location, Type, ArcIntern), + Primitive(Location, Type, Primitive), } impl ValueOrRef { @@ -198,6 +198,7 @@ impl ValueOrRef { match self { ValueOrRef::Ref(_, t, _) => t.clone(), ValueOrRef::Value(_, t, _) => t.clone(), + ValueOrRef::Primitive(_, t, _) => t.clone(), } } } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 63748b0..d74db44 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -1,4 +1,4 @@ -use super::{Primitive, Type, ValueOrRef}; +use super::{Type, ValueOrRef}; use crate::eval::{EvalError, Value}; use crate::ir::{Expression, Program, TopLevel, Variable}; use crate::util::scoped_map::ScopedMap; @@ -65,22 +65,6 @@ where } } - Expression::Primitive(_, _, op, args) => { - let arg_values = args - .iter() - .map(|x| x.eval(env)) - .collect::>, IREvalError>>()?; - - // and then finally we call `calculate` to run them. trust me, it's nice - // to not have to deal with all the nonsense hidden under `calculate`. - match op { - Primitive::Plus => Ok(Value::calculate("+", arg_values)?), - Primitive::Minus => Ok(Value::calculate("-", arg_values)?), - Primitive::Times => Ok(Value::calculate("*", arg_values)?), - Primitive::Divide => Ok(Value::calculate("/", arg_values)?), - } - } - Expression::Construct(_, _, name, fields) => { let mut result_fields = HashMap::with_capacity(fields.len()); @@ -114,16 +98,6 @@ where Ok(result) } - Expression::Print(_, value) => { - let n = match value { - ValueOrRef::Ref(_, _, ref name) => name.as_str(), - ValueOrRef::Value(_, _, _) => "", - }; - let value = value.eval(env)?; - stdout.push_str(&format!("{} = {}\n", n, value)); - Ok(Value::Void) - } - Expression::Bind(_, name, _, value) => { let value = value.eval(env, stdout)?; env.insert(name.clone(), value.clone()); @@ -153,6 +127,28 @@ where closure_env.release_scope(); Ok(result) } + + Value::Primitive(name) if name == "print" => { + if let [ValueOrRef::Ref(loc, ty, name)] = &args[..] { + let value = ValueOrRef::Ref(loc.clone(), ty.clone(), name.clone()).eval(env)?; + let addendum = format!("{} = {}\n", name, value); + + stdout.push_str(&addendum); + Ok(Value::Void) + } else { + panic!("Non-reference/non-singleton argument to 'print'"); + } + } + + Value::Primitive(name) => { + let values = args + .iter() + .map(|x| x.eval(env)) + .collect::>()?; + println!("primitive {}: args {:?}", name, values); + Value::calculate(name.as_str(), values).map_err(Into::into) + } + _ => Err(EvalError::NotAFunction(loc.clone(), function)), } } @@ -179,6 +175,8 @@ impl ValueOrRef { .get(n) .cloned() .ok_or_else(|| EvalError::LookupFailed(loc.clone(), n.to_string())), + + ValueOrRef::Primitive(_, _, prim) => Ok(Value::primitive(prim)), } } } diff --git a/src/ir/pretty.rs b/src/ir/pretty.rs index 129d8f9..ab55bc3 100644 --- a/src/ir/pretty.rs +++ b/src/ir/pretty.rs @@ -53,9 +53,6 @@ impl Expression { .append(t.pretty(allocator)) .append(allocator.text(">")) .append(e.pretty(allocator)), - Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { - op.pretty(allocator).append(exprs[0].pretty(allocator)) - } Expression::Construct(_, _, name, fields) => { let inner = allocator .intersperse( @@ -78,19 +75,6 @@ impl Expression { .text(".") .append(allocator.text(field.to_string())), ), - Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { - let left = exprs[0].pretty(allocator); - let right = exprs[1].pretty(allocator); - - left.append(allocator.space()) - .append(op.pretty(allocator)) - .append(allocator.space()) - .append(right) - .parens() - } - Expression::Primitive(_, _, op, exprs) => { - allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) - } Expression::Call(_, _, fun, args) => { let args = args.iter().map(|x| x.pretty(allocator)); let comma_sepped_args = allocator.intersperse(args, allocator.text(",")); @@ -119,10 +103,6 @@ impl Expression { result.append(last).append(allocator.text("}")) } }, - Expression::Print(_, var) => allocator - .text("print") - .append(allocator.space()) - .append(var.pretty(allocator)), Expression::Bind(_, var, ty, expr) => allocator .text(var.as_ref().to_string()) .append(allocator.space()) @@ -144,6 +124,8 @@ impl Primitive { Primitive::Minus => allocator.text("-"), Primitive::Times => allocator.text("*"), Primitive::Divide => allocator.text("/"), + Primitive::Print => allocator.text("print"), + Primitive::Negate => allocator.text("negate"), } } } @@ -153,6 +135,7 @@ impl ValueOrRef { match self { ValueOrRef::Value(_, _, v) => v.pretty(allocator), ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), + ValueOrRef::Primitive(_, _, p) => p.pretty(allocator), } } } diff --git a/src/repl.rs b/src/repl.rs index 58225f7..d2730ef 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,5 +1,5 @@ use crate::backend::{Backend, BackendError}; -use crate::syntax::{ConstantType, Expression, Location, ParserError, Statement, TopLevel}; +use crate::syntax::{ConstantType, Expression, Location, ParserError, TopLevel}; use crate::type_infer::TypeInferenceResult; use crate::util::scoped_map::ScopedMap; use codespan_reporting::diagnostic::Diagnostic; @@ -130,7 +130,7 @@ impl REPL { let syntax = TopLevel::parse(entry, source)?; let program = match syntax { - TopLevel::Statement(Statement::Expression(Expression::Binding(loc, name, expr))) => { + TopLevel::Expression(Expression::Binding(loc, name, expr)) => { // if this is a variable binding, and we've never defined this variable before, // we should tell cranelift about it. this is optimistic; if we fail to compile, // then we won't use this definition until someone tries again. @@ -142,12 +142,15 @@ impl REPL { crate::syntax::Program { items: vec![ - TopLevel::Statement(Statement::Expression(Expression::Binding( + TopLevel::Expression(Expression::Binding(loc.clone(), name.clone(), expr)), + TopLevel::Expression(Expression::Call( loc.clone(), - name.clone(), - expr, - ))), - TopLevel::Statement(Statement::Print(loc, name)), + Box::new(Expression::Primitive( + loc.clone(), + crate::syntax::Name::manufactured("print"), + )), + vec![Expression::Reference(loc.clone(), name.name)], + )), ], } } diff --git a/src/syntax.rs b/src/syntax.rs index 751c36d..26d0b36 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -285,35 +285,39 @@ fn order_of_operations() { assert_eq!( Program::from_str(muladd1).unwrap(), Program { - items: vec![TopLevel::Statement(Statement::Expression( - Expression::Binding( - Location::new(testfile, 0..1), - Name::manufactured("x"), + items: vec![TopLevel::Expression(Expression::Binding( + Location::new(testfile, 0..1), + Name::manufactured("x"), + Box::new(Expression::Call( + Location::new(testfile, 6..7), Box::new(Expression::Primitive( Location::new(testfile, 6..7), - "+".to_string(), - vec![ - Expression::Value( - Location::new(testfile, 4..5), - Value::Number(None, None, 1), - ), - Expression::Primitive( + Name::manufactured("+") + )), + vec![ + Expression::Value( + Location::new(testfile, 4..5), + Value::Number(None, None, 1), + ), + Expression::Call( + Location::new(testfile, 10..11), + Box::new(Expression::Primitive( Location::new(testfile, 10..11), - "*".to_string(), - vec![ - Expression::Value( - Location::new(testfile, 8..9), - Value::Number(None, None, 2), - ), - Expression::Value( - Location::new(testfile, 12..13), - Value::Number(None, None, 3), - ), - ] - ) - ] - )) - ) + Name::manufactured("*") + )), + vec![ + Expression::Value( + Location::new(testfile, 8..9), + Value::Number(None, None, 2), + ), + Expression::Value( + Location::new(testfile, 12..13), + Value::Number(None, None, 3), + ), + ] + ) + ] + )) ))], } ); diff --git a/src/syntax/arbitrary.rs b/src/syntax/arbitrary.rs index 9abfa1f..44bd323 100644 --- a/src/syntax/arbitrary.rs +++ b/src/syntax/arbitrary.rs @@ -1,4 +1,4 @@ -use crate::syntax::ast::{ConstantType, Expression, Name, Program, Statement, TopLevel, Value}; +use crate::syntax::ast::{ConstantType, Expression, Name, Program, TopLevel, Value}; use crate::syntax::location::Location; use proptest::sample::select; use proptest::{ @@ -72,19 +72,26 @@ impl Arbitrary for Program { genenv.bindings.insert(psi.name.clone(), psi.binding_type); items.push( expr.prop_map(move |expr| { - TopLevel::Statement(Statement::Expression(Expression::Binding( + TopLevel::Expression(Expression::Binding( Location::manufactured(), psi.name.clone(), Box::new(expr), - ))) + )) }) .boxed(), ); } else { let printers = genenv.bindings.keys().map(|n| { - Just(TopLevel::Statement(Statement::Print( + Just(TopLevel::Expression(Expression::Call( Location::manufactured(), - Name::manufactured(n), + Box::new(Expression::Primitive( + Location::manufactured(), + Name::manufactured("print"), + )), + vec![Expression::Reference( + Location::manufactured(), + n.to_string(), + )], ))) }); items.push(Union::new(printers).boxed()); @@ -186,7 +193,14 @@ impl Arbitrary for Expression { while args.len() > count { args.pop(); } - Expression::Primitive(Location::manufactured(), oper.to_string(), args) + Expression::Call( + Location::manufactured(), + Box::new(Expression::Primitive( + Location::manufactured(), + Name::manufactured(oper), + )), + args, + ) }) }) .boxed() diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index f128ecb..726ad5a 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -26,7 +26,7 @@ pub struct Program { /// and functions #[derive(Clone, Debug, PartialEq)] pub enum TopLevel { - Statement(Statement), + Expression(Expression), Function( Option, Vec<(Name, Option)>, @@ -86,38 +86,6 @@ impl fmt::Display for Name { } } -/// A parsed statement. -/// -/// Statements are guaranteed to be syntactically valid, but may be -/// complete nonsense at the semantic level. Which is to say, all the -/// print statements were correctly formatted, and all the variables -/// referenced are definitely valid symbols, but they may not have -/// been defined or anything. -/// -/// Note that equivalence testing on statements is independent of -/// source location; it is testing if the two statements say the same -/// thing, not if they are the exact same statement. -#[derive(Clone, Debug)] -pub enum Statement { - Print(Location, Name), - Expression(Expression), -} - -impl PartialEq for Statement { - fn eq(&self, other: &Self) -> bool { - match self { - Statement::Print(_, name1) => match other { - Statement::Print(_, name2) => name1 == name2, - _ => false, - }, - Statement::Expression(e1) => match other { - Statement::Expression(e2) => e1 == e2, - _ => false, - }, - } - } -} - /// An expression in the underlying syntax. /// /// Like statements, these expressions are guaranteed to have been @@ -131,12 +99,25 @@ pub enum Expression { Reference(Location, String), FieldRef(Location, Box, Name), Cast(Location, String, Box), - Primitive(Location, String, Vec), + Primitive(Location, Name), Call(Location, Box, Vec), - Block(Location, Vec), + Block(Location, Vec), Binding(Location, Name, Box), } +impl Expression { + pub fn primitive(loc: Location, name: &str, args: Vec) -> Expression { + Expression::Call( + loc.clone(), + Box::new(Expression::Primitive( + loc.clone(), + Name::new(name, loc.clone()), + )), + args, + ) + } +} + impl PartialEq for Expression { fn eq(&self, other: &Self) -> bool { match self { @@ -160,8 +141,8 @@ impl PartialEq for Expression { Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2, _ => false, }, - Expression::Primitive(_, prim1, args1) => match other { - Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, + Expression::Primitive(_, prim1) => match other { + Expression::Primitive(_, prim2) => prim1 == prim2, _ => false, }, Expression::Call(_, f1, a1) => match other { @@ -189,7 +170,7 @@ impl Expression { Expression::Reference(loc, _) => loc, Expression::FieldRef(loc, _, _) => loc, Expression::Cast(loc, _, _) => loc, - Expression::Primitive(loc, _, _) => loc, + Expression::Primitive(loc, _) => loc, Expression::Call(loc, _, _) => loc, Expression::Block(loc, _) => loc, Expression::Binding(loc, _, _) => loc, diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index 7546392..f4635f7 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -1,5 +1,5 @@ use crate::eval::{EvalError, PrimitiveType, Value}; -use crate::syntax::{ConstantType, Expression, Name, Program, Statement, TopLevel}; +use crate::syntax::{ConstantType, Expression, Name, Program, TopLevel}; use crate::util::scoped_map::ScopedMap; use internment::ArcIntern; use std::collections::HashMap; @@ -40,7 +40,7 @@ impl Program { } } - TopLevel::Statement(stmt) => last_result = stmt.eval(&mut stdout, &mut env)?, + TopLevel::Expression(expr) => last_result = expr.eval(&mut stdout, &mut env)?, TopLevel::Structure(_, _, _) => { last_result = Value::Void; @@ -52,32 +52,6 @@ impl Program { } } -impl Statement { - fn eval( - &self, - stdout: &mut String, - env: &mut ScopedMap, Value>, - ) -> Result, EvalError> { - match self { - Statement::Print(loc, name) => { - let value = env - .get(&name.clone().intern()) - .ok_or_else(|| EvalError::LookupFailed(loc.clone(), name.name.clone()))?; - let value = if let Value::Number(x) = value { - Value::U64(*x) - } else { - value.clone() - }; - let line = format!("{} = {}\n", name, value); - stdout.push_str(&line); - Ok(Value::Void) - } - - Statement::Expression(e) => e.eval(stdout, env), - } - } -} - impl Expression { fn eval( &self, @@ -144,16 +118,7 @@ impl Expression { Ok(target_type.safe_cast(&value)?) } - Expression::Primitive(_, op, args) => { - let mut arg_values = Vec::with_capacity(args.len()); - - for arg in args.iter() { - // yay, recursion! makes this pretty straightforward - arg_values.push(arg.eval(stdout, env)?); - } - - Ok(Value::calculate(op, arg_values)?) - } + Expression::Primitive(_, op) => Ok(Value::primitive(op.name.clone())), Expression::Call(loc, fun, args) => { let function = fun.eval(stdout, env)?; @@ -178,6 +143,31 @@ impl Expression { closure_env.release_scope(); Ok(result) } + + Value::Primitive(name) if name == "print" => { + if let [Expression::Reference(_, name)] = &args[..] { + let value = Expression::Reference(loc.clone(), name.clone()).eval(stdout, env)?; + let value = match value { + Value::Number(x) => Value::U64(x), + x => x, + }; + let addendum = format!("{} = {}\n", name, value); + + stdout.push_str(&addendum); + Ok(Value::Void) + } else { + panic!("Non-reference/non-singleton argument to 'print': {:?}", args); + } + } + + Value::Primitive(name) => { + let values = args + .iter() + .map(|x| x.eval(stdout, env)) + .collect::>()?; + Value::calculate(name.as_str(), values).map_err(Into::into) + } + _ => Err(EvalError::NotAFunction(loc.clone(), function)), } } diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index 3c37a92..ed42ca4 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -9,7 +9,7 @@ //! eventually want to leave lalrpop behind.) //! use crate::syntax::{Location, ParserError}; -use crate::syntax::ast::{Program,TopLevel,Statement,Expression,Value,Name,Type}; +use crate::syntax::ast::{Program,TopLevel,Expression,Value,Name,Type}; use crate::syntax::tokens::{ConstantType, Token}; use internment::ArcIntern; @@ -81,7 +81,7 @@ ProgramTopLevel: Vec = { pub TopLevel: TopLevel = { => f, => s, - ";" => TopLevel::Statement(s), + ";" => TopLevel::Expression(s), } Function: TopLevel = { @@ -132,34 +132,6 @@ TypeName: Name = { Name::new(v, Location::new(file_idx, name_start..name_end)), } -Statements: Vec = { - // a statement is either a set of statements followed by another - // statement (note, here, that you can name the result of a sub-parse - // using ) ... - ";" => { - stmts.push(stmt); - stmts - }, - - => { - vec![stmt] - } -} - -#[inline] -Statement: Statement = { - // A statement can just be a print statement. - "print" "> => - Statement::Print( - Location::new(file_idx, ls..le), - Name::new(v, Location::new(file_idx, name_start..name_end)), - ), - - // A statement can just be an expression. - => - Statement::Expression(e), -} - // Expressions! Expressions are a little fiddly, because we're going to // use a little bit of a trick to make sure that we get operator precedence // right. The trick works by creating a top-level `Expression` grammar entry @@ -198,6 +170,21 @@ BindingExpression: Expression = { Box::new(e), ), + PrintExpression, +} + +PrintExpression: Expression = { + "print" => + Expression::Call( + Location::new(file_idx, ls..le), + Box::new( + Expression::Primitive( + Location::new(file_idx, ls..pe), + Name::new("print", Location::new(file_idx, ls..pe)), + ), + ), + vec![e], + ), ConstructorExpression, } @@ -214,24 +201,24 @@ FieldSetter: (Name, Expression) = { // we group addition and subtraction under the heading "additive" AdditiveExpression: Expression = { "+" => - Expression::Primitive(Location::new(file_idx, ls..le), "+".to_string(), vec![e1, e2]), + Expression::primitive(Location::new(file_idx, ls..le), "+", vec![e1, e2]), "-" => - Expression::Primitive(Location::new(file_idx, ls..le), "-".to_string(), vec![e1, e2]), + Expression::primitive(Location::new(file_idx, ls..le), "-", vec![e1, e2]), MultiplicativeExpression, } // similarly, we group multiplication and division under "multiplicative" MultiplicativeExpression: Expression = { "*" => - Expression::Primitive(Location::new(file_idx, ls..le), "*".to_string(), vec![e1, e2]), + Expression::primitive(Location::new(file_idx, ls..le), "*", vec![e1, e2]), "/" => - Expression::Primitive(Location::new(file_idx, ls..le), "/".to_string(), vec![e1, e2]), + Expression::primitive(Location::new(file_idx, ls..le), "/", vec![e1, e2]), UnaryExpression, } UnaryExpression: Expression = { "-" => - Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]), + Expression::primitive(Location::new(file_idx, l..le), "negate", vec![e]), "<" "> ">" => Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)), CallExpression, @@ -257,13 +244,21 @@ AtomicExpression: Expression = { // just a number "> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)), // this expression could actually be a block! - "{" "}" => Expression::Block(Location::new(file_idx, s..e), stmts), + "{" ";"? "}" => Expression::Block(Location::new(file_idx, s..e), exprs), "{" "}" => Expression::Block(Location::new(file_idx, s..e), vec![]), // finally, let people parenthesize expressions and get back to a // lower precedence "(" ")" => e, } +Expressions: Vec = { + => vec![e], + ";" => { + exps.push(e); + exps + } +} + // Lifted from the LALRPop book, a comma-separated list of T that may or // may not conclude with a comma. Comma: Vec = { diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index 249c571..81eabe2 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -1,4 +1,4 @@ -use crate::syntax::ast::{ConstantType, Expression, Program, Statement, TopLevel, Type, Value}; +use crate::syntax::ast::{ConstantType, Expression, Program, TopLevel, Type, Value}; use crate::util::pretty::{derived_display, Allocator}; use pretty::{DocAllocator, DocBuilder}; @@ -20,7 +20,7 @@ impl Program { impl TopLevel { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { match self { - TopLevel::Statement(stmt) => stmt.pretty(allocator), + TopLevel::Expression(expr) => expr.pretty(allocator), TopLevel::Function(name, args, rettype, body) => allocator .text("function") .append(allocator.space()) @@ -87,18 +87,6 @@ impl TopLevel { } } -impl Statement { - pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { - match self { - Statement::Print(_, var) => allocator - .text("print") - .append(allocator.space()) - .append(allocator.text(var.to_string())), - Statement::Expression(e) => e.pretty(allocator), - } - } -} - impl Expression { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { match self { @@ -129,25 +117,7 @@ impl Expression { .text(t.clone()) .angles() .append(e.pretty(allocator)), - Expression::Primitive(_, op, exprs) if exprs.len() == 1 => allocator - .text(op.to_string()) - .append(exprs[0].pretty(allocator)), - Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { - let left = exprs[0].pretty(allocator); - let right = exprs[1].pretty(allocator); - - left.append(allocator.space()) - .append(allocator.text(op.to_string())) - .append(allocator.space()) - .append(right) - .parens() - } - Expression::Primitive(_, op, exprs) => { - let call = allocator.text(op.to_string()); - let args = exprs.iter().map(|x| x.pretty(allocator)); - let comma_sepped_args = allocator.intersperse(args, allocator.text(",")); - call.append(comma_sepped_args.parens()) - } + Expression::Primitive(_, op) => allocator.text(op.name.clone()), Expression::Call(_, fun, args) => { let args = args.iter().map(|x| x.pretty(allocator)); let comma_sepped_args = allocator.intersperse(args, allocator.text(",")); @@ -245,6 +215,5 @@ impl Type { derived_display!(Program); derived_display!(TopLevel); -derived_display!(Statement); derived_display!(Expression); derived_display!(Value); diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index 3c868d0..8ee2792 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -1,6 +1,6 @@ use crate::{ eval::PrimitiveType, - syntax::{Expression, Location, Program, Statement, TopLevel}, + syntax::{Expression, Location, Program, TopLevel}, util::scoped_map::ScopedMap, }; use codespan_reporting::diagnostic::Diagnostic; @@ -132,47 +132,12 @@ impl TopLevel { bound_variables.release_scope(); result } - TopLevel::Statement(stmt) => stmt.validate(bound_variables), + TopLevel::Expression(expr) => expr.validate(bound_variables), TopLevel::Structure(_, _, _) => (vec![], vec![]), } } } -impl Statement { - /// Validate that the statement makes semantic sense, not just syntactic sense. - /// - /// This checks for things like references to variables that don't exist, for - /// example, and generates warnings for things that are inadvisable but not - /// actually a problem. Since statements appear in a broader context, you'll - /// need to provide the set of variables that are bound where this statement - /// occurs. We use a `HashMap` to map these bound locations to the locations - /// where their bound, because these locations are handy when generating errors - /// and warnings. - fn validate( - &self, - bound_variables: &mut ScopedMap, - ) -> (Vec, Vec) { - let mut errors = vec![]; - let mut warnings = vec![]; - - match self { - Statement::Print(_, var) if bound_variables.contains_key(&var.name) => {} - Statement::Print(loc, var) => { - errors.push(Error::UnboundVariable(loc.clone(), var.to_string())) - } - - Statement::Expression(e) => { - let (mut exp_errors, mut exp_warnings) = e.validate(bound_variables); - - errors.append(&mut exp_errors); - warnings.append(&mut exp_warnings); - } - } - - (errors, warnings) - } -} - impl Expression { fn validate( &self, @@ -207,18 +172,7 @@ impl Expression { (errs, warns) } - Expression::Primitive(_, _, args) => { - let mut errors = vec![]; - let mut warnings = vec![]; - - for expr in args.iter() { - let (mut err, mut warn) = expr.validate(variable_map); - errors.append(&mut err); - warnings.append(&mut warn); - } - - (errors, warnings) - } + Expression::Primitive(_, _) => (vec![], vec![]), Expression::Call(_, func, args) => { let (mut errors, mut warnings) = func.validate(variable_map); diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 1fdf95d..fb69cba 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -9,7 +9,7 @@ use std::str::FromStr; enum TopLevelItem { Type(ArcIntern, ir::TypeOrVar), - Expression(ir::TopLevel), + Value(ir::TopLevel), } /// This function takes a syntactic program and converts it into the IR version of the @@ -32,7 +32,7 @@ pub fn convert_program( let tli = convert_top_level(item, &mut constraint_db, &mut renames, &mut bindings); match tli { - TopLevelItem::Expression(item) => items.push(item), + TopLevelItem::Value(item) => items.push(item), TopLevelItem::Type(name, decl) => { let _ = type_definitions.insert(name, decl); } @@ -137,7 +137,7 @@ fn convert_top_level( // Remember to exit this scoping level! renames.release_scope(); - TopLevelItem::Expression(ir::TopLevel::Function( + TopLevelItem::Value(ir::TopLevel::Function( function_name, arginfo, rettype, @@ -145,8 +145,8 @@ fn convert_top_level( )) } - syntax::TopLevel::Statement(stmt) => TopLevelItem::Expression(ir::TopLevel::Statement( - convert_statement(stmt, constraint_db, renames, bindings), + syntax::TopLevel::Expression(expr) => TopLevelItem::Value(ir::TopLevel::Statement( + convert_expression(expr, constraint_db, renames, bindings).0, )), syntax::TopLevel::Structure(_loc, name, fields) => { @@ -161,46 +161,6 @@ fn convert_top_level( } } -/// This function takes a syntactic statements and converts it into a series of -/// IR statements, adding type variables and constraints as necessary. -/// -/// We generate a series of statements because we're going to flatten all -/// incoming expressions so that they are no longer recursive. This will -/// generate a bunch of new bindings for all the subexpressions, which we -/// return as a bundle. -/// -/// See the safety warning on [`convert_program`]! This function assumes that -/// you have run [`Statement::validate`], and will trigger panics in error -/// conditions if you have run that and had it come back clean. -fn convert_statement( - statement: syntax::Statement, - constraint_db: &mut Vec, - renames: &mut ScopedMap, ArcIntern>, - bindings: &mut HashMap, ir::TypeOrVar>, -) -> ir::Expression { - match statement { - syntax::Statement::Print(loc, name) => { - let iname = ArcIntern::new(name.to_string()); - let final_name = renames - .get(&iname) - .cloned() - .unwrap_or_else(|| iname.clone()); - let varty = bindings - .get(&final_name) - .expect("print variable defined before use") - .clone(); - - constraint_db.push(Constraint::Printable(loc.clone(), varty.clone())); - - ir::Expression::Print(loc.clone(), ir::ValueOrRef::Ref(loc, varty, final_name)) - } - - syntax::Statement::Expression(e) => { - convert_expression(e, constraint_db, renames, bindings).0 - } - } -} - /// This function takes a syntactic expression and converts it into a series /// of IR statements, adding type variables and constraints as necessary. /// @@ -359,39 +319,54 @@ fn convert_expression( (finalize_expression(prereqs, res), target_type) } - syntax::Expression::Primitive(loc, fun, mut args) => { - let primop = ir::Primitive::from_str(&fun).expect("valid primitive"); - let mut prereqs = vec![]; - let mut nargs = vec![]; - let mut atypes = vec![]; - let ret_type = ir::TypeOrVar::new(); + syntax::Expression::Primitive(loc, name) => { + let primop = ir::Primitive::from_str(&name.name).expect("valid primitive"); - for arg in args.drain(..) { - let (aexp, atype) = convert_expression(arg, constraint_db, renames, bindings); - let (aprereqs, asimple) = simplify_expr(aexp); + match primop { + ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { + let numeric_type = ir::TypeOrVar::new_located(loc.clone()); + constraint_db.push(Constraint::NumericType(loc.clone(), numeric_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![numeric_type.clone(), numeric_type.clone()], + Box::new(numeric_type.clone()), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + (ir::Expression::Atomic(result_value), funtype) + } - merge_prereq(&mut prereqs, aprereqs); - nargs.push(asimple); - atypes.push(atype); - } + ir::Primitive::Minus => { + let numeric_type = ir::TypeOrVar::new_located(loc.clone()); + constraint_db.push(Constraint::NumericType(loc.clone(), numeric_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![numeric_type.clone(), numeric_type.clone()], + Box::new(numeric_type.clone()), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + (ir::Expression::Atomic(result_value), funtype) + } - constraint_db.push(Constraint::ProperPrimitiveArgs( - loc.clone(), - primop, - atypes.clone(), - ret_type.clone(), - )); + ir::Primitive::Print => { + let arg_type = ir::TypeOrVar::new_located(loc.clone()); + constraint_db.push(Constraint::Printable(loc.clone(), arg_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![arg_type], + Box::new(ir::TypeOrVar::Primitive(PrimitiveType::Void)), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + (ir::Expression::Atomic(result_value), funtype) + } - let last_call = ir::Expression::Primitive(loc.clone(), ret_type.clone(), primop, nargs); - - if prereqs.is_empty() { - (last_call, ret_type) - } else { - prereqs.push(last_call); - ( - ir::Expression::Block(loc, ret_type.clone(), prereqs), - ret_type, - ) + ir::Primitive::Negate => { + let arg_type = ir::TypeOrVar::new_located(loc.clone()); + constraint_db.push(Constraint::NumericType(loc.clone(), arg_type.clone())); + constraint_db.push(Constraint::IsSigned(loc.clone(), arg_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![arg_type.clone()], + Box::new(arg_type), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + (ir::Expression::Atomic(result_value), funtype) + } } } @@ -444,10 +419,10 @@ fn convert_expression( let mut ret_type = ir::TypeOrVar::Primitive(PrimitiveType::Void); let mut exprs = vec![]; - for statement in stmts { - let expr = convert_statement(statement, constraint_db, renames, bindings); + for xpr in stmts { + let (expr, expr_type) = convert_expression(xpr, constraint_db, renames, bindings); - ret_type = expr.type_of(); + ret_type = expr_type; exprs.push(expr); } diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 8949c26..7f12adb 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -54,15 +54,6 @@ fn finalize_expression( finalize_val_or_ref(val_or_ref, resolutions), ), - Expression::Primitive(loc, ty, prim, mut args) => Expression::Primitive( - loc, - finalize_type(ty, resolutions), - prim, - args.drain(..) - .map(|x| finalize_val_or_ref(x, resolutions)) - .collect(), - ), - Expression::Construct(loc, ty, name, fields) => Expression::Construct( loc, finalize_type(ty, resolutions), @@ -97,10 +88,6 @@ fn finalize_expression( Expression::Block(loc, finalize_type(ty, resolutions), final_exprs) } - Expression::Print(loc, var) => { - Expression::Print(loc, finalize_val_or_ref(var, resolutions)) - } - Expression::Call(loc, ty, fun, args) => Expression::Call( loc, finalize_type(ty, resolutions), @@ -147,6 +134,9 @@ fn finalize_val_or_ref( ) -> ValueOrRef { match valref { ValueOrRef::Ref(loc, ty, var) => ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var), + ValueOrRef::Primitive(loc, ty, prim) => { + ValueOrRef::Primitive(loc, finalize_type(ty, resolutions), prim) + } ValueOrRef::Value(loc, ty, val) => { let new_type = finalize_type(ty, resolutions); diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index b18d284..15dc067 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -12,8 +12,6 @@ pub enum Constraint { Printable(Location, TypeOrVar), /// The provided numeric value fits in the given constant type FitsInNumType(Location, TypeOrVar, u64), - /// The given primitive has the proper arguments types associated with it - ProperPrimitiveArgs(Location, Primitive, Vec, TypeOrVar), /// The given type can be casted to the target type safely CanCastTo(Location, TypeOrVar, TypeOrVar), /// The given type has the given field in it, and the type of that field @@ -41,13 +39,6 @@ impl fmt::Display for Constraint { match self { Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty), Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty), - Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => { - write!(f, "PRIM {} {} -> {}", op, args[0], ret) - } - Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => { - write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret) - } - Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret), Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), Constraint::TypeHasField(_, ty1, field, ty2) => { write!(f, "FIELD {}.{} -> {}", ty1, field, ty2) @@ -84,10 +75,6 @@ impl Constraint { Constraint::IsSigned(_, ty) => ty.replace(name, replace_with), Constraint::IsSomething(_, ty) => ty.replace(name, replace_with), Constraint::NumericType(_, ty) => ty.replace(name, replace_with), - Constraint::ProperPrimitiveArgs(_, _, args, ret) => { - ret.replace(name, replace_with) - | args.iter_mut().any(|x| x.replace(name, replace_with)) - } Constraint::NamedTypeIs(_, name, ty) => ty.replace(name, replace_with), } } @@ -269,12 +256,6 @@ impl From for Diagnostic { TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc .labelled_error("internal error") .with_message(format!("Could not determine if type {} was printable", ty)), - TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => { - loc.labelled_error("internal error").with_message(format!( - "Could not tell if primitive {} received the proper argument types", - prim - )) - } TypeInferenceError::CouldNotSolve(Constraint::IsSomething(loc, _)) => { loc.labelled_error("could not infer type") .with_message("Could not find *any* type information; is this an unused function argument?") @@ -621,60 +602,6 @@ pub fn solve_constraints( changed_something = true; } - Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim { - Primitive::Plus | Primitive::Minus | Primitive::Times | Primitive::Divide - if args.len() == 2 => - { - let right = args.pop().expect("2>0"); - let left = args.pop().expect("2>1"); - - new_constraints.push(Constraint::NumericType(loc.clone(), left.clone())); - new_constraints.push(Constraint::Equivalent( - loc.clone(), - left.clone(), - right, - )); - new_constraints.push(Constraint::Equivalent(loc.clone(), left, ret)); - changed_something = true; - all_constraints_solved = false; - tracing::trace!(primitive = %prim, "we expanded out a binary primitive operation"); - } - - Primitive::Minus if args.len() == 1 => { - let value = args.pop().expect("1>0"); - new_constraints.push(Constraint::NumericType(loc.clone(), value.clone())); - new_constraints.push(Constraint::IsSigned(loc.clone(), value.clone())); - new_constraints.push(Constraint::Equivalent(loc, value, ret)); - changed_something = true; - all_constraints_solved = false; - tracing::trace!(primitive = %prim, "we expanded out a unary primitive operation"); - } - - Primitive::Plus | Primitive::Times | Primitive::Divide => { - errors.push(TypeInferenceError::WrongPrimitiveArity( - loc, - prim, - 2, - 2, - args.len(), - )); - changed_something = true; - tracing::trace!(primitive = %prim, provided_arity = args.len(), "binary primitive operation arity is wrong"); - } - - Primitive::Minus => { - errors.push(TypeInferenceError::WrongPrimitiveArity( - loc, - prim, - 1, - 2, - args.len(), - )); - changed_something = true; - tracing::trace!(primitive = %prim, provided_arity = args.len(), "unary primitive operation arity is wrong"); - } - }, - // Some equivalences we can/should solve directly Constraint::Equivalent( loc,