From 3687785540341bc7cd6aeb7858cd0c71fd3a233d Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Mon, 19 Jun 2023 21:16:28 -0700 Subject: [PATCH] [checkpoint] Start the switch to type inference. --- src/backend/into_crane.rs | 76 +++++----- src/compiler.rs | 3 +- src/ir.rs | 2 +- src/ir/ast.rs | 80 ++++++++--- src/ir/eval.rs | 20 +-- src/ir/from_syntax.rs | 213 --------------------------- src/ir/strings.rs | 4 +- src/ir/type_infer.rs | 293 ++++++++++++++++++++++++++++++++++++++ src/repl.rs | 6 +- src/syntax/validate.rs | 11 +- 10 files changed, 424 insertions(+), 284 deletions(-) delete mode 100644 src/ir/from_syntax.rs create mode 100644 src/ir/type_infer.rs diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index fa038dc..759eab1 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -125,7 +125,7 @@ impl Backend { // Print statements are fairly easy to compile: we just lookup the // output buffer, the address of the string to print, and the value // of whatever variable we're printing. Then we just call print. - Statement::Print(ann, var) => { + Statement::Print(ann, t, var) => { // Get the output buffer (or null) from our general compilation context. let buffer_ptr = self.output_buffer_ptr(); let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); @@ -137,7 +137,7 @@ impl Backend { // Look up the value for the variable. Because this might be a // global variable (and that requires special logic), we just turn // this into an `Expression` and re-use the logic in that implementation. - let (val, vtype) = ValueOrRef::Ref(ann, var).into_crane( + let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane( &mut builder, &variable_table, &pre_defined_symbols, @@ -163,7 +163,7 @@ impl Backend { } // Variable binding is a little more con - Statement::Binding(_, var_name, value) => { + Statement::Binding(_, var_name, _, value) => { // Kick off to the `Expression` implementation to see what value we're going // to bind to this variable. let (val, etype) = @@ -254,50 +254,62 @@ impl Expression { Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables), Expression::Cast(_, target_type, expr) => { - let (val, val_type) = expr.into_crane(builder, local_variables, global_variables)?; + let (val, val_type) = + expr.into_crane(builder, local_variables, global_variables)?; match (val_type, &target_type) { (ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)), - (ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => - Ok((builder.ins().sextend(types::I16, val), ConstantType::I16)), - (ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => - Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)), - (ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => - Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)), + (ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => { + Ok((builder.ins().sextend(types::I16, val), ConstantType::I16)) + } + (ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => { + Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)) + } + (ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)) + } (ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)), - (ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => - Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)), - (ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => - Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)), + (ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => { + Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)) + } + (ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)) + } (ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)), - (ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => - Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)), + (ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)) + } (ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)), (ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)), - (ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => - Ok((builder.ins().uextend(types::I16, val), ConstantType::U16)), - (ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => - Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)), - (ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => - Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)), + (ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => { + Ok((builder.ins().uextend(types::I16, val), ConstantType::U16)) + } + (ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => { + Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)) + } + (ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)) + } (ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)), - (ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => - Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)), - (ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => - Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)), + (ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => { + Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)) + } + (ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)) + } (ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)), - (ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => - Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)), + (ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)) + } (ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)), - _ => Err(BackendError::InvalidTypeCast { from: val_type.into(), to: target_type, @@ -305,7 +317,7 @@ impl Expression { } } - Expression::Primitive(_, prim, mut vals) => { + Expression::Primitive(_, _, prim, mut vals) => { let mut values = vec![]; let mut first_type = None; @@ -357,7 +369,7 @@ impl ValueOrRef { match self { // Values are pretty straightforward to compile, mostly because we only // have one type of variable, and it's an integer type. - ValueOrRef::Value(_, val) => match val { + ValueOrRef::Value(_, _, val) => match val { Value::I8(_, v) => { Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8)) } @@ -387,7 +399,7 @@ impl ValueOrRef { )), }, - ValueOrRef::Ref(_, name) => { + ValueOrRef::Ref(_, _, name) => { // first we see if this is a local variable (which is nicer, from an // optimization point of view.) if let Some((local_var, etype)) = local_variables.get(&name) { diff --git a/src/compiler.rs b/src/compiler.rs index 41cc037..a793b7f 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,5 +1,4 @@ use crate::backend::Backend; -use crate::ir::Program as IR; use crate::syntax::Program as Syntax; use codespan_reporting::{ diagnostic::Diagnostic, @@ -101,7 +100,7 @@ impl Compiler { } // Now that we've validated it, turn it into IR. - let ir = IR::from(syntax); + let ir = syntax.type_infer(); // Finally, send all this to Cranelift for conversion into an object file. let mut backend = Backend::object_file(Triple::host())?; diff --git a/src/ir.rs b/src/ir.rs index 88454e4..5ad3fa7 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -14,7 +14,7 @@ //! come to for analysis and optimization work. mod ast; mod eval; -mod from_syntax; mod strings; +mod type_infer; pub use ast::*; diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 8cf69be..8cde98b 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -8,7 +8,7 @@ use proptest::{ prelude::Arbitrary, strategy::{BoxedStrategy, Strategy}, }; -use std::fmt; +use std::{fmt, str::FromStr}; /// We're going to represent variables as interned strings. /// @@ -61,7 +61,7 @@ impl Arbitrary for Program { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { crate::syntax::Program::arbitrary_with(args) - .prop_map(Program::from) + .prop_map(syntax::Program::type_infer) .boxed() } } @@ -78,8 +78,8 @@ impl Arbitrary for Program { /// #[derive(Debug)] pub enum Statement { - Binding(Location, Variable, Expression), - Print(Location, Variable), + Binding(Location, Variable, Type, Expression), + Print(Location, Type, Variable), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement @@ -89,13 +89,13 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - Statement::Binding(_, var, expr) => allocator + Statement::Binding(_, var, _, expr) => allocator .text(var.as_ref().to_string()) .append(allocator.space()) .append(allocator.text("=")) .append(allocator.space()) .append(expr.pretty(allocator)), - Statement::Print(_, var) => allocator + Statement::Print(_, _, var) => allocator .text("print") .append(allocator.space()) .append(allocator.text(var.as_ref().to_string())), @@ -119,7 +119,30 @@ where pub enum Expression { Atomic(ValueOrRef), Cast(Location, Type, ValueOrRef), - Primitive(Location, Primitive, Vec), + Primitive(Location, Type, Primitive, Vec), +} + +impl Expression { + /// Return a reference to the type of the expression, as inferred or recently + /// computed. + pub fn type_of(&self) -> &Type { + match self { + Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, + Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, + Expression::Cast(_, t, _) => t, + Expression::Primitive(_, t, _, _) => t, + } + } + + /// Return a reference to the location associated with the expression. + pub fn location(&self) -> &Location { + match self { + Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, + Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, + Expression::Cast(l, _, _) => l, + Expression::Primitive(l, _, _, _) => l, + } + } } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression @@ -135,10 +158,10 @@ where .append(t.pretty(allocator)) .append(allocator.text(">")) .append(e.pretty(allocator)), - Expression::Primitive(_, op, exprs) if exprs.len() == 1 => { + Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { op.pretty(allocator).append(exprs[0].pretty(allocator)) } - Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { + Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { let left = exprs[0].pretty(allocator); let right = exprs[1].pretty(allocator); @@ -148,7 +171,7 @@ where .append(right) .parens() } - Expression::Primitive(_, op, exprs) => { + Expression::Primitive(_, _, op, exprs) => { allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) } } @@ -169,10 +192,10 @@ pub enum Primitive { Divide, } -impl<'a> TryFrom<&'a str> for Primitive { - type Error = String; +impl FromStr for Primitive { + type Err = String; - fn try_from(value: &str) -> Result { + fn from_str(value: &str) -> Result { match value { "+" => Ok(Primitive::Plus), "-" => Ok(Primitive::Minus), @@ -203,10 +226,10 @@ where /// This is the type used to guarantee that we don't nest expressions /// at this level. Instead, expressions that take arguments take one /// of these, which can only be a constant or a reference. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum ValueOrRef { - Value(Location, Value), - Ref(Location, ArcIntern), + Value(Location, Type, Value), + Ref(Location, Type, ArcIntern), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef @@ -216,8 +239,8 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - ValueOrRef::Value(_, v) => v.pretty(allocator), - ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()), + ValueOrRef::Value(_, _, v) => v.pretty(allocator), + ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), } } } @@ -246,6 +269,22 @@ pub enum Value { U64(Option, u64), } +impl Value { + /// Return the type described by this value + pub fn type_of(&self) -> Type { + match self { + Value::I8(_, _) => Type::Primitive(PrimitiveType::I8), + Value::I16(_, _) => Type::Primitive(PrimitiveType::I16), + Value::I32(_, _) => Type::Primitive(PrimitiveType::I32), + Value::I64(_, _) => Type::Primitive(PrimitiveType::I64), + Value::U8(_, _) => Type::Primitive(PrimitiveType::U8), + Value::U16(_, _) => Type::Primitive(PrimitiveType::U16), + Value::U32(_, _) => Type::Primitive(PrimitiveType::U32), + Value::U64(_, _) => Type::Primitive(PrimitiveType::U64), + } + } +} + impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value where A: 'a, @@ -289,8 +328,9 @@ where } } -#[derive(Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Type { + Variable(Location, ArcIntern), Primitive(PrimitiveType), } @@ -301,6 +341,7 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { + Type::Variable(_, x) => allocator.text(x.to_string()), Type::Primitive(pt) => allocator.text(format!("{}", pt)), } } @@ -309,6 +350,7 @@ where impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Type::Variable(_, x) => write!(f, "{}", x), Type::Primitive(pt) => pt.fmt(f), } } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 420c939..0d1b201 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -14,12 +14,12 @@ impl Program { for stmt in self.statements.iter() { match stmt { - Statement::Binding(_, name, value) => { + Statement::Binding(_, name, _, value) => { let actual_value = value.eval(&env)?; env = env.extend(name.clone(), actual_value); } - Statement::Print(_, name) => { + Statement::Print(_, _, name) => { let value = env.lookup(name.clone())?; let line = format!("{} = {}\n", name, value); stdout.push_str(&line); @@ -40,12 +40,16 @@ impl Expression { let value = valref.eval(env)?; match t { + Type::Variable(_, _) => unimplemented!("how to cast to a type variable?"), Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), } } - Expression::Primitive(_, op, args) => { - let arg_values = args.iter().map(|x| x.eval(env)).collect::, EvalError>>()?; + Expression::Primitive(_, _, op, args) => { + let arg_values = args + .iter() + .map(|x| x.eval(env)) + .collect::, EvalError>>()?; // and then finally we call `calculate` to run them. trust me, it's nice // to not have to deal with all the nonsense hidden under `calculate`. @@ -63,7 +67,7 @@ impl Expression { impl ValueOrRef { fn eval(&self, env: &EvalEnvironment) -> Result { match self { - ValueOrRef::Value(_, v) => match v { + ValueOrRef::Value(_, _, v) => match v { super::Value::I8(_, v) => Ok(Value::I8(*v)), super::Value::I16(_, v) => Ok(Value::I16(*v)), super::Value::I32(_, v) => Ok(Value::I32(*v)), @@ -74,7 +78,7 @@ impl ValueOrRef { super::Value::U64(_, v) => Ok(Value::U64(*v)), }, - ValueOrRef::Ref(_, n) => Ok(env.lookup(n.clone())?), + ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?), } } } @@ -82,7 +86,7 @@ impl ValueOrRef { #[test] fn two_plus_three() { let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); - let ir = Program::from(input); + let ir = input.type_infer(); let output = ir.eval().expect("runs successfully"); assert_eq!("x = 5u64\n", &output); } @@ -91,7 +95,7 @@ fn two_plus_three() { fn lotsa_math() { let input = crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); - let ir = Program::from(input); + let ir = input.type_infer(); let output = ir.eval().expect("runs successfully"); assert_eq!("x = 7u64\n", &output); } diff --git a/src/ir/from_syntax.rs b/src/ir/from_syntax.rs deleted file mode 100644 index 3f174c6..0000000 --- a/src/ir/from_syntax.rs +++ /dev/null @@ -1,213 +0,0 @@ -use internment::ArcIntern; -use std::str::FromStr; -use std::sync::atomic::AtomicUsize; - -use crate::eval::PrimitiveType; -use crate::ir::ast as ir; -use crate::syntax; - -use super::ValueOrRef; - -impl From for ir::Program { - /// We implement the top-level conversion of a syntax::Program into an - /// ir::Program using just the standard `From::from`, because we don't - /// need to return any arguments and we shouldn't produce any errors. - /// Technically there's an `unwrap` deep under the hood that we could - /// float out, but the validator really should've made sure that never - /// happens, so we're just going to assume. - fn from(mut value: syntax::Program) -> Self { - let mut statements = Vec::new(); - - for stmt in value.statements.drain(..) { - statements.append(&mut stmt.simplify()); - } - - ir::Program { statements } - } -} - -impl From for ir::Program { - /// One interesting thing about this conversion is that there isn't - /// a natural translation from syntax::Statement to ir::Statement, - /// because the syntax version can have nested expressions and the - /// IR version can't. - /// - /// As a result, we can naturally convert a syntax::Statement into - /// an ir::Program, because we can allow the additional binding - /// sites to be generated, instead. And, bonus, it turns out that - /// this is what we wanted anyways. - fn from(value: syntax::Statement) -> Self { - ir::Program { - statements: value.simplify(), - } - } -} - -impl syntax::Statement { - /// Simplify a syntax::Statement into a series of ir::Statements. - /// - /// The reason this function is one-to-many is because we may have to - /// introduce new binding sites in order to avoid having nested - /// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed - /// in syntax::Expression but are expressly *not* allowed in - /// ir::Expression. So this pass converts them into bindings, like - /// this: - /// - /// x = (1 + 2) * 3; - /// - /// ==> - /// - /// x:1 = 1 + 2; - /// x:2 = x:1 * 3; - /// x = x:2 - /// - /// Thus ensuring that things are nice and simple. Note that the - /// binding of `x:2` is not, strictly speaking, necessary, but it - /// makes the code below much easier to read. - fn simplify(self) -> Vec { - let mut new_statements = vec![]; - - match self { - // Print statements we don't have to do much with - syntax::Statement::Print(loc, name) => { - new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name))) - } - - // Bindings, however, may involve a single expression turning into - // a series of statements and then an expression. - syntax::Statement::Binding(loc, name, value) => { - let (mut prereqs, new_value) = value.rebind(&name); - new_statements.append(&mut prereqs); - new_statements.push(ir::Statement::Binding( - loc, - ArcIntern::new(name), - new_value.into(), - )) - } - } - - new_statements - } -} - -impl syntax::Expression { - /// This actually does the meat of the simplification work, here, by rebinding - /// any nested expressions into their own variables. We have this return - /// `ValueOrRef` in all cases because it makes for slighly less code; in the - /// case when we actually want an `Expression`, we can just use `into()`. - fn rebind(self, base_name: &str) -> (Vec, ir::ValueOrRef) { - match self { - // Values just convert in the obvious way, and require no prereqs - syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())), - - // Similarly, references just convert in the obvious way, and require - // no prereqs - syntax::Expression::Reference(loc, name) => { - (vec![], ValueOrRef::Ref(loc, ArcIntern::new(name))) - } - - syntax::Expression::Cast(loc, t, expr) => { - let (mut prereqs, new_expr) = expr.rebind(base_name); - let new_name = gensym(base_name); - prereqs.push(ir::Statement::Binding( - loc.clone(), - new_name.clone(), - ir::Expression::Cast( - loc.clone(), - ir::Type::Primitive(PrimitiveType::from_str(&t).unwrap()), - new_expr, - ), - )); - (prereqs, ValueOrRef::Ref(loc, new_name)) - } - - // Primitive expressions are where we do the real work. - syntax::Expression::Primitive(loc, prim, mut expressions) => { - // generate a fresh new name for the binding site we're going to - // introduce, basing the name on wherever we came from; so if this - // expression was bound to `x` originally, it might become `x:23`. - // - // gensym is guaranteed to give us a name that is unused anywhere - // else in the program. - let new_name = gensym(base_name); - let mut prereqs = Vec::new(); - let mut new_exprs = Vec::new(); - - // here we loop through every argument, and recurse on the expressions - // we find. that will give us any new binding sites that *they* introduce, - // and a simple value or reference that we can use in our result. - for expr in expressions.drain(..) { - let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str()); - prereqs.append(&mut cur_prereqs); - new_exprs.push(arg); - } - - // now we're going to use those new arguments to run the primitive, binding - // the results to the new variable we introduced. - let prim = - ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function"); - prereqs.push(ir::Statement::Binding( - loc.clone(), - new_name.clone(), - ir::Expression::Primitive(loc.clone(), prim, new_exprs), - )); - - // and finally, we can return all the new bindings, and a reference to - // the variable we just introduced to hold the value of the primitive - // invocation. - (prereqs, ValueOrRef::Ref(loc, new_name)) - } - } - } -} - -impl From for ir::Value { - fn from(value: syntax::Value) -> Self { - match value { - syntax::Value::Number(base, ty, val) => match ty { - None => ir::Value::U64(base, val), - Some(syntax::ConstantType::I8) => ir::Value::I8(base, val as i8), - Some(syntax::ConstantType::I16) => ir::Value::I16(base, val as i16), - Some(syntax::ConstantType::I32) => ir::Value::I32(base, val as i32), - Some(syntax::ConstantType::I64) => ir::Value::I64(base, val as i64), - Some(syntax::ConstantType::U8) => ir::Value::U8(base, val as u8), - Some(syntax::ConstantType::U16) => ir::Value::U16(base, val as u16), - Some(syntax::ConstantType::U32) => ir::Value::U32(base, val as u32), - Some(syntax::ConstantType::U64) => ir::Value::U64(base, val), - }, - } - } -} - -impl From for ir::Primitive { - fn from(value: String) -> Self { - value.try_into().unwrap() - } -} - -/// Generate a fresh new name based on the given name. -/// -/// The new name is guaranteed to be unique across the entirety of the -/// execution. This is achieved by using characters in the variable name -/// that would not be valid input, and by including a counter that is -/// incremented on every invocation. -fn gensym(name: &str) -> ArcIntern { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let new_name = format!( - "<{}:{}>", - name, - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - ); - ArcIntern::new(new_name) -} - -proptest::proptest! { - #[test] - fn translation_maintains_semantics(input: syntax::Program) { - let syntax_result = input.eval(); - let ir = ir::Program::from(input); - let ir_result = ir.eval(); - assert_eq!(syntax_result, ir_result); - } -} diff --git a/src/ir/strings.rs b/src/ir/strings.rs index f7b291e..70f939f 100644 --- a/src/ir/strings.rs +++ b/src/ir/strings.rs @@ -21,12 +21,12 @@ impl Program { impl Statement { fn register_strings(&self, string_set: &mut HashSet>) { match self { - Statement::Binding(_, name, expr) => { + Statement::Binding(_, name, _, expr) => { string_set.insert(name.clone()); expr.register_strings(string_set); } - Statement::Print(_, name) => { + Statement::Print(_, _, name) => { string_set.insert(name.clone()); } } diff --git a/src/ir/type_infer.rs b/src/ir/type_infer.rs new file mode 100644 index 0000000..92cad82 --- /dev/null +++ b/src/ir/type_infer.rs @@ -0,0 +1,293 @@ +use internment::ArcIntern; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::atomic::AtomicUsize; + +use crate::eval::PrimitiveType; +use crate::ir::ast as ir; +use crate::ir::ast::Type; +use crate::syntax::{self, ConstantType, Location}; + +enum Constraint { + /// The given type must be printable using the `print` built-in + Printable(Location, Type), + /// The provided numeric value fits in the given constant type + FitsInNumType(Location, ConstantType, u64), + /// The given primitive has the proper arguments types associated with it + ProperPrimitiveArgs(Location, ir::Primitive, Vec, Type), + /// The given type can be casted to the target type safely + CanCastTo(Location, Type, Type), +} + +/// This function takes a syntactic program and converts it into the IR version of the +/// program, with appropriate type variables introduced and their constraints added to +/// the given database. +/// +/// If the input function has been validated (which it should be), then this should run +/// into no error conditions. However, if you failed to validate the input, then this +/// function can panic. +fn convert_program( + mut program: syntax::Program, + constraint_db: &mut Vec, +) -> ir::Program { + let mut statements = Vec::new(); + let mut renames = HashMap::new(); + let mut bindings = HashMap::new(); + + for stmt in program.statements.drain(..) { + statements.append(&mut convert_statement( + stmt, + constraint_db, + &mut renames, + &mut bindings, + )); + } + + ir::Program { statements } +} + +/// This function takes a syntactic statements and converts it into a series of +/// IR statements, adding type variables and constraints as necessary. +/// +/// We generate a series of statements because we're going to flatten all +/// incoming expressions so that they are no longer recursive. This will +/// generate a bunch of new bindings for all the subexpressions, which we +/// return as a bundle. +/// +/// See the safety warning on [`convert_program`]! This function assumes that +/// you have run [`Statement::validate`], and will trigger panics in error +/// conditions if you have run that and had it come back clean. +fn convert_statement( + statement: syntax::Statement, + constraint_db: &mut Vec, + renames: &mut HashMap, ArcIntern>, + bindings: &mut HashMap, Type>, +) -> Vec { + match statement { + syntax::Statement::Print(loc, name) => { + let iname = ArcIntern::new(name); + let final_name = renames.get(&iname).map(Clone::clone).unwrap_or_else(|| iname.clone()); + let varty = bindings + .get(&final_name) + .expect("print variable defined before use") + .clone(); + + constraint_db.push(Constraint::Printable(loc.clone(), varty.clone())); + + vec![ir::Statement::Print(loc, varty, iname)] + } + + syntax::Statement::Binding(loc, name, expr) => { + let (mut prereqs, expr, ty) = + convert_expression(expr, constraint_db, renames, bindings); + let iname = ArcIntern::new(name); + let final_name = if bindings.contains_key(&iname) { + let new_name = gensym(iname.as_str()); + renames.insert(iname, new_name.clone()); + new_name + } else { + iname + }; + + bindings.insert(final_name.clone(), ty.clone()); + prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); + prereqs + } + } +} + +/// This function takes a syntactic expression and converts it into a series +/// of IR statements, adding type variables and constraints as necessary. +/// +/// We generate a series of statements because we're going to flatten all +/// incoming expressions so that they are no longer recursive. This will +/// generate a bunch of new bindings for all the subexpressions, which we +/// return as a bundle. +/// +/// See the safety warning on [`convert_program`]! This function assumes that +/// you have run [`Statement::validate`], and will trigger panics in error +/// conditions if you have run that and had it come back clean. +fn convert_expression( + expression: syntax::Expression, + constraint_db: &mut Vec, + renames: &HashMap, ArcIntern>, + bindings: &mut HashMap, Type>, +) -> (Vec, ir::Expression, Type) { + match expression { + syntax::Expression::Value(loc, val) => { + let newval = match val { + syntax::Value::Number(base, mctype, value) => { + if let Some(suggested_type) = mctype { + constraint_db.push(Constraint::FitsInNumType(loc.clone(), suggested_type, value)); + } + + match mctype { + None => ir::Value::U64(base, value), + Some(ConstantType::U8) => ir::Value::U8(base, value as u8), + Some(ConstantType::U16) => ir::Value::U16(base, value as u16), + Some(ConstantType::U32) => ir::Value::U32(base, value as u32), + Some(ConstantType::U64) => ir::Value::U64(base, value), + Some(ConstantType::I8) => ir::Value::I8(base, value as i8), + Some(ConstantType::I16) => ir::Value::I16(base, value as i16), + Some(ConstantType::I32) => ir::Value::I32(base, value as i32), + Some(ConstantType::I64) => ir::Value::I64(base, value as i64), + } + } + }; + let valtype = newval.type_of(); + + ( + vec![], + ir::Expression::Atomic(ir::ValueOrRef::Value(loc, valtype.clone(), newval)), + valtype, + ) + } + + syntax::Expression::Reference(loc, name) => { + let iname = ArcIntern::new(name); + let final_name = renames.get(&iname).cloned().unwrap_or(iname); + let rtype = bindings + .get(&final_name) + .cloned() + .expect("variable bound before use"); + let refexp = + ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name)); + + (vec![], refexp, rtype) + } + + syntax::Expression::Cast(loc, target, expr) => { + let (mut stmts, nexpr, etype) = + convert_expression(*expr, constraint_db, renames, bindings); + let val_or_ref = simplify_expr(nexpr, &mut stmts); + let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast"); + let target_type = Type::Primitive(target_prim_type); + let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); + + constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); + + (stmts, res, target_type) + } + + syntax::Expression::Primitive(loc, fun, mut args) => { + let primop = ir::Primitive::from_str(&fun).expect("valid primitive"); + let mut stmts = vec![]; + let mut nargs = vec![]; + let mut atypes = vec![]; + let ret_type = gentype(); + + for arg in args.drain(..) { + let (mut astmts, aexp, atype) = convert_expression(arg, constraint_db, renames, bindings); + + stmts.append(&mut astmts); + nargs.push(simplify_expr(aexp, &mut stmts)); + atypes.push(atype); + } + + constraint_db.push(Constraint::ProperPrimitiveArgs(loc.clone(), primop, atypes.clone(), ret_type.clone())); + + (stmts, ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs), ret_type) + } + } +} + +fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::ValueOrRef { + match expr { + ir::Expression::Atomic(v_or_ref) => v_or_ref, + expr => { + let etype = expr.type_of().clone(); + let loc = expr.location().clone(); + let nname = gensym("g"); + let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr); + + stmts.push(nbinding); + ir::ValueOrRef::Ref(loc, etype, nname) + } + } +} + +impl syntax::Program { + /// Infer the types for the syntactic AST, returning either a type-checked program in + /// the IR, or a series of type errors encountered during inference. + /// + /// You really should have made sure that this program was validated before running + /// this method, otherwise you may experience panics during operation. + pub fn type_infer(self) -> ir::Program { + let mut constraint_db = vec![]; + let program = convert_program(self, &mut constraint_db); + let mut changed_something = true; + + // We want to run this inference endlessly, until either we have solved all of our + // constraints or we've gotten stuck somewhere. + while constraint_db.len() > 0 && changed_something { + // Set this to false at the top of the loop. We'll set this to true if we make + // progress in any way further down, but having this here prevents us from going + // into an infinite look when we can't figure stuff out. + changed_something = false; + // This is sort of a double-buffering thing; we're going to rename constraint_db + // and then set it to a new empty vector, which we'll add to as we find new + // constraints or find ourselves unable to solve existing ones. + let mut local_constraints = constraint_db; + constraint_db = vec![]; + + for constraint in local_constraints.drain(..) { + match constraint { + // Currently, all of our types are printable + Constraint::Printable(_loc, _ty) => {} + + Constraint::FitsInNumType(loc, ctype, val) => unimplemented!(), + + Constraint::ProperPrimitiveArgs(loc, prim, args, ret) => unimplemented!(), + + Constraint::CanCastTo(loc, from_type, to_type) => unimplemented!(), + } + } + } + + program + } +} + +/// Generate a fresh new name based on the given name. +/// +/// The new name is guaranteed to be unique across the entirety of the +/// execution. This is achieved by using characters in the variable name +/// that would not be valid input, and by including a counter that is +/// incremented on every invocation. +fn gensym(name: &str) -> ArcIntern { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let new_name = format!( + "<{}:{}>", + name, + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + ); + ArcIntern::new(new_name) +} + +/// Generate a fresh new type; this will be a unique new type variable. +/// +/// The new name is guaranteed to be unique across the entirety of the +/// execution. This is achieved by using characters in the variable name +/// that would not be valid input, and by including a counter that is +/// incremented on every invocation. +fn gentype() -> Type { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let new_name = ArcIntern::new(format!( + "t<{}>", + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + )); + + Type::Variable(Location::manufactured(), new_name) +} + +proptest::proptest! { + #[test] + fn translation_maintains_semantics(input: syntax::Program) { + let syntax_result = input.eval(); + let ir = input.type_infer(); + let ir_result = ir.eval(); + assert_eq!(syntax_result, ir_result); + } +} diff --git a/src/repl.rs b/src/repl.rs index d199fbb..4e32694 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,5 +1,4 @@ use crate::backend::{Backend, BackendError}; -use crate::ir::Program as IR; use crate::syntax::{ConstantType, Location, ParserError, Statement}; use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::files::SimpleFiles; @@ -155,7 +154,10 @@ impl REPL { return Ok(()); } - let ir = IR::from(syntax); + let ir = crate::syntax::Program { + statements: vec![syntax], + } + .type_infer(); let name = format!("line{}", line_no); let function_id = self.jitter.compile_function(&name, ir)?; self.jitter.module.finalize_definitions()?; diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index 9223e97..cdfca71 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -1,9 +1,10 @@ -use crate::{syntax::{Expression, Location, Program, Statement}, eval::PrimitiveType}; +use crate::{ + eval::PrimitiveType, + syntax::{Expression, Location, Program, Statement}, +}; use codespan_reporting::diagnostic::Diagnostic; use std::{collections::HashMap, str::FromStr}; -use super::location; - /// An error we found while validating the input program. /// /// These errors indicate that we should stop trying to compile @@ -136,7 +137,7 @@ impl Expression { ), Expression::Cast(location, t, expr) => { let (mut errs, warns) = expr.validate(variable_map); - + if PrimitiveType::from_str(t).is_err() { errs.push(Error::UnknownType(location.clone(), t.clone())) } @@ -173,4 +174,4 @@ fn cast_checks_are_reasonable() { assert!(bad_warns.is_empty()); assert_eq!(bad_errs.len(), 1); assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple")); -} \ No newline at end of file +}