[checkpoint] Start the switch to type inference.

This commit is contained in:
2023-06-19 21:16:28 -07:00
parent b931ba5b17
commit 3687785540
10 changed files with 424 additions and 284 deletions

View File

@@ -125,7 +125,7 @@ impl<M: Module> Backend<M> {
// Print statements are fairly easy to compile: we just lookup the // Print statements are fairly easy to compile: we just lookup the
// output buffer, the address of the string to print, and the value // output buffer, the address of the string to print, and the value
// of whatever variable we're printing. Then we just call print. // of whatever variable we're printing. Then we just call print.
Statement::Print(ann, var) => { Statement::Print(ann, t, var) => {
// Get the output buffer (or null) from our general compilation context. // Get the output buffer (or null) from our general compilation context.
let buffer_ptr = self.output_buffer_ptr(); let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
@@ -137,7 +137,7 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a // Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn // global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation. // this into an `Expression` and re-use the logic in that implementation.
let (val, vtype) = ValueOrRef::Ref(ann, var).into_crane( let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane(
&mut builder, &mut builder,
&variable_table, &variable_table,
&pre_defined_symbols, &pre_defined_symbols,
@@ -163,7 +163,7 @@ impl<M: Module> Backend<M> {
} }
// Variable binding is a little more con // Variable binding is a little more con
Statement::Binding(_, var_name, value) => { Statement::Binding(_, var_name, _, value) => {
// Kick off to the `Expression` implementation to see what value we're going // Kick off to the `Expression` implementation to see what value we're going
// to bind to this variable. // to bind to this variable.
let (val, etype) = let (val, etype) =
@@ -254,50 +254,62 @@ impl Expression {
Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables), Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables),
Expression::Cast(_, target_type, expr) => { Expression::Cast(_, target_type, expr) => {
let (val, val_type) = expr.into_crane(builder, local_variables, global_variables)?; let (val, val_type) =
expr.into_crane(builder, local_variables, global_variables)?;
match (val_type, &target_type) { match (val_type, &target_type) {
(ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)), (ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)),
(ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => (ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().sextend(types::I16, val), ConstantType::I16)), Ok((builder.ins().sextend(types::I16, val), ConstantType::I16))
(ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => }
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)), (ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => {
(ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)), }
(ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)), (ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)),
(ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => (ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)), Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
(ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => }
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)), (ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)), (ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)),
(ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => (ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)), Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)), (ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)), (ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => (ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => {
Ok((builder.ins().uextend(types::I16, val), ConstantType::U16)), Ok((builder.ins().uextend(types::I16, val), ConstantType::U16))
(ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => }
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)), (ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => {
(ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)), }
(ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)), (ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)),
(ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => (ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)), Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
(ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => }
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)), (ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)), (ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)),
(ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => (ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)), Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)), (ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
_ => Err(BackendError::InvalidTypeCast { _ => Err(BackendError::InvalidTypeCast {
from: val_type.into(), from: val_type.into(),
to: target_type, to: target_type,
@@ -305,7 +317,7 @@ impl Expression {
} }
} }
Expression::Primitive(_, prim, mut vals) => { Expression::Primitive(_, _, prim, mut vals) => {
let mut values = vec![]; let mut values = vec![];
let mut first_type = None; let mut first_type = None;
@@ -357,7 +369,7 @@ impl ValueOrRef {
match self { match self {
// Values are pretty straightforward to compile, mostly because we only // Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type. // have one type of variable, and it's an integer type.
ValueOrRef::Value(_, val) => match val { ValueOrRef::Value(_, _, val) => match val {
Value::I8(_, v) => { Value::I8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8)) Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8))
} }
@@ -387,7 +399,7 @@ impl ValueOrRef {
)), )),
}, },
ValueOrRef::Ref(_, name) => { ValueOrRef::Ref(_, _, name) => {
// first we see if this is a local variable (which is nicer, from an // first we see if this is a local variable (which is nicer, from an
// optimization point of view.) // optimization point of view.)
if let Some((local_var, etype)) = local_variables.get(&name) { if let Some((local_var, etype)) = local_variables.get(&name) {

View File

@@ -1,5 +1,4 @@
use crate::backend::Backend; use crate::backend::Backend;
use crate::ir::Program as IR;
use crate::syntax::Program as Syntax; use crate::syntax::Program as Syntax;
use codespan_reporting::{ use codespan_reporting::{
diagnostic::Diagnostic, diagnostic::Diagnostic,
@@ -101,7 +100,7 @@ impl Compiler {
} }
// Now that we've validated it, turn it into IR. // Now that we've validated it, turn it into IR.
let ir = IR::from(syntax); let ir = syntax.type_infer();
// Finally, send all this to Cranelift for conversion into an object file. // Finally, send all this to Cranelift for conversion into an object file.
let mut backend = Backend::object_file(Triple::host())?; let mut backend = Backend::object_file(Triple::host())?;

View File

@@ -14,7 +14,7 @@
//! come to for analysis and optimization work. //! come to for analysis and optimization work.
mod ast; mod ast;
mod eval; mod eval;
mod from_syntax;
mod strings; mod strings;
mod type_infer;
pub use ast::*; pub use ast::*;

View File

@@ -8,7 +8,7 @@ use proptest::{
prelude::Arbitrary, prelude::Arbitrary,
strategy::{BoxedStrategy, Strategy}, strategy::{BoxedStrategy, Strategy},
}; };
use std::fmt; use std::{fmt, str::FromStr};
/// We're going to represent variables as interned strings. /// We're going to represent variables as interned strings.
/// ///
@@ -61,7 +61,7 @@ impl Arbitrary for Program {
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
crate::syntax::Program::arbitrary_with(args) crate::syntax::Program::arbitrary_with(args)
.prop_map(Program::from) .prop_map(syntax::Program::type_infer)
.boxed() .boxed()
} }
} }
@@ -78,8 +78,8 @@ impl Arbitrary for Program {
/// ///
#[derive(Debug)] #[derive(Debug)]
pub enum Statement { pub enum Statement {
Binding(Location, Variable, Expression), Binding(Location, Variable, Type, Expression),
Print(Location, Variable), Print(Location, Type, Variable),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
@@ -89,13 +89,13 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
Statement::Binding(_, var, expr) => allocator Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string()) .text(var.as_ref().to_string())
.append(allocator.space()) .append(allocator.space())
.append(allocator.text("=")) .append(allocator.text("="))
.append(allocator.space()) .append(allocator.space())
.append(expr.pretty(allocator)), .append(expr.pretty(allocator)),
Statement::Print(_, var) => allocator Statement::Print(_, _, var) => allocator
.text("print") .text("print")
.append(allocator.space()) .append(allocator.space())
.append(allocator.text(var.as_ref().to_string())), .append(allocator.text(var.as_ref().to_string())),
@@ -119,7 +119,30 @@ where
pub enum Expression { pub enum Expression {
Atomic(ValueOrRef), Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef), Cast(Location, Type, ValueOrRef),
Primitive(Location, Primitive, Vec<ValueOrRef>), Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
@@ -135,10 +158,10 @@ where
.append(t.pretty(allocator)) .append(t.pretty(allocator))
.append(allocator.text(">")) .append(allocator.text(">"))
.append(e.pretty(allocator)), .append(e.pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => { Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator)) op.pretty(allocator).append(exprs[0].pretty(allocator))
} }
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator); let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator); let right = exprs[1].pretty(allocator);
@@ -148,7 +171,7 @@ where
.append(right) .append(right)
.parens() .parens()
} }
Expression::Primitive(_, op, exprs) => { Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
} }
} }
@@ -169,10 +192,10 @@ pub enum Primitive {
Divide, Divide,
} }
impl<'a> TryFrom<&'a str> for Primitive { impl FromStr for Primitive {
type Error = String; type Err = String;
fn try_from(value: &str) -> Result<Self, Self::Error> { fn from_str(value: &str) -> Result<Self, Self::Err> {
match value { match value {
"+" => Ok(Primitive::Plus), "+" => Ok(Primitive::Plus),
"-" => Ok(Primitive::Minus), "-" => Ok(Primitive::Minus),
@@ -203,10 +226,10 @@ where
/// This is the type used to guarantee that we don't nest expressions /// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one /// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference. /// of these, which can only be a constant or a reference.
#[derive(Debug)] #[derive(Clone, Debug)]
pub enum ValueOrRef { pub enum ValueOrRef {
Value(Location, Value), Value(Location, Type, Value),
Ref(Location, ArcIntern<String>), Ref(Location, Type, ArcIntern<String>),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
@@ -216,8 +239,8 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
ValueOrRef::Value(_, v) => v.pretty(allocator), ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()), ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
} }
} }
} }
@@ -246,6 +269,22 @@ pub enum Value {
U64(Option<u8>, u64), U64(Option<u8>, u64),
} }
impl Value {
/// Return the type described by this value
pub fn type_of(&self) -> Type {
match self {
Value::I8(_, _) => Type::Primitive(PrimitiveType::I8),
Value::I16(_, _) => Type::Primitive(PrimitiveType::I16),
Value::I32(_, _) => Type::Primitive(PrimitiveType::I32),
Value::I64(_, _) => Type::Primitive(PrimitiveType::I64),
Value::U8(_, _) => Type::Primitive(PrimitiveType::U8),
Value::U16(_, _) => Type::Primitive(PrimitiveType::U16),
Value::U32(_, _) => Type::Primitive(PrimitiveType::U32),
Value::U64(_, _) => Type::Primitive(PrimitiveType::U64),
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where where
A: 'a, A: 'a,
@@ -289,8 +328,9 @@ where
} }
} }
#[derive(Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type { pub enum Type {
Variable(Location, ArcIntern<String>),
Primitive(PrimitiveType), Primitive(PrimitiveType),
} }
@@ -301,6 +341,7 @@ where
{ {
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
Type::Variable(_, x) => allocator.text(x.to_string()),
Type::Primitive(pt) => allocator.text(format!("{}", pt)), Type::Primitive(pt) => allocator.text(format!("{}", pt)),
} }
} }
@@ -309,6 +350,7 @@ where
impl fmt::Display for Type { impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Type::Variable(_, x) => write!(f, "{}", x),
Type::Primitive(pt) => pt.fmt(f), Type::Primitive(pt) => pt.fmt(f),
} }
} }

View File

@@ -14,12 +14,12 @@ impl Program {
for stmt in self.statements.iter() { for stmt in self.statements.iter() {
match stmt { match stmt {
Statement::Binding(_, name, value) => { Statement::Binding(_, name, _, value) => {
let actual_value = value.eval(&env)?; let actual_value = value.eval(&env)?;
env = env.extend(name.clone(), actual_value); env = env.extend(name.clone(), actual_value);
} }
Statement::Print(_, name) => { Statement::Print(_, _, name) => {
let value = env.lookup(name.clone())?; let value = env.lookup(name.clone())?;
let line = format!("{} = {}\n", name, value); let line = format!("{} = {}\n", name, value);
stdout.push_str(&line); stdout.push_str(&line);
@@ -40,12 +40,16 @@ impl Expression {
let value = valref.eval(env)?; let value = valref.eval(env)?;
match t { match t {
Type::Variable(_, _) => unimplemented!("how to cast to a type variable?"),
Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
} }
} }
Expression::Primitive(_, op, args) => { Expression::Primitive(_, _, op, args) => {
let arg_values = args.iter().map(|x| x.eval(env)).collect::<Result<Vec<Value>, EvalError>>()?; let arg_values = args
.iter()
.map(|x| x.eval(env))
.collect::<Result<Vec<Value>, EvalError>>()?;
// and then finally we call `calculate` to run them. trust me, it's nice // and then finally we call `calculate` to run them. trust me, it's nice
// to not have to deal with all the nonsense hidden under `calculate`. // to not have to deal with all the nonsense hidden under `calculate`.
@@ -63,7 +67,7 @@ impl Expression {
impl ValueOrRef { impl ValueOrRef {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self { match self {
ValueOrRef::Value(_, v) => match v { ValueOrRef::Value(_, _, v) => match v {
super::Value::I8(_, v) => Ok(Value::I8(*v)), super::Value::I8(_, v) => Ok(Value::I8(*v)),
super::Value::I16(_, v) => Ok(Value::I16(*v)), super::Value::I16(_, v) => Ok(Value::I16(*v)),
super::Value::I32(_, v) => Ok(Value::I32(*v)), super::Value::I32(_, v) => Ok(Value::I32(*v)),
@@ -74,7 +78,7 @@ impl ValueOrRef {
super::Value::U64(_, v) => Ok(Value::U64(*v)), super::Value::U64(_, v) => Ok(Value::U64(*v)),
}, },
ValueOrRef::Ref(_, n) => Ok(env.lookup(n.clone())?), ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?),
} }
} }
} }
@@ -82,7 +86,7 @@ impl ValueOrRef {
#[test] #[test]
fn two_plus_three() { fn two_plus_three() {
let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let ir = Program::from(input); let ir = input.type_infer();
let output = ir.eval().expect("runs successfully"); let output = ir.eval().expect("runs successfully");
assert_eq!("x = 5u64\n", &output); assert_eq!("x = 5u64\n", &output);
} }
@@ -91,7 +95,7 @@ fn two_plus_three() {
fn lotsa_math() { fn lotsa_math() {
let input = let input =
crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let ir = Program::from(input); let ir = input.type_infer();
let output = ir.eval().expect("runs successfully"); let output = ir.eval().expect("runs successfully");
assert_eq!("x = 7u64\n", &output); assert_eq!("x = 7u64\n", &output);
} }

View File

@@ -1,213 +0,0 @@
use internment::ArcIntern;
use std::str::FromStr;
use std::sync::atomic::AtomicUsize;
use crate::eval::PrimitiveType;
use crate::ir::ast as ir;
use crate::syntax;
use super::ValueOrRef;
impl From<syntax::Program> for ir::Program {
/// We implement the top-level conversion of a syntax::Program into an
/// ir::Program using just the standard `From::from`, because we don't
/// need to return any arguments and we shouldn't produce any errors.
/// Technically there's an `unwrap` deep under the hood that we could
/// float out, but the validator really should've made sure that never
/// happens, so we're just going to assume.
fn from(mut value: syntax::Program) -> Self {
let mut statements = Vec::new();
for stmt in value.statements.drain(..) {
statements.append(&mut stmt.simplify());
}
ir::Program { statements }
}
}
impl From<syntax::Statement> for ir::Program {
/// One interesting thing about this conversion is that there isn't
/// a natural translation from syntax::Statement to ir::Statement,
/// because the syntax version can have nested expressions and the
/// IR version can't.
///
/// As a result, we can naturally convert a syntax::Statement into
/// an ir::Program, because we can allow the additional binding
/// sites to be generated, instead. And, bonus, it turns out that
/// this is what we wanted anyways.
fn from(value: syntax::Statement) -> Self {
ir::Program {
statements: value.simplify(),
}
}
}
impl syntax::Statement {
/// Simplify a syntax::Statement into a series of ir::Statements.
///
/// The reason this function is one-to-many is because we may have to
/// introduce new binding sites in order to avoid having nested
/// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed
/// in syntax::Expression but are expressly *not* allowed in
/// ir::Expression. So this pass converts them into bindings, like
/// this:
///
/// x = (1 + 2) * 3;
///
/// ==>
///
/// x:1 = 1 + 2;
/// x:2 = x:1 * 3;
/// x = x:2
///
/// Thus ensuring that things are nice and simple. Note that the
/// binding of `x:2` is not, strictly speaking, necessary, but it
/// makes the code below much easier to read.
fn simplify(self) -> Vec<ir::Statement> {
let mut new_statements = vec![];
match self {
// Print statements we don't have to do much with
syntax::Statement::Print(loc, name) => {
new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name)))
}
// Bindings, however, may involve a single expression turning into
// a series of statements and then an expression.
syntax::Statement::Binding(loc, name, value) => {
let (mut prereqs, new_value) = value.rebind(&name);
new_statements.append(&mut prereqs);
new_statements.push(ir::Statement::Binding(
loc,
ArcIntern::new(name),
new_value.into(),
))
}
}
new_statements
}
}
impl syntax::Expression {
/// This actually does the meat of the simplification work, here, by rebinding
/// any nested expressions into their own variables. We have this return
/// `ValueOrRef` in all cases because it makes for slighly less code; in the
/// case when we actually want an `Expression`, we can just use `into()`.
fn rebind(self, base_name: &str) -> (Vec<ir::Statement>, ir::ValueOrRef) {
match self {
// Values just convert in the obvious way, and require no prereqs
syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())),
// Similarly, references just convert in the obvious way, and require
// no prereqs
syntax::Expression::Reference(loc, name) => {
(vec![], ValueOrRef::Ref(loc, ArcIntern::new(name)))
}
syntax::Expression::Cast(loc, t, expr) => {
let (mut prereqs, new_expr) = expr.rebind(base_name);
let new_name = gensym(base_name);
prereqs.push(ir::Statement::Binding(
loc.clone(),
new_name.clone(),
ir::Expression::Cast(
loc.clone(),
ir::Type::Primitive(PrimitiveType::from_str(&t).unwrap()),
new_expr,
),
));
(prereqs, ValueOrRef::Ref(loc, new_name))
}
// Primitive expressions are where we do the real work.
syntax::Expression::Primitive(loc, prim, mut expressions) => {
// generate a fresh new name for the binding site we're going to
// introduce, basing the name on wherever we came from; so if this
// expression was bound to `x` originally, it might become `x:23`.
//
// gensym is guaranteed to give us a name that is unused anywhere
// else in the program.
let new_name = gensym(base_name);
let mut prereqs = Vec::new();
let mut new_exprs = Vec::new();
// here we loop through every argument, and recurse on the expressions
// we find. that will give us any new binding sites that *they* introduce,
// and a simple value or reference that we can use in our result.
for expr in expressions.drain(..) {
let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str());
prereqs.append(&mut cur_prereqs);
new_exprs.push(arg);
}
// now we're going to use those new arguments to run the primitive, binding
// the results to the new variable we introduced.
let prim =
ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function");
prereqs.push(ir::Statement::Binding(
loc.clone(),
new_name.clone(),
ir::Expression::Primitive(loc.clone(), prim, new_exprs),
));
// and finally, we can return all the new bindings, and a reference to
// the variable we just introduced to hold the value of the primitive
// invocation.
(prereqs, ValueOrRef::Ref(loc, new_name))
}
}
}
}
impl From<syntax::Value> for ir::Value {
fn from(value: syntax::Value) -> Self {
match value {
syntax::Value::Number(base, ty, val) => match ty {
None => ir::Value::U64(base, val),
Some(syntax::ConstantType::I8) => ir::Value::I8(base, val as i8),
Some(syntax::ConstantType::I16) => ir::Value::I16(base, val as i16),
Some(syntax::ConstantType::I32) => ir::Value::I32(base, val as i32),
Some(syntax::ConstantType::I64) => ir::Value::I64(base, val as i64),
Some(syntax::ConstantType::U8) => ir::Value::U8(base, val as u8),
Some(syntax::ConstantType::U16) => ir::Value::U16(base, val as u16),
Some(syntax::ConstantType::U32) => ir::Value::U32(base, val as u32),
Some(syntax::ConstantType::U64) => ir::Value::U64(base, val),
},
}
}
}
impl From<String> for ir::Primitive {
fn from(value: String) -> Self {
value.try_into().unwrap()
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input: syntax::Program) {
let syntax_result = input.eval();
let ir = ir::Program::from(input);
let ir_result = ir.eval();
assert_eq!(syntax_result, ir_result);
}
}

View File

@@ -21,12 +21,12 @@ impl Program {
impl Statement { impl Statement {
fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) { fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) {
match self { match self {
Statement::Binding(_, name, expr) => { Statement::Binding(_, name, _, expr) => {
string_set.insert(name.clone()); string_set.insert(name.clone());
expr.register_strings(string_set); expr.register_strings(string_set);
} }
Statement::Print(_, name) => { Statement::Print(_, _, name) => {
string_set.insert(name.clone()); string_set.insert(name.clone());
} }
} }

293
src/ir/type_infer.rs Normal file
View File

@@ -0,0 +1,293 @@
use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::atomic::AtomicUsize;
use crate::eval::PrimitiveType;
use crate::ir::ast as ir;
use crate::ir::ast::Type;
use crate::syntax::{self, ConstantType, Location};
enum Constraint {
/// The given type must be printable using the `print` built-in
Printable(Location, Type),
/// The provided numeric value fits in the given constant type
FitsInNumType(Location, ConstantType, u64),
/// The given primitive has the proper arguments types associated with it
ProperPrimitiveArgs(Location, ir::Primitive, Vec<Type>, Type),
/// The given type can be casted to the target type safely
CanCastTo(Location, Type, Type),
}
/// This function takes a syntactic program and converts it into the IR version of the
/// program, with appropriate type variables introduced and their constraints added to
/// the given database.
///
/// If the input function has been validated (which it should be), then this should run
/// into no error conditions. However, if you failed to validate the input, then this
/// function can panic.
fn convert_program(
mut program: syntax::Program,
constraint_db: &mut Vec<Constraint>,
) -> ir::Program {
let mut statements = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
for stmt in program.statements.drain(..) {
statements.append(&mut convert_statement(
stmt,
constraint_db,
&mut renames,
&mut bindings,
));
}
ir::Program { statements }
}
/// This function takes a syntactic statements and converts it into a series of
/// IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> Vec<ir::Statement> {
match statement {
syntax::Statement::Print(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames.get(&iname).map(Clone::clone).unwrap_or_else(|| iname.clone());
let varty = bindings
.get(&final_name)
.expect("print variable defined before use")
.clone();
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
vec![ir::Statement::Print(loc, varty, iname)]
}
syntax::Statement::Binding(loc, name, expr) => {
let (mut prereqs, expr, ty) =
convert_expression(expr, constraint_db, renames, bindings);
let iname = ArcIntern::new(name);
let final_name = if bindings.contains_key(&iname) {
let new_name = gensym(iname.as_str());
renames.insert(iname, new_name.clone());
new_name
} else {
iname
};
bindings.insert(final_name.clone(), ty.clone());
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr));
prereqs
}
}
}
/// This function takes a syntactic expression and converts it into a series
/// of IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_expression(
expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>,
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> (Vec<ir::Statement>, ir::Expression, Type) {
match expression {
syntax::Expression::Value(loc, val) => {
let newval = match val {
syntax::Value::Number(base, mctype, value) => {
if let Some(suggested_type) = mctype {
constraint_db.push(Constraint::FitsInNumType(loc.clone(), suggested_type, value));
}
match mctype {
None => ir::Value::U64(base, value),
Some(ConstantType::U8) => ir::Value::U8(base, value as u8),
Some(ConstantType::U16) => ir::Value::U16(base, value as u16),
Some(ConstantType::U32) => ir::Value::U32(base, value as u32),
Some(ConstantType::U64) => ir::Value::U64(base, value),
Some(ConstantType::I8) => ir::Value::I8(base, value as i8),
Some(ConstantType::I16) => ir::Value::I16(base, value as i16),
Some(ConstantType::I32) => ir::Value::I32(base, value as i32),
Some(ConstantType::I64) => ir::Value::I64(base, value as i64),
}
}
};
let valtype = newval.type_of();
(
vec![],
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, valtype.clone(), newval)),
valtype,
)
}
syntax::Expression::Reference(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames.get(&iname).cloned().unwrap_or(iname);
let rtype = bindings
.get(&final_name)
.cloned()
.expect("variable bound before use");
let refexp =
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
(vec![], refexp, rtype)
}
syntax::Expression::Cast(loc, target, expr) => {
let (mut stmts, nexpr, etype) =
convert_expression(*expr, constraint_db, renames, bindings);
let val_or_ref = simplify_expr(nexpr, &mut stmts);
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast");
let target_type = Type::Primitive(target_prim_type);
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone()));
(stmts, res, target_type)
}
syntax::Expression::Primitive(loc, fun, mut args) => {
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
let mut stmts = vec![];
let mut nargs = vec![];
let mut atypes = vec![];
let ret_type = gentype();
for arg in args.drain(..) {
let (mut astmts, aexp, atype) = convert_expression(arg, constraint_db, renames, bindings);
stmts.append(&mut astmts);
nargs.push(simplify_expr(aexp, &mut stmts));
atypes.push(atype);
}
constraint_db.push(Constraint::ProperPrimitiveArgs(loc.clone(), primop, atypes.clone(), ret_type.clone()));
(stmts, ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs), ret_type)
}
}
}
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef {
match expr {
ir::Expression::Atomic(v_or_ref) => v_or_ref,
expr => {
let etype = expr.type_of().clone();
let loc = expr.location().clone();
let nname = gensym("g");
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr);
stmts.push(nbinding);
ir::ValueOrRef::Ref(loc, etype, nname)
}
}
}
impl syntax::Program {
/// Infer the types for the syntactic AST, returning either a type-checked program in
/// the IR, or a series of type errors encountered during inference.
///
/// You really should have made sure that this program was validated before running
/// this method, otherwise you may experience panics during operation.
pub fn type_infer(self) -> ir::Program {
let mut constraint_db = vec![];
let program = convert_program(self, &mut constraint_db);
let mut changed_something = true;
// We want to run this inference endlessly, until either we have solved all of our
// constraints or we've gotten stuck somewhere.
while constraint_db.len() > 0 && changed_something {
// Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out.
changed_something = false;
// This is sort of a double-buffering thing; we're going to rename constraint_db
// and then set it to a new empty vector, which we'll add to as we find new
// constraints or find ourselves unable to solve existing ones.
let mut local_constraints = constraint_db;
constraint_db = vec![];
for constraint in local_constraints.drain(..) {
match constraint {
// Currently, all of our types are printable
Constraint::Printable(_loc, _ty) => {}
Constraint::FitsInNumType(loc, ctype, val) => unimplemented!(),
Constraint::ProperPrimitiveArgs(loc, prim, args, ret) => unimplemented!(),
Constraint::CanCastTo(loc, from_type, to_type) => unimplemented!(),
}
}
}
program
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
/// Generate a fresh new type; this will be a unique new type variable.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gentype() -> Type {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = ArcIntern::new(format!(
"t<{}>",
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
));
Type::Variable(Location::manufactured(), new_name)
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input: syntax::Program) {
let syntax_result = input.eval();
let ir = input.type_infer();
let ir_result = ir.eval();
assert_eq!(syntax_result, ir_result);
}
}

View File

@@ -1,5 +1,4 @@
use crate::backend::{Backend, BackendError}; use crate::backend::{Backend, BackendError};
use crate::ir::Program as IR;
use crate::syntax::{ConstantType, Location, ParserError, Statement}; use crate::syntax::{ConstantType, Location, ParserError, Statement};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use codespan_reporting::files::SimpleFiles; use codespan_reporting::files::SimpleFiles;
@@ -155,7 +154,10 @@ impl REPL {
return Ok(()); return Ok(());
} }
let ir = IR::from(syntax); let ir = crate::syntax::Program {
statements: vec![syntax],
}
.type_infer();
let name = format!("line{}", line_no); let name = format!("line{}", line_no);
let function_id = self.jitter.compile_function(&name, ir)?; let function_id = self.jitter.compile_function(&name, ir)?;
self.jitter.module.finalize_definitions()?; self.jitter.module.finalize_definitions()?;

View File

@@ -1,9 +1,10 @@
use crate::{syntax::{Expression, Location, Program, Statement}, eval::PrimitiveType}; use crate::{
eval::PrimitiveType,
syntax::{Expression, Location, Program, Statement},
};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use std::{collections::HashMap, str::FromStr}; use std::{collections::HashMap, str::FromStr};
use super::location;
/// An error we found while validating the input program. /// An error we found while validating the input program.
/// ///
/// These errors indicate that we should stop trying to compile /// These errors indicate that we should stop trying to compile