🤔 Add a type inference engine, along with typed literals. (#4)

The typed literal formatting mirrors that of Rust. If no type can be
inferred for an untagged literal, the type inference engine will warn
the user and then assume that they meant an unsigned 64-bit number.
(This is slightly inconvenient, because there can be cases in which our
Arbitrary instance may generate a unary negation, in which we should
assume that it's a signed 64-bit number; we may want to revisit this
later.)

The type inference engine is a standard two phase one, in which we first
generate a series of type constraints, and then we solve those
constraints. In this particular implementation, we actually use a third
phase to generate a final AST.

Finally, to increase the amount of testing performed, I've removed the
overflow checking in the evaluator. The only thing we now check for is
division by zero. This does make things a trace slower in testing, but
hopefully we get more coverage this way.
This commit was merged in pull request #4.
This commit is contained in:
2023-09-19 20:40:05 -07:00
committed by GitHub
parent 1fbfd0c2d2
commit bd3b9af469
44 changed files with 3258 additions and 702 deletions

View File

@@ -1,4 +1,4 @@
use crate::backend::runtime::RuntimeFunctionError;
use crate::{backend::runtime::RuntimeFunctionError, eval::PrimitiveType, ir::Type};
use codespan_reporting::diagnostic::Diagnostic;
use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError};
use cranelift_module::ModuleError;
@@ -39,6 +39,8 @@ pub enum BackendError {
LookupError(#[from] LookupError),
#[error(transparent)]
Write(#[from] cranelift_object::object::write::Error),
#[error("Invalid type cast from {from} to {to}")]
InvalidTypeCast { from: PrimitiveType, to: Type },
}
impl From<BackendError> for Diagnostic<usize> {
@@ -64,6 +66,9 @@ impl From<BackendError> for Diagnostic<usize> {
BackendError::Write(me) => {
Diagnostic::error().with_message(format!("Cranelift object write error: {}", me))
}
BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message(
format!("Internal error trying to cast from {} to {}", from, to),
),
}
}
}
@@ -103,6 +108,17 @@ impl PartialEq for BackendError {
BackendError::Write(b) => a == b,
_ => false,
},
BackendError::InvalidTypeCast {
from: from1,
to: to1,
} => match other {
BackendError::InvalidTypeCast {
from: from2,
to: to2,
} => from1 == from2 && to1 == to2,
_ => false,
},
}
}
}

View File

@@ -1,10 +1,13 @@
use std::path::Path;
use crate::backend::Backend;
use crate::eval::EvalError;
use crate::ir::Program;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
use cranelift_jit::JITModule;
use cranelift_object::ObjectModule;
#[cfg(test)]
use proptest::arbitrary::Arbitrary;
use std::path::Path;
use target_lexicon::Triple;
impl Backend<JITModule> {
@@ -28,7 +31,8 @@ impl Backend<JITModule> {
let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(jitter.output())
let output = jitter.output();
Ok(output)
}
}
@@ -116,7 +120,7 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow.
#[test]
fn static_backend(program: Program) {
fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError;
let basic_result = program.eval();
@@ -127,8 +131,18 @@ proptest::proptest! {
let basic_result = basic_result.map(|x| x.replace('\n', "\r\n"));
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let compiled_result = Backend::<ObjectModule>::eval(program);
assert_eq!(basic_result, compiled_result);
proptest::prop_assert_eq!(basic_result, compiled_result);
}
}
@@ -136,14 +150,24 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow.
#[test]
fn jit_backend(program: Program) {
fn jit_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError;
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let basic_result = program.eval();
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program);
assert_eq!(basic_result, compiled_result);
proptest::prop_assert_eq!(basic_result, compiled_result);
}
}
}

View File

@@ -1,9 +1,11 @@
use std::collections::HashMap;
use crate::ir::{Expression, Primitive, Program, Statement, Value, ValueOrRef};
use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, Statement, Type, Value, ValueOrRef};
use crate::syntax::ConstantType;
use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::{
entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName,
self, entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName,
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::Context;
@@ -41,7 +43,7 @@ impl<M: Module> Backend<M> {
let basic_signature = Signature {
params: vec![],
returns: vec![],
call_conv: CallConv::SystemV,
call_conv: CallConv::triple_default(&self.platform),
};
// this generates the handle for the function that we'll eventually want to
@@ -85,12 +87,12 @@ impl<M: Module> Backend<M> {
// Just like with strings, generating the `GlobalValue`s we need can potentially
// be a little tricky to do on the fly, so we generate the complete list right
// here and then use it later.
let pre_defined_symbols: HashMap<String, GlobalValue> = self
let pre_defined_symbols: HashMap<String, (GlobalValue, ConstantType)> = self
.defined_symbols
.iter()
.map(|(k, v)| {
.map(|(k, (v, t))| {
let local_data = self.module.declare_data_in_func(*v, &mut ctx.func);
(k.clone(), local_data)
(k.clone(), (local_data, *t))
})
.collect();
@@ -123,7 +125,7 @@ impl<M: Module> Backend<M> {
// Print statements are fairly easy to compile: we just lookup the
// output buffer, the address of the string to print, and the value
// of whatever variable we're printing. Then we just call print.
Statement::Print(ann, var) => {
Statement::Print(ann, t, var) => {
// Get the output buffer (or null) from our general compilation context.
let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
@@ -135,31 +137,47 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation.
let val = Expression::Reference(ann, var).into_crane(
let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane(
&mut builder,
&variable_table,
&pre_defined_symbols,
)?;
let vtype_repr = builder.ins().iconst(types::I64, vtype as i64);
let casted_val = match vtype {
ConstantType::U64 | ConstantType::I64 => val,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => {
builder.ins().sextend(types::I64, val)
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => {
builder.ins().uextend(types::I64, val)
}
};
// Finally, we can generate the call to print.
builder
.ins()
.call(print_func_ref, &[buffer_ptr, name_ptr, val]);
builder.ins().call(
print_func_ref,
&[buffer_ptr, name_ptr, vtype_repr, casted_val],
);
}
// Variable binding is a little more con
Statement::Binding(_, var_name, value) => {
Statement::Binding(_, var_name, _, value) => {
// Kick off to the `Expression` implementation to see what value we're going
// to bind to this variable.
let val =
let (val, etype) =
value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?;
// Now the question is: is this a local variable, or a global one?
if let Some(global_id) = pre_defined_symbols.get(var_name.as_str()) {
if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) {
// It's a global variable! In this case, we assume that someone has already
// dedicated some space in memory to store this value. We look this location
// up, and then tell Cranelift to store the value there.
let val_ptr = builder.ins().symbol_value(types::I64, *global_id);
assert_eq!(etype, *ctype);
let val_ptr = builder
.ins()
.symbol_value(ir::Type::from(*ctype), *global_id);
builder.ins().store(MemFlags::new(), val, val_ptr, 0);
} else {
// It's a local variable! In this case, we need to allocate a new Cranelift
@@ -171,12 +189,10 @@ impl<M: Module> Backend<M> {
next_var_num += 1;
// We can add the variable directly to our local variable map; it's `Copy`.
variable_table.insert(var_name, var);
variable_table.insert(var_name, (var, etype));
// Now we tell Cranelift about our new variable, which has type I64 because
// everything we have at this point is of type I64. Once it's declare, we
// define it as having the value we computed above.
builder.declare_var(var, types::I64);
// Now we tell Cranelift about our new variable!
builder.declare_var(var, ir::Type::from(etype));
builder.def_var(var, val);
}
}
@@ -195,7 +211,7 @@ impl<M: Module> Backend<M> {
// so we register it using the function ID and our builder context. However, the
// result of this function isn't actually super helpful. So we ignore it, unless
// it's an error.
let _ = self.module.define_function(func_id, &mut ctx)?;
self.module.define_function(func_id, &mut ctx)?;
// done!
Ok(func_id)
@@ -231,54 +247,110 @@ impl Expression {
fn into_crane(
self,
builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>,
global_variables: &HashMap<String, GlobalValue>,
) -> Result<entities::Value, BackendError> {
local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<(entities::Value, ConstantType), BackendError> {
match self {
// Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type.
Expression::Value(_, Value::Number(_, v)) => Ok(builder.ins().iconst(types::I64, v)),
Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables),
Expression::Reference(_, name) => {
// first we see if this is a local variable (which is nicer, from an
// optimization point of view.)
if let Some(local_var) = local_variables.get(&name) {
return Ok(builder.use_var(*local_var));
Expression::Cast(_, target_type, expr) => {
let (val, val_type) =
expr.into_crane(builder, local_variables, global_variables)?;
match (val_type, &target_type) {
(ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)),
(ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().sextend(types::I16, val), ConstantType::I16))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)),
(ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)),
(ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => {
Ok((builder.ins().uextend(types::I16, val), ConstantType::U16))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)),
(ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)),
(ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
_ => Err(BackendError::InvalidTypeCast {
from: val_type.into(),
to: target_type,
}),
}
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some(global_var) = global_variables.get(name.as_ref()) {
let val_ptr = builder.ins().symbol_value(types::I64, *global_var);
return Ok(builder.ins().load(types::I64, MemFlags::new(), val_ptr, 0));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
}
Expression::Primitive(_, prim, mut vals) => {
// we're going to use `pop`, so we're going to pull and compile the right value ...
let right =
vals.pop()
.unwrap()
.into_crane(builder, local_variables, global_variables)?;
// ... and then the left.
let left =
vals.pop()
.unwrap()
.into_crane(builder, local_variables, global_variables)?;
Expression::Primitive(_, _, prim, mut vals) => {
let mut values = vec![];
let mut first_type = None;
for val in vals.drain(..) {
let (compiled, compiled_type) =
val.into_crane(builder, local_variables, global_variables)?;
if let Some(leftmost_type) = first_type {
assert_eq!(leftmost_type, compiled_type);
} else {
first_type = Some(compiled_type);
}
values.push(compiled);
}
let first_type = first_type.expect("primitive op has at least one argument");
// then we just need to tell Cranelift how to do each of our primitives! Much
// like Statements, above, we probably want to eventually shuffle this off into
// a separate function (maybe something off `Primitive`), but for now it's simple
// enough that we just do the `match` here.
match prim {
Primitive::Plus => Ok(builder.ins().iadd(left, right)),
Primitive::Minus => Ok(builder.ins().isub(left, right)),
Primitive::Times => Ok(builder.ins().imul(left, right)),
Primitive::Divide => Ok(builder.ins().sdiv(left, right)),
Primitive::Plus => Ok((builder.ins().iadd(values[0], values[1]), first_type)),
Primitive::Minus if values.len() == 2 => {
Ok((builder.ins().isub(values[0], values[1]), first_type))
}
Primitive::Minus => Ok((builder.ins().ineg(values[0]), first_type)),
Primitive::Times => Ok((builder.ins().imul(values[0], values[1]), first_type)),
Primitive::Divide if first_type.is_signed() => {
Ok((builder.ins().sdiv(values[0], values[1]), first_type))
}
Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)),
}
}
}
@@ -291,9 +363,66 @@ impl ValueOrRef {
fn into_crane(
self,
builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>,
global_variables: &HashMap<String, GlobalValue>,
) -> Result<entities::Value, BackendError> {
Expression::from(self).into_crane(builder, local_variables, global_variables)
local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<(entities::Value, ConstantType), BackendError> {
match self {
// Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type.
ValueOrRef::Value(_, _, val) => match val {
Value::I8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8))
}
Value::I16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::I16,
)),
Value::I32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::I32,
)),
Value::I64(_, v) => Ok((builder.ins().iconst(types::I64, v), ConstantType::I64)),
Value::U8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::U8))
}
Value::U16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::U16,
)),
Value::U32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::U32,
)),
Value::U64(_, v) => Ok((
builder.ins().iconst(types::I64, v as i64),
ConstantType::U64,
)),
},
ValueOrRef::Ref(_, _, name) => {
// first we see if this is a local variable (which is nicer, from an
// optimization point of view.)
if let Some((local_var, etype)) = local_variables.get(&name) {
return Ok((builder.use_var(*local_var), *etype));
}
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some((global_var, etype)) = global_variables.get(name.as_ref()) {
let cranelift_type = ir::Type::from(*etype);
let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var);
return Ok((
builder
.ins()
.load(cranelift_type, MemFlags::new(), val_ptr, 0),
*etype,
));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
}
}
}
}

View File

@@ -8,6 +8,8 @@ use std::fmt::Write;
use target_lexicon::Triple;
use thiserror::Error;
use crate::syntax::ConstantType;
/// An object for querying / using functions built into the runtime.
///
/// Right now, this is a quite a bit of boilerplate for very nebulous
@@ -49,7 +51,7 @@ impl RuntimeFunctions {
"print",
Linkage::Import,
&Signature {
params: vec![string_param, string_param, int64_param],
params: vec![string_param, string_param, int64_param, int64_param],
returns: vec![],
call_conv: CallConv::triple_default(platform),
},
@@ -98,13 +100,30 @@ impl RuntimeFunctions {
// we extend with the output, so that multiple JIT'd `Program`s can run concurrently
// without stomping over each other's output. If `output_buffer` is NULL, we just print
// to stdout.
extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) {
extern "C" fn runtime_print(
output_buffer: *mut String,
name: *const i8,
vtype_repr: i64,
value: i64,
) {
let cstr = unsafe { CStr::from_ptr(name) };
let reconstituted = cstr.to_string_lossy();
let output = match vtype_repr.try_into() {
Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8),
Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16),
Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32),
Ok(ConstantType::I64) => format!("{} = {}i64", reconstituted, value),
Ok(ConstantType::U8) => format!("{} = {}u8", reconstituted, value as u8),
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value),
};
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap();
writeln!(output_buffer, "{}", output).unwrap();
} else {
println!("{} = {}", reconstituted, value);
println!("{}", output);
}
}