diff --git a/build.rs b/build.rs index db2b647..89f3bf9 100644 --- a/build.rs +++ b/build.rs @@ -68,7 +68,10 @@ fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> { " assert_eq!(errors.len(), 0, \"file should have no validation errors\");" )?; writeln!(f, " let syntax_result = syntax.eval();")?; - writeln!(f, " let ir = IR::from(syntax);")?; + writeln!( + f, + " let ir = syntax.type_infer().expect(\"example is typed correctly\");" + )?; writeln!( f, " assert_eq!(syntax_result, ir.eval(), \"syntax equivalent to IR\");" diff --git a/src/compiler.rs b/src/compiler.rs index a793b7f..14eae47 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,5 +1,5 @@ -use crate::backend::Backend; use crate::syntax::Program as Syntax; +use crate::{backend::Backend, ir::TypeInferenceResult}; use codespan_reporting::{ diagnostic::Diagnostic, files::SimpleFiles, @@ -99,8 +99,38 @@ impl Compiler { return Ok(None); } - // Now that we've validated it, turn it into IR. - let ir = syntax.type_infer(); + // Now that we've validated it, let's do type inference, potentially turning + // into IR while we're at it. + let ir = match syntax.type_infer() { + TypeInferenceResult::Failure { + mut errors, + mut warnings, + } => { + let messages = errors + .drain(..) + .map(Into::into) + .chain(warnings.drain(..).map(Into::into)); + + for message in messages { + self.emit(message); + } + + return Ok(None); + } + + TypeInferenceResult::Success { + result, + mut warnings, + } => { + let messages = warnings.drain(..).map(Into::into); + + for message in messages { + self.emit(message); + } + + result + } + }; // Finally, send all this to Cranelift for conversion into an object file. let mut backend = Backend::object_file(Triple::host())?; diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index 87559a1..59ec352 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -80,6 +80,62 @@ impl FromStr for PrimitiveType { } impl PrimitiveType { + /// Return true if this type can be safely cast into the target type. + pub fn can_cast_to(&self, target: &PrimitiveType) -> bool { + match self { + PrimitiveType::U8 => match target { + PrimitiveType::U8 => true, + PrimitiveType::U16 => true, + PrimitiveType::U32 => true, + PrimitiveType::U64 => true, + PrimitiveType::I16 => true, + PrimitiveType::I32 => true, + PrimitiveType::I64 => true, + _ => false, + }, + PrimitiveType::U16 => match target { + PrimitiveType::U16 => true, + PrimitiveType::U32 => true, + PrimitiveType::U64 => true, + PrimitiveType::I32 => true, + PrimitiveType::I64 => true, + _ => false, + }, + PrimitiveType::U32 => match target { + PrimitiveType::U32 => true, + PrimitiveType::U64 => true, + PrimitiveType::I64 => true, + _ => false, + }, + PrimitiveType::U64 => match target { + PrimitiveType::U64 => true, + _ => false, + }, + PrimitiveType::I8 => match target { + PrimitiveType::I8 => true, + PrimitiveType::I16 => true, + PrimitiveType::I32 => true, + PrimitiveType::I64 => true, + _ => false, + }, + PrimitiveType::I16 => match target { + PrimitiveType::I16 => true, + PrimitiveType::I32 => true, + PrimitiveType::I64 => true, + _ => false, + }, + PrimitiveType::I32 => match target { + PrimitiveType::I32 => true, + PrimitiveType::I64 => true, + _ => false, + }, + PrimitiveType::I64 => match target { + PrimitiveType::I64 => true, + _ => false, + }, + } + } + /// Try to cast the given value to this type, returning the new value. /// /// Returns an error if the cast is not safe *in* *general*. This means that diff --git a/src/examples.rs b/src/examples.rs index 46dd175..82a864f 100644 --- a/src/examples.rs +++ b/src/examples.rs @@ -1,5 +1,4 @@ use crate::backend::Backend; -use crate::ir::Program as IR; use crate::syntax::Program as Syntax; use codespan_reporting::files::SimpleFiles; use cranelift_jit::JITModule; diff --git a/src/ir.rs b/src/ir.rs index 5ad3fa7..cc006d4 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -18,3 +18,4 @@ mod strings; mod type_infer; pub use ast::*; +pub use type_infer::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning}; diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 8cde98b..037b51f 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -61,7 +61,10 @@ impl Arbitrary for Program { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { crate::syntax::Program::arbitrary_with(args) - .prop_map(syntax::Program::type_infer) + .prop_map(|x| { + x.type_infer() + .expect("arbitrary_with should generate type-correct programs") + }) .boxed() } } @@ -330,7 +333,6 @@ where #[derive(Clone, Debug, Eq, PartialEq)] pub enum Type { - Variable(Location, ArcIntern), Primitive(PrimitiveType), } @@ -341,7 +343,6 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - Type::Variable(_, x) => allocator.text(x.to_string()), Type::Primitive(pt) => allocator.text(format!("{}", pt)), } } @@ -350,7 +351,6 @@ where impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Type::Variable(_, x) => write!(f, "{}", x), Type::Primitive(pt) => pt.fmt(f), } } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 0d1b201..6841508 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -40,7 +40,6 @@ impl Expression { let value = valref.eval(env)?; match t { - Type::Variable(_, _) => unimplemented!("how to cast to a type variable?"), Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), } } @@ -86,7 +85,7 @@ impl ValueOrRef { #[test] fn two_plus_three() { let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); - let ir = input.type_infer(); + let ir = input.type_infer().expect("test should be type-valid"); let output = ir.eval().expect("runs successfully"); assert_eq!("x = 5u64\n", &output); } @@ -95,7 +94,7 @@ fn two_plus_three() { fn lotsa_math() { let input = crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); - let ir = input.type_infer(); + let ir = input.type_infer().expect("test should be type-valid"); let output = ir.eval().expect("runs successfully"); assert_eq!("x = 7u64\n", &output); } diff --git a/src/ir/type_infer.rs b/src/ir/type_infer.rs index 92cad82..341e931 100644 --- a/src/ir/type_infer.rs +++ b/src/ir/type_infer.rs @@ -1,292 +1,35 @@ -use internment::ArcIntern; -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::atomic::AtomicUsize; +mod ast; +mod convert; +mod finalize; +mod solve; -use crate::eval::PrimitiveType; +use self::convert::convert_program; +use self::finalize::finalize_program; +use self::solve::solve_constraints; +pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning}; use crate::ir::ast as ir; -use crate::ir::ast::Type; -use crate::syntax::{self, ConstantType, Location}; - -enum Constraint { - /// The given type must be printable using the `print` built-in - Printable(Location, Type), - /// The provided numeric value fits in the given constant type - FitsInNumType(Location, ConstantType, u64), - /// The given primitive has the proper arguments types associated with it - ProperPrimitiveArgs(Location, ir::Primitive, Vec, Type), - /// The given type can be casted to the target type safely - CanCastTo(Location, Type, Type), -} - -/// This function takes a syntactic program and converts it into the IR version of the -/// program, with appropriate type variables introduced and their constraints added to -/// the given database. -/// -/// If the input function has been validated (which it should be), then this should run -/// into no error conditions. However, if you failed to validate the input, then this -/// function can panic. -fn convert_program( - mut program: syntax::Program, - constraint_db: &mut Vec, -) -> ir::Program { - let mut statements = Vec::new(); - let mut renames = HashMap::new(); - let mut bindings = HashMap::new(); - - for stmt in program.statements.drain(..) { - statements.append(&mut convert_statement( - stmt, - constraint_db, - &mut renames, - &mut bindings, - )); - } - - ir::Program { statements } -} - -/// This function takes a syntactic statements and converts it into a series of -/// IR statements, adding type variables and constraints as necessary. -/// -/// We generate a series of statements because we're going to flatten all -/// incoming expressions so that they are no longer recursive. This will -/// generate a bunch of new bindings for all the subexpressions, which we -/// return as a bundle. -/// -/// See the safety warning on [`convert_program`]! This function assumes that -/// you have run [`Statement::validate`], and will trigger panics in error -/// conditions if you have run that and had it come back clean. -fn convert_statement( - statement: syntax::Statement, - constraint_db: &mut Vec, - renames: &mut HashMap, ArcIntern>, - bindings: &mut HashMap, Type>, -) -> Vec { - match statement { - syntax::Statement::Print(loc, name) => { - let iname = ArcIntern::new(name); - let final_name = renames.get(&iname).map(Clone::clone).unwrap_or_else(|| iname.clone()); - let varty = bindings - .get(&final_name) - .expect("print variable defined before use") - .clone(); - - constraint_db.push(Constraint::Printable(loc.clone(), varty.clone())); - - vec![ir::Statement::Print(loc, varty, iname)] - } - - syntax::Statement::Binding(loc, name, expr) => { - let (mut prereqs, expr, ty) = - convert_expression(expr, constraint_db, renames, bindings); - let iname = ArcIntern::new(name); - let final_name = if bindings.contains_key(&iname) { - let new_name = gensym(iname.as_str()); - renames.insert(iname, new_name.clone()); - new_name - } else { - iname - }; - - bindings.insert(final_name.clone(), ty.clone()); - prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); - prereqs - } - } -} - -/// This function takes a syntactic expression and converts it into a series -/// of IR statements, adding type variables and constraints as necessary. -/// -/// We generate a series of statements because we're going to flatten all -/// incoming expressions so that they are no longer recursive. This will -/// generate a bunch of new bindings for all the subexpressions, which we -/// return as a bundle. -/// -/// See the safety warning on [`convert_program`]! This function assumes that -/// you have run [`Statement::validate`], and will trigger panics in error -/// conditions if you have run that and had it come back clean. -fn convert_expression( - expression: syntax::Expression, - constraint_db: &mut Vec, - renames: &HashMap, ArcIntern>, - bindings: &mut HashMap, Type>, -) -> (Vec, ir::Expression, Type) { - match expression { - syntax::Expression::Value(loc, val) => { - let newval = match val { - syntax::Value::Number(base, mctype, value) => { - if let Some(suggested_type) = mctype { - constraint_db.push(Constraint::FitsInNumType(loc.clone(), suggested_type, value)); - } - - match mctype { - None => ir::Value::U64(base, value), - Some(ConstantType::U8) => ir::Value::U8(base, value as u8), - Some(ConstantType::U16) => ir::Value::U16(base, value as u16), - Some(ConstantType::U32) => ir::Value::U32(base, value as u32), - Some(ConstantType::U64) => ir::Value::U64(base, value), - Some(ConstantType::I8) => ir::Value::I8(base, value as i8), - Some(ConstantType::I16) => ir::Value::I16(base, value as i16), - Some(ConstantType::I32) => ir::Value::I32(base, value as i32), - Some(ConstantType::I64) => ir::Value::I64(base, value as i64), - } - } - }; - let valtype = newval.type_of(); - - ( - vec![], - ir::Expression::Atomic(ir::ValueOrRef::Value(loc, valtype.clone(), newval)), - valtype, - ) - } - - syntax::Expression::Reference(loc, name) => { - let iname = ArcIntern::new(name); - let final_name = renames.get(&iname).cloned().unwrap_or(iname); - let rtype = bindings - .get(&final_name) - .cloned() - .expect("variable bound before use"); - let refexp = - ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name)); - - (vec![], refexp, rtype) - } - - syntax::Expression::Cast(loc, target, expr) => { - let (mut stmts, nexpr, etype) = - convert_expression(*expr, constraint_db, renames, bindings); - let val_or_ref = simplify_expr(nexpr, &mut stmts); - let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast"); - let target_type = Type::Primitive(target_prim_type); - let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); - - constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); - - (stmts, res, target_type) - } - - syntax::Expression::Primitive(loc, fun, mut args) => { - let primop = ir::Primitive::from_str(&fun).expect("valid primitive"); - let mut stmts = vec![]; - let mut nargs = vec![]; - let mut atypes = vec![]; - let ret_type = gentype(); - - for arg in args.drain(..) { - let (mut astmts, aexp, atype) = convert_expression(arg, constraint_db, renames, bindings); - - stmts.append(&mut astmts); - nargs.push(simplify_expr(aexp, &mut stmts)); - atypes.push(atype); - } - - constraint_db.push(Constraint::ProperPrimitiveArgs(loc.clone(), primop, atypes.clone(), ret_type.clone())); - - (stmts, ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs), ret_type) - } - } -} - -fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::ValueOrRef { - match expr { - ir::Expression::Atomic(v_or_ref) => v_or_ref, - expr => { - let etype = expr.type_of().clone(); - let loc = expr.location().clone(); - let nname = gensym("g"); - let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr); - - stmts.push(nbinding); - ir::ValueOrRef::Ref(loc, etype, nname) - } - } -} +use crate::syntax; impl syntax::Program { /// Infer the types for the syntactic AST, returning either a type-checked program in /// the IR, or a series of type errors encountered during inference. - /// + /// /// You really should have made sure that this program was validated before running /// this method, otherwise you may experience panics during operation. - pub fn type_infer(self) -> ir::Program { + pub fn type_infer(self) -> TypeInferenceResult { let mut constraint_db = vec![]; let program = convert_program(self, &mut constraint_db); - let mut changed_something = true; + let inference_result = solve_constraints(constraint_db); - // We want to run this inference endlessly, until either we have solved all of our - // constraints or we've gotten stuck somewhere. - while constraint_db.len() > 0 && changed_something { - // Set this to false at the top of the loop. We'll set this to true if we make - // progress in any way further down, but having this here prevents us from going - // into an infinite look when we can't figure stuff out. - changed_something = false; - // This is sort of a double-buffering thing; we're going to rename constraint_db - // and then set it to a new empty vector, which we'll add to as we find new - // constraints or find ourselves unable to solve existing ones. - let mut local_constraints = constraint_db; - constraint_db = vec![]; - - for constraint in local_constraints.drain(..) { - match constraint { - // Currently, all of our types are printable - Constraint::Printable(_loc, _ty) => {} - - Constraint::FitsInNumType(loc, ctype, val) => unimplemented!(), - - Constraint::ProperPrimitiveArgs(loc, prim, args, ret) => unimplemented!(), - - Constraint::CanCastTo(loc, from_type, to_type) => unimplemented!(), - } - } - } - - program + inference_result.map(|type_renames| finalize_program(program, type_renames)) } } -/// Generate a fresh new name based on the given name. -/// -/// The new name is guaranteed to be unique across the entirety of the -/// execution. This is achieved by using characters in the variable name -/// that would not be valid input, and by including a counter that is -/// incremented on every invocation. -fn gensym(name: &str) -> ArcIntern { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let new_name = format!( - "<{}:{}>", - name, - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - ); - ArcIntern::new(new_name) -} - -/// Generate a fresh new type; this will be a unique new type variable. -/// -/// The new name is guaranteed to be unique across the entirety of the -/// execution. This is achieved by using characters in the variable name -/// that would not be valid input, and by including a counter that is -/// incremented on every invocation. -fn gentype() -> Type { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let new_name = ArcIntern::new(format!( - "t<{}>", - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - )); - - Type::Variable(Location::manufactured(), new_name) -} - proptest::proptest! { #[test] fn translation_maintains_semantics(input: syntax::Program) { let syntax_result = input.eval(); - let ir = input.type_infer(); + let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); let ir_result = ir.eval(); assert_eq!(syntax_result, ir_result); } diff --git a/src/ir/type_infer/ast.rs b/src/ir/type_infer/ast.rs new file mode 100644 index 0000000..87d3cbc --- /dev/null +++ b/src/ir/type_infer/ast.rs @@ -0,0 +1,330 @@ +pub use crate::ir::ast::Primitive; +/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going +/// to want to use while we're doing type inference, but don't want to keep around +/// afterwards. These are: +/// +/// * A notion of a type variable +/// * An unknown numeric constant form +/// +use crate::{ + eval::PrimitiveType, + syntax::{self, ConstantType, Location}, +}; +use internment::ArcIntern; +use pretty::{DocAllocator, Pretty}; +use std::fmt; +use std::sync::atomic::AtomicUsize; + +/// We're going to represent variables as interned strings. +/// +/// These should be fast enough for comparison that it's OK, since it's going to end up +/// being pretty much the pointer to the string. +type Variable = ArcIntern; + +/// The representation of a program within our IR. For now, this is exactly one file. +/// +/// In addition, for the moment there's not really much of interest to hold here besides +/// the list of statements read from the file. Order is important. In the future, you +/// could imagine caching analysis information in this structure. +/// +/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used +/// to print the structure whenever possible, especially if you value your or your +/// user's time. The latter is useful for testing that conversions of `Program` retain +/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be +/// syntactically valid, although they may contain runtime issue like over- or underflow. +#[derive(Debug)] +pub struct Program { + // For now, a program is just a vector of statements. In the future, we'll probably + // extend this to include a bunch of other information, but for now: just a list. + pub(crate) statements: Vec, +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + let mut result = allocator.nil(); + + for stmt in self.statements.iter() { + // there's probably a better way to do this, rather than constantly + // adding to the end, but this works. + result = result + .append(stmt.pretty(allocator)) + .append(allocator.text(";")) + .append(allocator.hardline()); + } + + result + } +} + +/// The representation of a statement in the language. +/// +/// For now, this is either a binding site (`x = 4`) or a print statement +/// (`print x`). Someday, though, more! +/// +/// As with `Program`, this type implements [`Pretty`], which should +/// be used to display the structure whenever possible. It does not +/// implement [`Arbitrary`], though, mostly because it's slightly +/// complicated to do so. +/// +#[derive(Debug)] +pub enum Statement { + Binding(Location, Variable, Type, Expression), + Print(Location, Type, Variable), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Statement::Binding(_, var, _, expr) => allocator + .text(var.as_ref().to_string()) + .append(allocator.space()) + .append(allocator.text("=")) + .append(allocator.space()) + .append(expr.pretty(allocator)), + Statement::Print(_, _, var) => allocator + .text("print") + .append(allocator.space()) + .append(allocator.text(var.as_ref().to_string())), + } + } +} + +/// The representation of an expression. +/// +/// Note that expressions, like everything else in this syntax tree, +/// supports [`Pretty`], and it's strongly encouraged that you use +/// that trait/module when printing these structures. +/// +/// Also, Expressions at this point in the compiler are explicitly +/// defined so that they are *not* recursive. By this point, if an +/// expression requires some other data (like, for example, invoking +/// a primitive), any subexpressions have been bound to variables so +/// that the referenced data will always either be a constant or a +/// variable reference. +#[derive(Debug)] +pub enum Expression { + Atomic(ValueOrRef), + Cast(Location, Type, ValueOrRef), + Primitive(Location, Type, Primitive, Vec), +} + +impl Expression { + /// Return a reference to the type of the expression, as inferred or recently + /// computed. + pub fn type_of(&self) -> &Type { + match self { + Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, + Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, + Expression::Cast(_, t, _) => t, + Expression::Primitive(_, t, _, _) => t, + } + } + + /// Return a reference to the location associated with the expression. + pub fn location(&self) -> &Location { + match self { + Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, + Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, + Expression::Cast(l, _, _) => l, + Expression::Primitive(l, _, _, _) => l, + } + } +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Expression::Atomic(x) => x.pretty(allocator), + Expression::Cast(_, t, e) => allocator + .text("<") + .append(t.pretty(allocator)) + .append(allocator.text(">")) + .append(e.pretty(allocator)), + Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { + op.pretty(allocator).append(exprs[0].pretty(allocator)) + } + Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { + let left = exprs[0].pretty(allocator); + let right = exprs[1].pretty(allocator); + + left.append(allocator.space()) + .append(op.pretty(allocator)) + .append(allocator.space()) + .append(right) + .parens() + } + Expression::Primitive(_, _, op, exprs) => { + allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) + } + } + } +} + +/// An expression that is always either a value or a reference. +/// +/// This is the type used to guarantee that we don't nest expressions +/// at this level. Instead, expressions that take arguments take one +/// of these, which can only be a constant or a reference. +#[derive(Clone, Debug)] +pub enum ValueOrRef { + Value(Location, Type, Value), + Ref(Location, Type, ArcIntern), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + ValueOrRef::Value(_, _, v) => v.pretty(allocator), + ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), + } + } +} + +impl From for Expression { + fn from(value: ValueOrRef) -> Self { + Expression::Atomic(value) + } +} + +/// A constant in the IR. +/// +/// The optional argument in numeric types is the base that was used by the +/// user to input the number. By retaining it, we can ensure that if we need +/// to print the number back out, we can do so in the form that the user +/// entered it. +#[derive(Clone, Debug)] +pub enum Value { + Unknown(Option, u64), + I8(Option, i8), + I16(Option, i16), + I32(Option, i32), + I64(Option, i64), + U8(Option, u8), + U16(Option, u16), + U32(Option, u32), + U64(Option, u64), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + let pretty_internal = |opt_base: &Option, x, t| { + syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator) + }; + + let pretty_internal_signed = |opt_base, x: i64, t| { + let base = pretty_internal(opt_base, x.unsigned_abs(), t); + + allocator.text("-").append(base) + }; + + match self { + Value::Unknown(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::U64) + } + Value::I8(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I8) + } + Value::I16(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I16) + } + Value::I32(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I32) + } + Value::I64(opt_base, value) => { + pretty_internal_signed(opt_base, *value, ConstantType::I64) + } + Value::U8(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U8) + } + Value::U16(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U16) + } + Value::U32(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U32) + } + Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Type { + Variable(Location, ArcIntern), + Primitive(PrimitiveType), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Type::Variable(_, x) => allocator.text(x.to_string()), + Type::Primitive(pt) => allocator.text(format!("{}", pt)), + } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Variable(_, x) => write!(f, "{}", x), + Type::Primitive(pt) => pt.fmt(f), + } + } +} + +/// Generate a fresh new name based on the given name. +/// +/// The new name is guaranteed to be unique across the entirety of the +/// execution. This is achieved by using characters in the variable name +/// that would not be valid input, and by including a counter that is +/// incremented on every invocation. +pub fn gensym(name: &str) -> ArcIntern { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let new_name = format!( + "<{}:{}>", + name, + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + ); + ArcIntern::new(new_name) +} + +/// Generate a fresh new type; this will be a unique new type variable. +/// +/// The new name is guaranteed to be unique across the entirety of the +/// execution. This is achieved by using characters in the variable name +/// that would not be valid input, and by including a counter that is +/// incremented on every invocation. +pub fn gentype() -> Type { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let new_name = ArcIntern::new(format!( + "t<{}>", + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + )); + + Type::Variable(Location::manufactured(), new_name) +} diff --git a/src/ir/type_infer/convert.rs b/src/ir/type_infer/convert.rs new file mode 100644 index 0000000..8b61584 --- /dev/null +++ b/src/ir/type_infer/convert.rs @@ -0,0 +1,232 @@ +use super::ast as ir; +use super::ast::Type; +use crate::eval::PrimitiveType; +use crate::ir::type_infer::solve::Constraint; +use crate::syntax::{self, ConstantType}; +use internment::ArcIntern; +use std::collections::HashMap; +use std::str::FromStr; + +/// This function takes a syntactic program and converts it into the IR version of the +/// program, with appropriate type variables introduced and their constraints added to +/// the given database. +/// +/// If the input function has been validated (which it should be), then this should run +/// into no error conditions. However, if you failed to validate the input, then this +/// function can panic. +pub fn convert_program( + mut program: syntax::Program, + constraint_db: &mut Vec, +) -> ir::Program { + let mut statements = Vec::new(); + let mut renames = HashMap::new(); + let mut bindings = HashMap::new(); + + for stmt in program.statements.drain(..) { + statements.append(&mut convert_statement( + stmt, + constraint_db, + &mut renames, + &mut bindings, + )); + } + + ir::Program { statements } +} + +/// This function takes a syntactic statements and converts it into a series of +/// IR statements, adding type variables and constraints as necessary. +/// +/// We generate a series of statements because we're going to flatten all +/// incoming expressions so that they are no longer recursive. This will +/// generate a bunch of new bindings for all the subexpressions, which we +/// return as a bundle. +/// +/// See the safety warning on [`convert_program`]! This function assumes that +/// you have run [`Statement::validate`], and will trigger panics in error +/// conditions if you have run that and had it come back clean. +fn convert_statement( + statement: syntax::Statement, + constraint_db: &mut Vec, + renames: &mut HashMap, ArcIntern>, + bindings: &mut HashMap, Type>, +) -> Vec { + match statement { + syntax::Statement::Print(loc, name) => { + let iname = ArcIntern::new(name); + let final_name = renames + .get(&iname) + .map(Clone::clone) + .unwrap_or_else(|| iname.clone()); + let varty = bindings + .get(&final_name) + .expect("print variable defined before use") + .clone(); + + constraint_db.push(Constraint::Printable(loc.clone(), varty.clone())); + + vec![ir::Statement::Print(loc, varty, iname)] + } + + syntax::Statement::Binding(loc, name, expr) => { + let (mut prereqs, expr, ty) = + convert_expression(expr, constraint_db, renames, bindings); + let iname = ArcIntern::new(name); + let final_name = if bindings.contains_key(&iname) { + let new_name = ir::gensym(iname.as_str()); + renames.insert(iname, new_name.clone()); + new_name + } else { + iname + }; + + bindings.insert(final_name.clone(), ty.clone()); + prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); + prereqs + } + } +} + +/// This function takes a syntactic expression and converts it into a series +/// of IR statements, adding type variables and constraints as necessary. +/// +/// We generate a series of statements because we're going to flatten all +/// incoming expressions so that they are no longer recursive. This will +/// generate a bunch of new bindings for all the subexpressions, which we +/// return as a bundle. +/// +/// See the safety warning on [`convert_program`]! This function assumes that +/// you have run [`Statement::validate`], and will trigger panics in error +/// conditions if you have run that and had it come back clean. +fn convert_expression( + expression: syntax::Expression, + constraint_db: &mut Vec, + renames: &HashMap, ArcIntern>, + bindings: &mut HashMap, Type>, +) -> (Vec, ir::Expression, Type) { + match expression { + syntax::Expression::Value(loc, val) => match val { + syntax::Value::Number(base, mctype, value) => { + let (newval, newtype) = match mctype { + None => { + let newtype = ir::gentype(); + let newval = ir::Value::Unknown(base, value); + + constraint_db.push(Constraint::NumericType(loc.clone(), newtype.clone())); + (newval, newtype) + } + Some(ConstantType::U8) => ( + ir::Value::U8(base, value as u8), + ir::Type::Primitive(PrimitiveType::U8), + ), + Some(ConstantType::U16) => ( + ir::Value::U16(base, value as u16), + ir::Type::Primitive(PrimitiveType::U16), + ), + Some(ConstantType::U32) => ( + ir::Value::U32(base, value as u32), + ir::Type::Primitive(PrimitiveType::U32), + ), + Some(ConstantType::U64) => ( + ir::Value::U64(base, value), + ir::Type::Primitive(PrimitiveType::U64), + ), + Some(ConstantType::I8) => ( + ir::Value::I8(base, value as i8), + ir::Type::Primitive(PrimitiveType::I8), + ), + Some(ConstantType::I16) => ( + ir::Value::I16(base, value as i16), + ir::Type::Primitive(PrimitiveType::I16), + ), + Some(ConstantType::I32) => ( + ir::Value::I32(base, value as i32), + ir::Type::Primitive(PrimitiveType::I32), + ), + Some(ConstantType::I64) => ( + ir::Value::I64(base, value as i64), + ir::Type::Primitive(PrimitiveType::I64), + ), + }; + + constraint_db.push(Constraint::FitsInNumType(loc.clone(), newtype.clone(), value)); + ( + vec![], + ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), + newtype, + ) + } + }, + + syntax::Expression::Reference(loc, name) => { + let iname = ArcIntern::new(name); + let final_name = renames.get(&iname).cloned().unwrap_or(iname); + let rtype = bindings + .get(&final_name) + .cloned() + .expect("variable bound before use"); + let refexp = + ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name)); + + (vec![], refexp, rtype) + } + + syntax::Expression::Cast(loc, target, expr) => { + let (mut stmts, nexpr, etype) = + convert_expression(*expr, constraint_db, renames, bindings); + let val_or_ref = simplify_expr(nexpr, &mut stmts); + let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast"); + let target_type = Type::Primitive(target_prim_type); + let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); + + constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); + + (stmts, res, target_type) + } + + syntax::Expression::Primitive(loc, fun, mut args) => { + let primop = ir::Primitive::from_str(&fun).expect("valid primitive"); + let mut stmts = vec![]; + let mut nargs = vec![]; + let mut atypes = vec![]; + let ret_type = ir::gentype(); + + for arg in args.drain(..) { + let (mut astmts, aexp, atype) = + convert_expression(arg, constraint_db, renames, bindings); + + stmts.append(&mut astmts); + nargs.push(simplify_expr(aexp, &mut stmts)); + atypes.push(atype); + } + + constraint_db.push(Constraint::ProperPrimitiveArgs( + loc.clone(), + primop, + atypes.clone(), + ret_type.clone(), + )); + + ( + stmts, + ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs), + ret_type, + ) + } + } +} + +fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::ValueOrRef { + match expr { + ir::Expression::Atomic(v_or_ref) => v_or_ref, + expr => { + let etype = expr.type_of().clone(); + let loc = expr.location().clone(); + let nname = ir::gensym("g"); + let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr); + + stmts.push(nbinding); + ir::ValueOrRef::Ref(loc, etype, nname) + } + } +} diff --git a/src/ir/type_infer/finalize.rs b/src/ir/type_infer/finalize.rs new file mode 100644 index 0000000..cc8c5de --- /dev/null +++ b/src/ir/type_infer/finalize.rs @@ -0,0 +1,197 @@ +use super::ast as input; +use crate::{eval::PrimitiveType, ir as output}; +use internment::ArcIntern; +use std::collections::HashMap; + +pub fn finalize_program( + mut program: input::Program, + type_renames: HashMap, input::Type>, +) -> output::Program { + output::Program { + statements: program + .statements + .drain(..) + .map(|x| finalize_statement(x, &type_renames)) + .collect(), + } +} + +fn finalize_statement( + statement: input::Statement, + type_renames: &HashMap, input::Type>, +) -> output::Statement { + match statement { + input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding( + loc, + var, + finalize_type(ty, type_renames), + finalize_expression(expr, type_renames), + ), + input::Statement::Print(loc, ty, var) => { + output::Statement::Print(loc, finalize_type(ty, type_renames), var) + } + } +} + +fn finalize_expression( + expression: input::Expression, + type_renames: &HashMap, input::Type>, +) -> output::Expression { + match expression { + input::Expression::Atomic(val_or_ref) => { + output::Expression::Atomic(finalize_val_or_ref(val_or_ref, type_renames)) + } + input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast( + loc, + finalize_type(target, type_renames), + finalize_val_or_ref(val_or_ref, type_renames), + ), + input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive( + loc, + finalize_type(ty, type_renames), + prim, + args.drain(..) + .map(|x| finalize_val_or_ref(x, type_renames)) + .collect(), + ), + } +} + +fn finalize_type( + ty: input::Type, + type_renames: &HashMap, input::Type>, +) -> output::Type { + match ty { + input::Type::Primitive(x) => output::Type::Primitive(x), + input::Type::Variable(loc, name) => match type_renames.get(&name) { + Some(input::Type::Primitive(x)) => output::Type::Primitive(*x), + res => panic!( + "ACK! Internal error cleaning up temporary type name at {:?}: got {:?}", + loc, res + ), + }, + } +} + +fn finalize_val_or_ref( + valref: input::ValueOrRef, + type_renames: &HashMap, input::Type>, +) -> output::ValueOrRef { + match valref { + input::ValueOrRef::Ref(loc, ty, var) => { + output::ValueOrRef::Ref(loc, finalize_type(ty, type_renames), var) + } + input::ValueOrRef::Value(loc, ty, val) => { + let new_type = finalize_type(ty, type_renames); + + match val { + input::Value::Unknown(base, value) => match new_type { + output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U8(base, value as u8), + ), + output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U16(base, value as u16), + ), + output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U32(base, value as u32), + ), + output::Type::Primitive(PrimitiveType::U64) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U64(base, value), + ), + output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I8(base, value as i8), + ), + output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I16(base, value as i16), + ), + output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I32(base, value as i32), + ), + output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I64(base, value as i64), + ), + }, + + input::Value::U8(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U8) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value)) + } + + input::Value::U16(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U16) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value)) + } + + input::Value::U32(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U32) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value)) + } + + input::Value::U64(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U64) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) + } + + input::Value::I8(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I8) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value)) + } + + input::Value::I16(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I16) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value)) + } + + input::Value::I32(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I32) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value)) + } + + input::Value::I64(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I64) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value)) + } + } + } + } +} diff --git a/src/ir/type_infer/solve.rs b/src/ir/type_infer/solve.rs new file mode 100644 index 0000000..66cbedc --- /dev/null +++ b/src/ir/type_infer/solve.rs @@ -0,0 +1,341 @@ +use super::ast as ir; +use super::ast::Type; +use crate::{eval::PrimitiveType, syntax::Location}; +use codespan_reporting::diagnostic::Diagnostic; +use internment::ArcIntern; +use std::collections::HashMap; + +pub enum Constraint { + /// The given type must be printable using the `print` built-in + Printable(Location, Type), + /// The provided numeric value fits in the given constant type + FitsInNumType(Location, Type, u64), + /// The given primitive has the proper arguments types associated with it + ProperPrimitiveArgs(Location, ir::Primitive, Vec, Type), + /// The given type can be casted to the target type safely + CanCastTo(Location, Type, Type), + /// The given type must be some numeric type + NumericType(Location, Type), + /// The two types should be equivalent + Equivalent(Location, Type, Type), +} + +pub enum TypeInferenceResult { + Success { + result: Result, + warnings: Vec, + }, + Failure { + errors: Vec, + warnings: Vec, + }, +} + +impl TypeInferenceResult { + // If this was a successful type inference, run the function over the result to + // create a new result. + pub fn map(self, f: F) -> TypeInferenceResult + where + F: FnOnce(R) -> U, + { + match self { + TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success { + result: f(result), + warnings, + }, + + TypeInferenceResult::Failure { errors, warnings } => { + TypeInferenceResult::Failure { errors, warnings } + } + } + } + + // Return the final result, or panic if it's not a success + pub fn expect(self, msg: &str) -> R { + match self { + TypeInferenceResult::Success { result, .. } => result, + TypeInferenceResult::Failure { .. } => { + panic!("tried to get value from failed type inference: {}", msg) + } + } + } +} + +pub enum TypeInferenceError { + ConstantTooLarge(Location, PrimitiveType, u64), + NotEquivalent(Location, PrimitiveType, PrimitiveType), + CannotSafelyCast(Location, PrimitiveType, PrimitiveType), + WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize), + CouldNotSolve(Constraint), +} + +impl From for Diagnostic { + fn from(value: TypeInferenceError) -> Self { + unimplemented!() + } +} + +pub enum TypeInferenceWarning { + DefaultedTo(Location, Type), +} + +impl From for Diagnostic { + fn from(value: TypeInferenceWarning) -> Self { + unimplemented!() + } +} + +pub fn solve_constraints( + mut constraint_db: Vec, +) -> TypeInferenceResult, Type>> { + let mut type_renames = HashMap::new(); + let mut errors = vec![]; + let mut warnings = vec![]; + let mut changed_something = true; + + // We want to run this inference endlessly, until either we have solved all of our + // constraints. Internal to the loop, we have a check that will make sure that we + // do (eventually) stop. + while changed_something && !constraint_db.is_empty() { + // Set this to false at the top of the loop. We'll set this to true if we make + // progress in any way further down, but having this here prevents us from going + // into an infinite look when we can't figure stuff out. + changed_something = false; + // This is sort of a double-buffering thing; we're going to rename constraint_db + // and then set it to a new empty vector, which we'll add to as we find new + // constraints or find ourselves unable to solve existing ones. + let mut local_constraints = constraint_db; + constraint_db = vec![]; + + // OK. First thing we're going to do is run through all of our constraints, + // and see if we can solve any, or reduce them to theoretically more simple + // constraints. + for constraint in local_constraints.drain(..) { + match constraint { + // Currently, all of our types are printable + Constraint::Printable(_loc, _ty) => changed_something = true, + + // If the first type is a variable, then we rename it to the second. + Constraint::Equivalent(_, Type::Variable(_, n), second) => { + type_renames.insert(n, second); + changed_something = true; + } + + // If the second type is a variable (I guess the first wasn't, then rename it to the first) + Constraint::Equivalent(_, first, Type::Variable(_, n)) => { + type_renames.insert(n, first); + changed_something = true; + } + + // Otherwise, we ar testing if two concrete types are equivalent, and just need to see + // if they're the same. + Constraint::Equivalent(loc, Type::Primitive(a), Type::Primitive(b)) => { + if a != b { + errors.push(TypeInferenceError::NotEquivalent(loc, a, b)); + } + changed_something = true; + } + + // Make sure that the provided number fits within the provided constant type. For the + // moment, we're going to call an error here a failure, although this could be a + // warning in the future. + Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => { + let is_too_big = match ctype { + PrimitiveType::U8 => (u8::MAX as u64) < val, + PrimitiveType::U16 => (u16::MAX as u64) < val, + PrimitiveType::U32 => (u32::MAX as u64) < val, + PrimitiveType::U64 => false, + PrimitiveType::I8 => (i8::MAX as u64) < val, + PrimitiveType::I16 => (i16::MAX as u64) < val, + PrimitiveType::I32 => (i32::MAX as u64) < val, + PrimitiveType::I64 => (i64::MAX as u64) < val, + }; + + if is_too_big { + errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); + } + + changed_something = true; + } + + // If we have a non-constant type, then let's see if we can advance this to a constant + // type + Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => { + match type_renames.get(&var) { + None => constraint_db.push(Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val)), + Some(nt) => { + constraint_db.push(Constraint::FitsInNumType(loc, nt.clone(), val)); + changed_something = true; + } + } + } + + // If the left type in a "can cast to" check is a variable, let's see if we can advance + // it into something more tangible + Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => { + match type_renames.get(&var) { + None => constraint_db.push(Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type)), + Some(nt) => { + constraint_db.push(Constraint::CanCastTo(loc, nt.clone(), to_type)); + changed_something = true; + } + } + } + + // If the right type in a "can cast to" check is a variable, same deal + Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => { + match type_renames.get(&var) { + None => constraint_db.push(Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var))), + Some(nt) => { + constraint_db.push(Constraint::CanCastTo(loc, from_type, nt.clone())); + changed_something = true; + } + } + } + + // If both of them are types, then we can actually do the test. yay! + Constraint::CanCastTo( + loc, + Type::Primitive(from_type), + Type::Primitive(to_type), + ) => { + if !from_type.can_cast_to(&to_type) { + errors.push(TypeInferenceError::CannotSafelyCast( + loc, from_type, to_type, + )); + } + changed_something = true; + } + + // As per usual, if we're trying to test if a type variable is numeric, first + // we try to advance it to a primitive + Constraint::NumericType(loc, Type::Variable(vloc, var)) => { + match type_renames.get(&var) { + None => constraint_db.push(Constraint::NumericType(loc, Type::Variable(vloc, var))), + Some(nt) => { + constraint_db.push(Constraint::NumericType(loc, nt.clone())); + changed_something = true; + } + } + } + + // Of course, if we get to a primitive type, then it's true, because all of our + // primitive types are numbers + Constraint::NumericType(_, Type::Primitive(_)) => { + changed_something = true; + } + + // OK, this one could be a little tricky if we tried to do it all at once, but + // instead what we're going to do here is just use this constraint to generate + // a bunch more constraints, and then go have the engine solve those. The only + // real errors we're going to come up with here are "arity errors"; errors we + // find by discovering that the number of arguments provided doesn't make sense + // given the primitive being used. + Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim { + ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide + if args.len() != 2 => + { + errors.push(TypeInferenceError::WrongPrimitiveArity( + loc, + prim, + 2, + 2, + args.len(), + )); + changed_something = true; + } + + ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { + let right = args.pop().expect("2 > 0"); + let left = args.pop().expect("2 > 1"); + + // technically testing that both are numeric is redundant, but it might give + // a slightly helpful type error if we do both. + constraint_db.push(Constraint::NumericType(loc.clone(), left.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), right.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); + constraint_db.push(Constraint::Equivalent(loc.clone(), left.clone(), right)); + constraint_db.push(Constraint::Equivalent(loc, left, ret)); + changed_something = true; + } + + ir::Primitive::Minus if args.is_empty() || args.len() > 2 => { + errors.push(TypeInferenceError::WrongPrimitiveArity( + loc, + prim, + 1, + 2, + args.len(), + )); + changed_something = true; + } + + ir::Primitive::Minus if args.len() == 1 => { + let arg = args.pop().expect("1 > 0"); + constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); + constraint_db.push(Constraint::Equivalent(loc, arg, ret)); + changed_something = true; + } + + ir::Primitive::Minus => { + let right = args.pop().expect("2 > 0"); + let left = args.pop().expect("2 > 1"); + + // technically testing that both are numeric is redundant, but it might give + // a slightly helpful type error if we do both. + constraint_db.push(Constraint::NumericType(loc.clone(), left.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), right.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); + constraint_db.push(Constraint::Equivalent(loc.clone(), left.clone(), right)); + constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret)); + changed_something = true; + } + }, + } + } + + // If that didn't actually come up with anything, and we just recycled all the constraints + // back into the database unchanged, then let's take a look for cases in which we just + // wanted something we didn't know to be a number. Basically, those are cases where the + // user just wrote a number, but didn't tell us what type it was, and there isn't enough + // information in the context to tell us. If that happens, we'll just set that type to + // be u64, and warn the user that we did so. + if !changed_something && !constraint_db.is_empty() { + local_constraints = constraint_db; + constraint_db = vec![]; + + for constraint in local_constraints.drain(..) { + match constraint { + Constraint::NumericType(loc, Type::Variable(_, name)) => { + let resty = Type::Primitive(PrimitiveType::U64); + type_renames.insert(name, resty.clone()); + warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty)); + changed_something = true; + } + + _ => constraint_db.push(constraint), + } + } + } + } + + // OK, we left our loop. Which means that either we solved everything, or we didn't. + // If we didn't, turn the unsolved constraints into type inference errors, and add + // them to our error list. + let mut unsolved_constraint_errors = constraint_db + .drain(..) + .map(TypeInferenceError::CouldNotSolve) + .collect(); + errors.append(&mut unsolved_constraint_errors); + + // How'd we do? + if errors.is_empty() { + TypeInferenceResult::Success { + result: type_renames, + warnings, + } + } else { + TypeInferenceResult::Failure { errors, warnings } + } +} diff --git a/src/repl.rs b/src/repl.rs index 4e32694..a4e2849 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,4 +1,5 @@ use crate::backend::{Backend, BackendError}; +use crate::ir::TypeInferenceResult; use crate::syntax::{ConstantType, Location, ParserError, Statement}; use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::files::SimpleFiles; @@ -128,18 +129,32 @@ impl REPL { .source(); let syntax = Statement::parse(entry, source)?; - // if this is a variable binding, and we've never defined this variable before, - // we should tell cranelift about it. this is optimistic; if we fail to compile, - // then we won't use this definition until someone tries again. - if let Statement::Binding(_, ref name, _) = syntax { - if !self.variable_binding_sites.contains_key(name.as_str()) { - self.jitter.define_string(name)?; - self.jitter - .define_variable(name.clone(), ConstantType::U64)?; + let program = match syntax { + Statement::Binding(loc, name, expr) => { + // if this is a variable binding, and we've never defined this variable before, + // we should tell cranelift about it. this is optimistic; if we fail to compile, + // then we won't use this definition until someone tries again. + if !self.variable_binding_sites.contains_key(&name) { + self.jitter.define_string(&name)?; + self.jitter + .define_variable(name.clone(), ConstantType::U64)?; + } + + crate::syntax::Program { + statements: vec![ + Statement::Binding(loc.clone(), name.clone(), expr), + Statement::Print(loc, name), + ], + } } + + nonbinding => crate::syntax::Program { + statements: vec![nonbinding], + }, }; - let (mut errors, mut warnings) = syntax.validate(&mut self.variable_binding_sites); + let (mut errors, mut warnings) = + program.validate_with_bindings(&mut self.variable_binding_sites); let stop = !errors.is_empty(); let messages = errors .drain(..) @@ -154,16 +169,39 @@ impl REPL { return Ok(()); } - let ir = crate::syntax::Program { - statements: vec![syntax], + match program.type_infer() { + TypeInferenceResult::Failure { + mut errors, + mut warnings, + } => { + let messages = errors + .drain(..) + .map(Into::into) + .chain(warnings.drain(..).map(Into::into)); + + for message in messages { + self.emit_diagnostic(message)?; + } + + Ok(()) + } + + TypeInferenceResult::Success { + result, + mut warnings, + } => { + for message in warnings.drain(..).map(Into::into) { + self.emit_diagnostic(message)?; + } + let name = format!("line{}", line_no); + let function_id = self.jitter.compile_function(&name, result)?; + self.jitter.module.finalize_definitions()?; + let compiled_bytes = self.jitter.bytes(function_id); + let compiled_function = + unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; + compiled_function(); + Ok(()) + } } - .type_infer(); - let name = format!("line{}", line_no); - let function_id = self.jitter.compile_function(&name, ir)?; - self.jitter.module.finalize_definitions()?; - let compiled_bytes = self.jitter.bytes(function_id); - let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; - compiled_function(); - Ok(()) } } diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index cdfca71..a3589f6 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -65,12 +65,24 @@ impl Program { /// example, and generates warnings for things that are inadvisable but not /// actually a problem. pub fn validate(&self) -> (Vec, Vec) { + let mut bound_variables = HashMap::new(); + self.validate_with_bindings(&mut bound_variables) + } + + /// Validate that the program makes semantic sense, not just syntactic sense. + /// + /// This checks for things like references to variables that don't exist, for + /// example, and generates warnings for things that are inadvisable but not + /// actually a problem. + pub fn validate_with_bindings( + &self, + bound_variables: &mut HashMap, + ) -> (Vec, Vec) { let mut errors = vec![]; let mut warnings = vec![]; - let mut bound_variables = HashMap::new(); for stmt in self.statements.iter() { - let (mut new_errors, mut new_warnings) = stmt.validate(&mut bound_variables); + let (mut new_errors, mut new_warnings) = stmt.validate(bound_variables); errors.append(&mut new_errors); warnings.append(&mut new_warnings); } @@ -89,7 +101,7 @@ impl Statement { /// occurs. We use a `HashMap` to map these bound locations to the locations /// where their bound, because these locations are handy when generating errors /// and warnings. - pub fn validate( + fn validate( &self, bound_variables: &mut HashMap, ) -> (Vec, Vec) {