🤔 Add a type inference engine, along with typed literals. (#4)

The typed literal formatting mirrors that of Rust. If no type can be
inferred for an untagged literal, the type inference engine will warn
the user and then assume that they meant an unsigned 64-bit number.
(This is slightly inconvenient, because there can be cases in which our
Arbitrary instance may generate a unary negation, in which we should
assume that it's a signed 64-bit number; we may want to revisit this
later.)

The type inference engine is a standard two phase one, in which we first
generate a series of type constraints, and then we solve those
constraints. In this particular implementation, we actually use a third
phase to generate a final AST.

Finally, to increase the amount of testing performed, I've removed the
overflow checking in the evaluator. The only thing we now check for is
division by zero. This does make things a trace slower in testing, but
hopefully we get more coverage this way.
This commit was merged in pull request #4.
This commit is contained in:
2023-09-19 20:40:05 -07:00
committed by GitHub
parent 1fbfd0c2d2
commit bd3b9af469
44 changed files with 3258 additions and 702 deletions

View File

@@ -31,15 +31,15 @@ mod eval;
mod into_crane;
mod runtime;
use std::collections::HashMap;
pub use self::error::BackendError;
pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions};
use crate::syntax::ConstantType;
use cranelift_codegen::settings::Configurable;
use cranelift_codegen::{isa, settings};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{default_libcall_names, DataContext, DataId, FuncId, Linkage, Module};
use cranelift_module::{default_libcall_names, DataDescription, DataId, FuncId, Linkage, Module};
use cranelift_object::{ObjectBuilder, ObjectModule};
use std::collections::HashMap;
use target_lexicon::Triple;
const EMPTY_DATUM: [u8; 8] = [0; 8];
@@ -55,11 +55,12 @@ const EMPTY_DATUM: [u8; 8] = [0; 8];
/// implementations.
pub struct Backend<M: Module> {
pub module: M,
data_ctx: DataContext,
data_ctx: DataDescription,
runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>,
defined_symbols: HashMap<String, DataId>,
defined_symbols: HashMap<String, (DataId, ConstantType)>,
output_buffer: Option<String>,
platform: Triple,
}
impl Backend<JITModule> {
@@ -85,11 +86,12 @@ impl Backend<JITModule> {
Ok(Backend {
module,
data_ctx: DataContext::new(),
data_ctx: DataDescription::new(),
runtime_functions,
defined_strings: HashMap::new(),
defined_symbols: HashMap::new(),
output_buffer,
platform: Triple::host(),
})
}
@@ -123,11 +125,12 @@ impl Backend<ObjectModule> {
Ok(Backend {
module,
data_ctx: DataContext::new(),
data_ctx: DataDescription::new(),
runtime_functions,
defined_strings: HashMap::new(),
defined_symbols: HashMap::new(),
output_buffer: None,
platform,
})
}
@@ -154,7 +157,7 @@ impl<M: Module> Backend<M> {
let global_id = self
.module
.declare_data(&name, Linkage::Local, false, false)?;
let mut data_context = DataContext::new();
let mut data_context = DataDescription::new();
data_context.set_align(8);
data_context.define(s0.into_boxed_str().into_boxed_bytes());
self.module.define_data(global_id, &data_context)?;
@@ -167,14 +170,18 @@ impl<M: Module> Backend<M> {
/// These variables can be shared between functions, and will be exported from the
/// module itself as public data in the case of static compilation. There initial
/// value will be null.
pub fn define_variable(&mut self, name: String) -> Result<DataId, BackendError> {
pub fn define_variable(
&mut self,
name: String,
ctype: ConstantType,
) -> Result<DataId, BackendError> {
self.data_ctx.define(Box::new(EMPTY_DATUM));
let id = self
.module
.declare_data(&name, Linkage::Export, true, false)?;
self.module.define_data(id, &self.data_ctx)?;
self.data_ctx.clear();
self.defined_symbols.insert(name, id);
self.defined_symbols.insert(name, (id, ctype));
Ok(id)
}

View File

@@ -1,4 +1,4 @@
use crate::backend::runtime::RuntimeFunctionError;
use crate::{backend::runtime::RuntimeFunctionError, eval::PrimitiveType, ir::Type};
use codespan_reporting::diagnostic::Diagnostic;
use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError};
use cranelift_module::ModuleError;
@@ -39,6 +39,8 @@ pub enum BackendError {
LookupError(#[from] LookupError),
#[error(transparent)]
Write(#[from] cranelift_object::object::write::Error),
#[error("Invalid type cast from {from} to {to}")]
InvalidTypeCast { from: PrimitiveType, to: Type },
}
impl From<BackendError> for Diagnostic<usize> {
@@ -64,6 +66,9 @@ impl From<BackendError> for Diagnostic<usize> {
BackendError::Write(me) => {
Diagnostic::error().with_message(format!("Cranelift object write error: {}", me))
}
BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message(
format!("Internal error trying to cast from {} to {}", from, to),
),
}
}
}
@@ -103,6 +108,17 @@ impl PartialEq for BackendError {
BackendError::Write(b) => a == b,
_ => false,
},
BackendError::InvalidTypeCast {
from: from1,
to: to1,
} => match other {
BackendError::InvalidTypeCast {
from: from2,
to: to2,
} => from1 == from2 && to1 == to2,
_ => false,
},
}
}
}

View File

@@ -1,10 +1,13 @@
use std::path::Path;
use crate::backend::Backend;
use crate::eval::EvalError;
use crate::ir::Program;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
use cranelift_jit::JITModule;
use cranelift_object::ObjectModule;
#[cfg(test)]
use proptest::arbitrary::Arbitrary;
use std::path::Path;
use target_lexicon::Triple;
impl Backend<JITModule> {
@@ -28,7 +31,8 @@ impl Backend<JITModule> {
let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(jitter.output())
let output = jitter.output();
Ok(output)
}
}
@@ -116,7 +120,7 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow.
#[test]
fn static_backend(program: Program) {
fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError;
let basic_result = program.eval();
@@ -127,8 +131,18 @@ proptest::proptest! {
let basic_result = basic_result.map(|x| x.replace('\n', "\r\n"));
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let compiled_result = Backend::<ObjectModule>::eval(program);
assert_eq!(basic_result, compiled_result);
proptest::prop_assert_eq!(basic_result, compiled_result);
}
}
@@ -136,14 +150,24 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow.
#[test]
fn jit_backend(program: Program) {
fn jit_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::PrimOpError;
// use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator;
// allocator
// .text("---------------")
// .append(allocator.hardline())
// .append(program.pretty(&allocator))
// .1
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works");
let basic_result = program.eval();
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program);
assert_eq!(basic_result, compiled_result);
proptest::prop_assert_eq!(basic_result, compiled_result);
}
}
}

View File

@@ -1,9 +1,11 @@
use std::collections::HashMap;
use crate::ir::{Expression, Primitive, Program, Statement, Value, ValueOrRef};
use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, Statement, Type, Value, ValueOrRef};
use crate::syntax::ConstantType;
use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::{
entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName,
self, entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName,
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::Context;
@@ -41,7 +43,7 @@ impl<M: Module> Backend<M> {
let basic_signature = Signature {
params: vec![],
returns: vec![],
call_conv: CallConv::SystemV,
call_conv: CallConv::triple_default(&self.platform),
};
// this generates the handle for the function that we'll eventually want to
@@ -85,12 +87,12 @@ impl<M: Module> Backend<M> {
// Just like with strings, generating the `GlobalValue`s we need can potentially
// be a little tricky to do on the fly, so we generate the complete list right
// here and then use it later.
let pre_defined_symbols: HashMap<String, GlobalValue> = self
let pre_defined_symbols: HashMap<String, (GlobalValue, ConstantType)> = self
.defined_symbols
.iter()
.map(|(k, v)| {
.map(|(k, (v, t))| {
let local_data = self.module.declare_data_in_func(*v, &mut ctx.func);
(k.clone(), local_data)
(k.clone(), (local_data, *t))
})
.collect();
@@ -123,7 +125,7 @@ impl<M: Module> Backend<M> {
// Print statements are fairly easy to compile: we just lookup the
// output buffer, the address of the string to print, and the value
// of whatever variable we're printing. Then we just call print.
Statement::Print(ann, var) => {
Statement::Print(ann, t, var) => {
// Get the output buffer (or null) from our general compilation context.
let buffer_ptr = self.output_buffer_ptr();
let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64);
@@ -135,31 +137,47 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation.
let val = Expression::Reference(ann, var).into_crane(
let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane(
&mut builder,
&variable_table,
&pre_defined_symbols,
)?;
let vtype_repr = builder.ins().iconst(types::I64, vtype as i64);
let casted_val = match vtype {
ConstantType::U64 | ConstantType::I64 => val,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => {
builder.ins().sextend(types::I64, val)
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => {
builder.ins().uextend(types::I64, val)
}
};
// Finally, we can generate the call to print.
builder
.ins()
.call(print_func_ref, &[buffer_ptr, name_ptr, val]);
builder.ins().call(
print_func_ref,
&[buffer_ptr, name_ptr, vtype_repr, casted_val],
);
}
// Variable binding is a little more con
Statement::Binding(_, var_name, value) => {
Statement::Binding(_, var_name, _, value) => {
// Kick off to the `Expression` implementation to see what value we're going
// to bind to this variable.
let val =
let (val, etype) =
value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?;
// Now the question is: is this a local variable, or a global one?
if let Some(global_id) = pre_defined_symbols.get(var_name.as_str()) {
if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) {
// It's a global variable! In this case, we assume that someone has already
// dedicated some space in memory to store this value. We look this location
// up, and then tell Cranelift to store the value there.
let val_ptr = builder.ins().symbol_value(types::I64, *global_id);
assert_eq!(etype, *ctype);
let val_ptr = builder
.ins()
.symbol_value(ir::Type::from(*ctype), *global_id);
builder.ins().store(MemFlags::new(), val, val_ptr, 0);
} else {
// It's a local variable! In this case, we need to allocate a new Cranelift
@@ -171,12 +189,10 @@ impl<M: Module> Backend<M> {
next_var_num += 1;
// We can add the variable directly to our local variable map; it's `Copy`.
variable_table.insert(var_name, var);
variable_table.insert(var_name, (var, etype));
// Now we tell Cranelift about our new variable, which has type I64 because
// everything we have at this point is of type I64. Once it's declare, we
// define it as having the value we computed above.
builder.declare_var(var, types::I64);
// Now we tell Cranelift about our new variable!
builder.declare_var(var, ir::Type::from(etype));
builder.def_var(var, val);
}
}
@@ -195,7 +211,7 @@ impl<M: Module> Backend<M> {
// so we register it using the function ID and our builder context. However, the
// result of this function isn't actually super helpful. So we ignore it, unless
// it's an error.
let _ = self.module.define_function(func_id, &mut ctx)?;
self.module.define_function(func_id, &mut ctx)?;
// done!
Ok(func_id)
@@ -231,54 +247,110 @@ impl Expression {
fn into_crane(
self,
builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>,
global_variables: &HashMap<String, GlobalValue>,
) -> Result<entities::Value, BackendError> {
local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<(entities::Value, ConstantType), BackendError> {
match self {
// Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type.
Expression::Value(_, Value::Number(_, v)) => Ok(builder.ins().iconst(types::I64, v)),
Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables),
Expression::Reference(_, name) => {
// first we see if this is a local variable (which is nicer, from an
// optimization point of view.)
if let Some(local_var) = local_variables.get(&name) {
return Ok(builder.use_var(*local_var));
Expression::Cast(_, target_type, expr) => {
let (val, val_type) =
expr.into_crane(builder, local_variables, global_variables)?;
match (val_type, &target_type) {
(ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)),
(ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().sextend(types::I16, val), ConstantType::I16))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)),
(ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().sextend(types::I32, val), ConstantType::I32))
}
(ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)),
(ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().sextend(types::I64, val), ConstantType::I64))
}
(ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)),
(ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => {
Ok((builder.ins().uextend(types::I16, val), ConstantType::U16))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)),
(ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::U32))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)),
(ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::U64))
}
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
_ => Err(BackendError::InvalidTypeCast {
from: val_type.into(),
to: target_type,
}),
}
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some(global_var) = global_variables.get(name.as_ref()) {
let val_ptr = builder.ins().symbol_value(types::I64, *global_var);
return Ok(builder.ins().load(types::I64, MemFlags::new(), val_ptr, 0));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
}
Expression::Primitive(_, prim, mut vals) => {
// we're going to use `pop`, so we're going to pull and compile the right value ...
let right =
vals.pop()
.unwrap()
.into_crane(builder, local_variables, global_variables)?;
// ... and then the left.
let left =
vals.pop()
.unwrap()
.into_crane(builder, local_variables, global_variables)?;
Expression::Primitive(_, _, prim, mut vals) => {
let mut values = vec![];
let mut first_type = None;
for val in vals.drain(..) {
let (compiled, compiled_type) =
val.into_crane(builder, local_variables, global_variables)?;
if let Some(leftmost_type) = first_type {
assert_eq!(leftmost_type, compiled_type);
} else {
first_type = Some(compiled_type);
}
values.push(compiled);
}
let first_type = first_type.expect("primitive op has at least one argument");
// then we just need to tell Cranelift how to do each of our primitives! Much
// like Statements, above, we probably want to eventually shuffle this off into
// a separate function (maybe something off `Primitive`), but for now it's simple
// enough that we just do the `match` here.
match prim {
Primitive::Plus => Ok(builder.ins().iadd(left, right)),
Primitive::Minus => Ok(builder.ins().isub(left, right)),
Primitive::Times => Ok(builder.ins().imul(left, right)),
Primitive::Divide => Ok(builder.ins().sdiv(left, right)),
Primitive::Plus => Ok((builder.ins().iadd(values[0], values[1]), first_type)),
Primitive::Minus if values.len() == 2 => {
Ok((builder.ins().isub(values[0], values[1]), first_type))
}
Primitive::Minus => Ok((builder.ins().ineg(values[0]), first_type)),
Primitive::Times => Ok((builder.ins().imul(values[0], values[1]), first_type)),
Primitive::Divide if first_type.is_signed() => {
Ok((builder.ins().sdiv(values[0], values[1]), first_type))
}
Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)),
}
}
}
@@ -291,9 +363,66 @@ impl ValueOrRef {
fn into_crane(
self,
builder: &mut FunctionBuilder,
local_variables: &HashMap<ArcIntern<String>, Variable>,
global_variables: &HashMap<String, GlobalValue>,
) -> Result<entities::Value, BackendError> {
Expression::from(self).into_crane(builder, local_variables, global_variables)
local_variables: &HashMap<ArcIntern<String>, (Variable, ConstantType)>,
global_variables: &HashMap<String, (GlobalValue, ConstantType)>,
) -> Result<(entities::Value, ConstantType), BackendError> {
match self {
// Values are pretty straightforward to compile, mostly because we only
// have one type of variable, and it's an integer type.
ValueOrRef::Value(_, _, val) => match val {
Value::I8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8))
}
Value::I16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::I16,
)),
Value::I32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::I32,
)),
Value::I64(_, v) => Ok((builder.ins().iconst(types::I64, v), ConstantType::I64)),
Value::U8(_, v) => {
Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::U8))
}
Value::U16(_, v) => Ok((
builder.ins().iconst(types::I16, v as i64),
ConstantType::U16,
)),
Value::U32(_, v) => Ok((
builder.ins().iconst(types::I32, v as i64),
ConstantType::U32,
)),
Value::U64(_, v) => Ok((
builder.ins().iconst(types::I64, v as i64),
ConstantType::U64,
)),
},
ValueOrRef::Ref(_, _, name) => {
// first we see if this is a local variable (which is nicer, from an
// optimization point of view.)
if let Some((local_var, etype)) = local_variables.get(&name) {
return Ok((builder.use_var(*local_var), *etype));
}
// then we check to see if this is a global reference, which requires us to
// first lookup where the value is stored, and then load it.
if let Some((global_var, etype)) = global_variables.get(name.as_ref()) {
let cranelift_type = ir::Type::from(*etype);
let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var);
return Ok((
builder
.ins()
.load(cranelift_type, MemFlags::new(), val_ptr, 0),
*etype,
));
}
// this should never happen, because we should have made sure that there are
// no unbound variables a long time before this. but still ...
Err(BackendError::VariableLookupFailure(name))
}
}
}
}

View File

@@ -8,6 +8,8 @@ use std::fmt::Write;
use target_lexicon::Triple;
use thiserror::Error;
use crate::syntax::ConstantType;
/// An object for querying / using functions built into the runtime.
///
/// Right now, this is a quite a bit of boilerplate for very nebulous
@@ -49,7 +51,7 @@ impl RuntimeFunctions {
"print",
Linkage::Import,
&Signature {
params: vec![string_param, string_param, int64_param],
params: vec![string_param, string_param, int64_param, int64_param],
returns: vec![],
call_conv: CallConv::triple_default(platform),
},
@@ -98,13 +100,30 @@ impl RuntimeFunctions {
// we extend with the output, so that multiple JIT'd `Program`s can run concurrently
// without stomping over each other's output. If `output_buffer` is NULL, we just print
// to stdout.
extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) {
extern "C" fn runtime_print(
output_buffer: *mut String,
name: *const i8,
vtype_repr: i64,
value: i64,
) {
let cstr = unsafe { CStr::from_ptr(name) };
let reconstituted = cstr.to_string_lossy();
let output = match vtype_repr.try_into() {
Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8),
Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16),
Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32),
Ok(ConstantType::I64) => format!("{} = {}i64", reconstituted, value),
Ok(ConstantType::U8) => format!("{} = {}u8", reconstituted, value as u8),
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value),
};
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {
writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap();
writeln!(output_buffer, "{}", output).unwrap();
} else {
println!("{} = {}", reconstituted, value);
println!("{}", output);
}
}

View File

@@ -17,7 +17,7 @@ fn main() {
let args = CommandLineArguments::parse();
let mut compiler = ngr::Compiler::default();
let output_file = args.output.unwrap_or("output.o".to_string());
let output_file = args.output.unwrap_or_else(|| "output.o".to_string());
if let Some(bytes) = compiler.compile(&args.file) {
std::fs::write(&output_file, bytes)

View File

@@ -1,6 +1,5 @@
use crate::backend::Backend;
use crate::ir::Program as IR;
use crate::syntax::Program as Syntax;
use crate::{backend::Backend, type_infer::TypeInferenceResult};
use codespan_reporting::{
diagnostic::Diagnostic,
files::SimpleFiles,
@@ -100,8 +99,38 @@ impl Compiler {
return Ok(None);
}
// Now that we've validated it, turn it into IR.
let ir = IR::from(syntax);
// Now that we've validated it, let's do type inference, potentially turning
// into IR while we're at it.
let ir = match syntax.type_infer() {
TypeInferenceResult::Failure {
mut errors,
mut warnings,
} => {
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit(message);
}
return Ok(None);
}
TypeInferenceResult::Success {
result,
mut warnings,
} => {
let messages = warnings.drain(..).map(Into::into);
for message in messages {
self.emit(message);
}
result
}
};
// Finally, send all this to Cranelift for conversion into an object file.
let mut backend = Backend::object_file(Triple::host())?;

View File

@@ -35,11 +35,13 @@
//!
mod env;
mod primop;
mod primtype;
mod value;
use cranelift_module::ModuleError;
pub use env::{EvalEnvironment, LookupError};
pub use primop::PrimOpError;
pub use primtype::PrimitiveType;
pub use value::Value;
use crate::backend::BackendError;

View File

@@ -87,9 +87,9 @@ mod tests {
let tester = tester.extend(arced("bar"), 2i64.into());
let tester = tester.extend(arced("goo"), 5i64.into());
assert_eq!(tester.lookup(arced("foo")), Ok(1.into()));
assert_eq!(tester.lookup(arced("bar")), Ok(2.into()));
assert_eq!(tester.lookup(arced("goo")), Ok(5.into()));
assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into()));
assert_eq!(tester.lookup(arced("bar")), Ok(2i64.into()));
assert_eq!(tester.lookup(arced("goo")), Ok(5i64.into()));
assert!(tester.lookup(arced("baz")).is_err());
}
@@ -103,14 +103,14 @@ mod tests {
check_nested(&tester);
assert_eq!(tester.lookup(arced("foo")), Ok(1.into()));
assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into()));
assert!(tester.lookup(arced("bar")).is_err());
}
fn check_nested(env: &EvalEnvironment) {
let nested_env = env.extend(arced("bar"), 2i64.into());
assert_eq!(nested_env.lookup(arced("foo")), Ok(1.into()));
assert_eq!(nested_env.lookup(arced("bar")), Ok(2.into()));
assert_eq!(nested_env.lookup(arced("foo")), Ok(1i64.into()));
assert_eq!(nested_env.lookup(arced("bar")), Ok(2i64.into()));
}
fn arced(s: &str) -> ArcIntern<String> {

View File

@@ -1,3 +1,4 @@
use crate::eval::primtype::PrimitiveType;
use crate::eval::value::Value;
/// Errors that can occur running primitive operations in the evaluators.
@@ -22,6 +23,13 @@ pub enum PrimOpError {
BadArgCount(String, usize),
#[error("Unknown primitive operation {0}")]
UnknownPrimOp(String),
#[error("Unsafe cast from {from} to {to}")]
UnsafeCast {
from: PrimitiveType,
to: PrimitiveType,
},
#[error("Unknown primitive type {0}")]
UnknownPrimType(String),
}
// Implementing primitives in an interpreter like this is *super* tedious,
@@ -37,39 +45,95 @@ pub enum PrimOpError {
macro_rules! run_op {
($op: ident, $left: expr, $right: expr) => {
match $op {
"+" => $left
.checked_add($right)
.ok_or(PrimOpError::MathFailure("+"))
.map(Into::into),
"-" => $left
.checked_sub($right)
.ok_or(PrimOpError::MathFailure("-"))
.map(Into::into),
"*" => $left
.checked_mul($right)
.ok_or(PrimOpError::MathFailure("*"))
.map(Into::into),
"/" => $left
.checked_div($right)
.ok_or(PrimOpError::MathFailure("/"))
.map(Into::into),
"+" => Ok($left.wrapping_add($right).into()),
"-" => Ok($left.wrapping_sub($right).into()),
"*" => Ok($left.wrapping_mul($right).into()),
"/" if $right == 0 => Err(PrimOpError::MathFailure("/")),
"/" => Ok($left.wrapping_div($right).into()),
_ => Err(PrimOpError::UnknownPrimOp($op.to_string())),
}
};
}
impl Value {
fn unary_op(operation: &str, value: &Value) -> Result<Value, PrimOpError> {
match operation {
"-" => match value {
Value::I8(x) => Ok(Value::I8(x.wrapping_neg())),
Value::I16(x) => Ok(Value::I16(x.wrapping_neg())),
Value::I32(x) => Ok(Value::I32(x.wrapping_neg())),
Value::I64(x) => Ok(Value::I64(x.wrapping_neg())),
_ => Err(PrimOpError::BadTypeFor("-", value.clone())),
},
_ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)),
}
}
fn binary_op(operation: &str, left: &Value, right: &Value) -> Result<Value, PrimOpError> {
match left {
// for now we only have one type, but in the future this is
// going to be very irritating.
Value::I8(x) => match right {
Value::I8(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I16(x) => match right {
Value::I16(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I32(x) => match right {
Value::I32(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::I64(x) => match right {
Value::I64(y) => run_op!(operation, x, *y),
// _ => Err(PrimOpError::TypeMismatch(
// operation.to_string(),
// left.clone(),
// right.clone(),
// )),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U8(x) => match right {
Value::U8(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U16(x) => match right {
Value::U16(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U32(x) => match right {
Value::U32(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::U64(x) => match right {
Value::U64(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
}
}
@@ -83,13 +147,10 @@ impl Value {
/// its worth being careful to make sure that your inputs won't cause either
/// condition.
pub fn calculate(operation: &str, values: Vec<Value>) -> Result<Value, PrimOpError> {
if values.len() == 2 {
Value::binary_op(operation, &values[0], &values[1])
} else {
Err(PrimOpError::BadArgCount(
operation.to_string(),
values.len(),
))
match values.len() {
1 => Value::unary_op(operation, &values[0]),
2 => Value::binary_op(operation, &values[0], &values[1]),
x => Err(PrimOpError::BadArgCount(operation.to_string(), x)),
}
}
}

173
src/eval/primtype.rs Normal file
View File

@@ -0,0 +1,173 @@
use crate::{
eval::{PrimOpError, Value},
syntax::ConstantType,
};
use std::{fmt::Display, str::FromStr};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum PrimitiveType {
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::I8 => write!(f, "i8"),
PrimitiveType::I16 => write!(f, "i16"),
PrimitiveType::I32 => write!(f, "i32"),
PrimitiveType::I64 => write!(f, "i64"),
PrimitiveType::U8 => write!(f, "u8"),
PrimitiveType::U16 => write!(f, "u16"),
PrimitiveType::U32 => write!(f, "u32"),
PrimitiveType::U64 => write!(f, "u64"),
}
}
}
impl<'a> From<&'a Value> for PrimitiveType {
fn from(value: &Value) -> Self {
match value {
Value::I8(_) => PrimitiveType::I8,
Value::I16(_) => PrimitiveType::I16,
Value::I32(_) => PrimitiveType::I32,
Value::I64(_) => PrimitiveType::I64,
Value::U8(_) => PrimitiveType::U8,
Value::U16(_) => PrimitiveType::U16,
Value::U32(_) => PrimitiveType::U32,
Value::U64(_) => PrimitiveType::U64,
}
}
}
impl From<ConstantType> for PrimitiveType {
fn from(value: ConstantType) -> Self {
match value {
ConstantType::I8 => PrimitiveType::I8,
ConstantType::I16 => PrimitiveType::I16,
ConstantType::I32 => PrimitiveType::I32,
ConstantType::I64 => PrimitiveType::I64,
ConstantType::U8 => PrimitiveType::U8,
ConstantType::U16 => PrimitiveType::U16,
ConstantType::U32 => PrimitiveType::U32,
ConstantType::U64 => PrimitiveType::U64,
}
}
}
impl FromStr for PrimitiveType {
type Err = PrimOpError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"i8" => Ok(PrimitiveType::I8),
"i16" => Ok(PrimitiveType::I16),
"i32" => Ok(PrimitiveType::I32),
"i64" => Ok(PrimitiveType::I64),
"u8" => Ok(PrimitiveType::U8),
"u16" => Ok(PrimitiveType::U16),
"u32" => Ok(PrimitiveType::U32),
"u64" => Ok(PrimitiveType::U64),
_ => Err(PrimOpError::UnknownPrimType(s.to_string())),
}
}
}
impl PrimitiveType {
/// Return true if this type can be safely cast into the target type.
pub fn can_cast_to(&self, target: &PrimitiveType) -> bool {
match self {
PrimitiveType::U8 => matches!(
target,
PrimitiveType::U8
| PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I16
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U16 => matches!(
target,
PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U32 => matches!(
target,
PrimitiveType::U32 | PrimitiveType::U64 | PrimitiveType::I64
),
PrimitiveType::U64 => target == &PrimitiveType::U64,
PrimitiveType::I8 => matches!(
target,
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I16 => matches!(
target,
PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I32 => matches!(target, PrimitiveType::I32 | PrimitiveType::I64),
PrimitiveType::I64 => target == &PrimitiveType::I64,
}
}
/// Try to cast the given value to this type, returning the new value.
///
/// Returns an error if the cast is not safe *in* *general*. This means that
/// this function will error even if the number will actually fit in the target
/// type, but it would not be generally safe to cast a member of the given
/// type to the target type. (So, for example, "1i64" is a number that could
/// work as a "u64", but since negative numbers wouldn't work, a cast from
/// "1i64" to "u64" will fail.)
pub fn safe_cast(&self, source: &Value) -> Result<Value, PrimOpError> {
match (self, source) {
(PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)),
(PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)),
(PrimitiveType::U16, Value::U16(x)) => Ok(Value::U16(*x)),
(PrimitiveType::U32, Value::U8(x)) => Ok(Value::U32(*x as u32)),
(PrimitiveType::U32, Value::U16(x)) => Ok(Value::U32(*x as u32)),
(PrimitiveType::U32, Value::U32(x)) => Ok(Value::U32(*x)),
(PrimitiveType::U64, Value::U8(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U16(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U32(x)) => Ok(Value::U64(*x as u64)),
(PrimitiveType::U64, Value::U64(x)) => Ok(Value::U64(*x)),
(PrimitiveType::I8, Value::I8(x)) => Ok(Value::I8(*x)),
(PrimitiveType::I16, Value::I8(x)) => Ok(Value::I16(*x as i16)),
(PrimitiveType::I16, Value::I16(x)) => Ok(Value::I16(*x)),
(PrimitiveType::I32, Value::I8(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I32, Value::I16(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I32, Value::I32(x)) => Ok(Value::I32(*x)),
(PrimitiveType::I64, Value::I8(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I16(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)),
_ => Err(PrimOpError::UnsafeCast {
from: source.into(),
to: *self,
}),
}
}
pub fn max_value(&self) -> u64 {
match self {
PrimitiveType::U8 => u8::MAX as u64,
PrimitiveType::U16 => u16::MAX as u64,
PrimitiveType::U32 => u32::MAX as u64,
PrimitiveType::U64 => u64::MAX,
PrimitiveType::I8 => i8::MAX as u64,
PrimitiveType::I16 => i16::MAX as u64,
PrimitiveType::I32 => i32::MAX as u64,
PrimitiveType::I64 => i64::MAX as u64,
}
}
}

View File

@@ -7,19 +7,75 @@ use std::fmt::Display;
/// by type so that we don't mix them up.
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
I8(i8),
I16(i16),
I32(i32),
I64(i64),
U8(u8),
U16(u16),
U32(u32),
U64(u64),
}
impl Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Value::I8(x) => write!(f, "{}i8", x),
Value::I16(x) => write!(f, "{}i16", x),
Value::I32(x) => write!(f, "{}i32", x),
Value::I64(x) => write!(f, "{}i64", x),
Value::U8(x) => write!(f, "{}u8", x),
Value::U16(x) => write!(f, "{}u16", x),
Value::U32(x) => write!(f, "{}u32", x),
Value::U64(x) => write!(f, "{}u64", x),
}
}
}
impl From<i8> for Value {
fn from(value: i8) -> Self {
Value::I8(value)
}
}
impl From<i16> for Value {
fn from(value: i16) -> Self {
Value::I16(value)
}
}
impl From<i32> for Value {
fn from(value: i32) -> Self {
Value::I32(value)
}
}
impl From<i64> for Value {
fn from(value: i64) -> Self {
Value::I64(value)
}
}
impl From<u8> for Value {
fn from(value: u8) -> Self {
Value::U8(value)
}
}
impl From<u16> for Value {
fn from(value: u16) -> Self {
Value::U16(value)
}
}
impl From<u32> for Value {
fn from(value: u32) -> Self {
Value::U32(value)
}
}
impl From<u64> for Value {
fn from(value: u64) -> Self {
Value::U64(value)
}
}

6
src/examples.rs Normal file
View File

@@ -0,0 +1,6 @@
use crate::backend::Backend;
use crate::syntax::Program as Syntax;
use codespan_reporting::files::SimpleFiles;
use cranelift_jit::JITModule;
include!(concat!(env!("OUT_DIR"), "/examples.rs"));

View File

@@ -12,9 +12,8 @@
//! validating syntax, and then figuring out how to turn it into Cranelift
//! and object code. After that point, however, this will be the module to
//! come to for analysis and optimization work.
mod ast;
pub mod ast;
mod eval;
mod from_syntax;
mod strings;
pub use ast::*;

View File

@@ -1,10 +1,14 @@
use crate::syntax::Location;
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use pretty::{BoxAllocator, DocAllocator, Pretty};
use proptest::{
prelude::Arbitrary,
strategy::{BoxedStrategy, Strategy},
};
use std::{fmt, str::FromStr};
/// We're going to represent variables as interned strings.
///
@@ -52,12 +56,15 @@ where
}
impl Arbitrary for Program {
type Parameters = ();
type Parameters = crate::syntax::arbitrary::GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
crate::syntax::Program::arbitrary_with(args)
.prop_map(Program::from)
.prop_map(|x| {
x.type_infer()
.expect("arbitrary_with should generate type-correct programs")
})
.boxed()
}
}
@@ -74,8 +81,8 @@ impl Arbitrary for Program {
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Expression),
Print(Location, Variable),
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
@@ -85,13 +92,13 @@ where
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, expr) => allocator
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, var) => allocator
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
@@ -113,9 +120,32 @@ where
/// variable reference.
#[derive(Debug)]
pub enum Expression {
Value(Location, Value),
Reference(Location, Variable),
Primitive(Location, Primitive, Vec<ValueOrRef>),
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
@@ -125,12 +155,16 @@ where
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Value(_, val) => val.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.as_ref().to_string()),
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => {
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
@@ -140,7 +174,7 @@ where
.append(right)
.parens()
}
Expression::Primitive(_, op, exprs) => {
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
@@ -161,10 +195,10 @@ pub enum Primitive {
Divide,
}
impl<'a> TryFrom<&'a str> for Primitive {
type Error = String;
impl FromStr for Primitive {
type Err = String;
fn try_from(value: &str) -> Result<Self, Self::Error> {
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
"+" => Ok(Primitive::Plus),
"-" => Ok(Primitive::Minus),
@@ -190,15 +224,21 @@ where
}
}
impl fmt::Display for Primitive {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f)
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub enum ValueOrRef {
Value(Location, Value),
Ref(Location, ArcIntern<String>),
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
@@ -208,30 +248,50 @@ where
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, v) => v.pretty(allocator),
ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()),
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
match value {
ValueOrRef::Value(loc, val) => Expression::Value(loc, val),
ValueOrRef::Ref(loc, var) => Expression::Reference(loc, var),
}
Expression::Atomic(value)
}
}
/// A constant in the IR.
#[derive(Debug)]
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug)]
pub enum Value {
/// A numerical constant.
///
/// The optional argument is the base that was used by the user to input
/// the number. By retaining it, we can ensure that if we need to print the
/// number back out, we can do so in the form that the user entered it.
Number(Option<u8>, i64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl Value {
/// Return the type described by this value
pub fn type_of(&self) -> Type {
match self {
Value::I8(_, _) => Type::Primitive(PrimitiveType::I8),
Value::I16(_, _) => Type::Primitive(PrimitiveType::I16),
Value::I32(_, _) => Type::Primitive(PrimitiveType::I32),
Value::I64(_, _) => Type::Primitive(PrimitiveType::I64),
Value::U8(_, _) => Type::Primitive(PrimitiveType::U8),
Value::U16(_, _) => Type::Primitive(PrimitiveType::U16),
Value::U32(_, _) => Type::Primitive(PrimitiveType::U32),
Value::U64(_, _) => Type::Primitive(PrimitiveType::U64),
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
@@ -240,19 +300,64 @@ where
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Value::Number(opt_base, value) => {
let value_str = match opt_base {
None => format!("{}", value),
Some(2) => format!("0b{:b}", value),
Some(8) => format!("0o{:o}", value),
Some(10) => format!("0d{}", value),
Some(16) => format!("0x{:x}", value),
Some(_) => format!("!!{:x}!!", value),
};
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
allocator.text(value_str)
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Primitive(PrimitiveType),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Primitive(pt) => pt.fmt(f),
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::eval::{EvalEnvironment, EvalError, Value};
use crate::ir::{Expression, Program, Statement};
use super::{Primitive, ValueOrRef};
use super::{Primitive, Type, ValueOrRef};
impl Program {
/// Evaluate the program, returning either an error or a string containing everything
@@ -14,12 +14,12 @@ impl Program {
for stmt in self.statements.iter() {
match stmt {
Statement::Binding(_, name, value) => {
Statement::Binding(_, name, _, value) => {
let actual_value = value.eval(&env)?;
env = env.extend(name.clone(), actual_value);
}
Statement::Print(_, name) => {
Statement::Print(_, _, name) => {
let value = env.lookup(name.clone())?;
let line = format!("{} = {}\n", name, value);
stdout.push_str(&line);
@@ -34,26 +34,21 @@ impl Program {
impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
Expression::Value(_, v) => match v {
super::Value::Number(_, v) => Ok(Value::I64(*v)),
},
Expression::Atomic(x) => x.eval(env),
Expression::Reference(_, n) => Ok(env.lookup(n.clone())?),
Expression::Cast(_, t, valref) => {
let value = valref.eval(env)?;
Expression::Primitive(_, op, args) => {
let mut arg_values = Vec::with_capacity(args.len());
// we implement primitive operations by first evaluating each of the
// arguments to the function, and then gathering up all the values
// produced.
for arg in args.iter() {
match arg {
ValueOrRef::Ref(_, n) => arg_values.push(env.lookup(n.clone())?),
ValueOrRef::Value(_, super::Value::Number(_, v)) => {
arg_values.push(Value::I64(*v))
}
}
match t {
Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
}
}
Expression::Primitive(_, _, op, args) => {
let arg_values = args
.iter()
.map(|x| x.eval(env))
.collect::<Result<Vec<Value>, EvalError>>()?;
// and then finally we call `calculate` to run them. trust me, it's nice
// to not have to deal with all the nonsense hidden under `calculate`.
@@ -68,19 +63,38 @@ impl Expression {
}
}
impl ValueOrRef {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
ValueOrRef::Value(_, _, v) => match v {
super::Value::I8(_, v) => Ok(Value::I8(*v)),
super::Value::I16(_, v) => Ok(Value::I16(*v)),
super::Value::I32(_, v) => Ok(Value::I32(*v)),
super::Value::I64(_, v) => Ok(Value::I64(*v)),
super::Value::U8(_, v) => Ok(Value::U8(*v)),
super::Value::U16(_, v) => Ok(Value::U16(*v)),
super::Value::U32(_, v) => Ok(Value::U32(*v)),
super::Value::U64(_, v) => Ok(Value::U64(*v)),
},
ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?),
}
}
}
#[test]
fn two_plus_three() {
let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let ir = Program::from(input);
let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output);
assert_eq!("x = 5u64\n", &output);
}
#[test]
fn lotsa_math() {
let input =
crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let ir = Program::from(input);
let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output);
assert_eq!("x = 7u64\n", &output);
}

View File

@@ -1,186 +0,0 @@
use internment::ArcIntern;
use std::sync::atomic::AtomicUsize;
use crate::ir::ast as ir;
use crate::syntax;
use super::ValueOrRef;
impl From<syntax::Program> for ir::Program {
/// We implement the top-level conversion of a syntax::Program into an
/// ir::Program using just the standard `From::from`, because we don't
/// need to return any arguments and we shouldn't produce any errors.
/// Technically there's an `unwrap` deep under the hood that we could
/// float out, but the validator really should've made sure that never
/// happens, so we're just going to assume.
fn from(mut value: syntax::Program) -> Self {
let mut statements = Vec::new();
for stmt in value.statements.drain(..) {
statements.append(&mut stmt.simplify());
}
ir::Program { statements }
}
}
impl From<syntax::Statement> for ir::Program {
/// One interesting thing about this conversion is that there isn't
/// a natural translation from syntax::Statement to ir::Statement,
/// because the syntax version can have nested expressions and the
/// IR version can't.
///
/// As a result, we can naturally convert a syntax::Statement into
/// an ir::Program, because we can allow the additional binding
/// sites to be generated, instead. And, bonus, it turns out that
/// this is what we wanted anyways.
fn from(value: syntax::Statement) -> Self {
ir::Program {
statements: value.simplify(),
}
}
}
impl syntax::Statement {
/// Simplify a syntax::Statement into a series of ir::Statements.
///
/// The reason this function is one-to-many is because we may have to
/// introduce new binding sites in order to avoid having nested
/// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed
/// in syntax::Expression but are expressly *not* allowed in
/// ir::Expression. So this pass converts them into bindings, like
/// this:
///
/// x = (1 + 2) * 3;
///
/// ==>
///
/// x:1 = 1 + 2;
/// x:2 = x:1 * 3;
/// x = x:2
///
/// Thus ensuring that things are nice and simple. Note that the
/// binding of `x:2` is not, strictly speaking, necessary, but it
/// makes the code below much easier to read.
fn simplify(self) -> Vec<ir::Statement> {
let mut new_statements = vec![];
match self {
// Print statements we don't have to do much with
syntax::Statement::Print(loc, name) => {
new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name)))
}
// Bindings, however, may involve a single expression turning into
// a series of statements and then an expression.
syntax::Statement::Binding(loc, name, value) => {
let (mut prereqs, new_value) = value.rebind(&name);
new_statements.append(&mut prereqs);
new_statements.push(ir::Statement::Binding(
loc,
ArcIntern::new(name),
new_value.into(),
))
}
}
new_statements
}
}
impl syntax::Expression {
/// This actually does the meat of the simplification work, here, by rebinding
/// any nested expressions into their own variables. We have this return
/// `ValueOrRef` in all cases because it makes for slighly less code; in the
/// case when we actually want an `Expression`, we can just use `into()`.
fn rebind(self, base_name: &str) -> (Vec<ir::Statement>, ir::ValueOrRef) {
match self {
// Values just convert in the obvious way, and require no prereqs
syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())),
// Similarly, references just convert in the obvious way, and require
// no prereqs
syntax::Expression::Reference(loc, name) => {
(vec![], ValueOrRef::Ref(loc, ArcIntern::new(name)))
}
// Primitive expressions are where we do the real work.
syntax::Expression::Primitive(loc, prim, mut expressions) => {
// generate a fresh new name for the binding site we're going to
// introduce, basing the name on wherever we came from; so if this
// expression was bound to `x` originally, it might become `x:23`.
//
// gensym is guaranteed to give us a name that is unused anywhere
// else in the program.
let new_name = gensym(base_name);
let mut prereqs = Vec::new();
let mut new_exprs = Vec::new();
// here we loop through every argument, and recurse on the expressions
// we find. that will give us any new binding sites that *they* introduce,
// and a simple value or reference that we can use in our result.
for expr in expressions.drain(..) {
let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str());
prereqs.append(&mut cur_prereqs);
new_exprs.push(arg);
}
// now we're going to use those new arguments to run the primitive, binding
// the results to the new variable we introduced.
let prim =
ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function");
prereqs.push(ir::Statement::Binding(
loc.clone(),
new_name.clone(),
ir::Expression::Primitive(loc.clone(), prim, new_exprs),
));
// and finally, we can return all the new bindings, and a reference to
// the variable we just introduced to hold the value of the primitive
// invocation.
(prereqs, ValueOrRef::Ref(loc, new_name))
}
}
}
}
impl From<syntax::Value> for ir::Value {
fn from(value: syntax::Value) -> Self {
match value {
syntax::Value::Number(base, val) => ir::Value::Number(base, val),
}
}
}
impl From<String> for ir::Primitive {
fn from(value: String) -> Self {
value.try_into().unwrap()
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input: syntax::Program) {
let syntax_result = input.eval();
let ir = ir::Program::from(input);
let ir_result = ir.eval();
assert_eq!(syntax_result, ir_result);
}
}

View File

@@ -21,12 +21,12 @@ impl Program {
impl Statement {
fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) {
match self {
Statement::Binding(_, name, expr) => {
Statement::Binding(_, name, _, expr) => {
string_set.insert(name.clone());
expr.register_strings(string_set);
}
Statement::Print(_, name) => {
Statement::Print(_, _, name) => {
string_set.insert(name.clone());
}
}

View File

@@ -63,8 +63,11 @@
//!
pub mod backend;
pub mod eval;
#[cfg(test)]
mod examples;
pub mod ir;
pub mod syntax;
pub mod type_infer;
/// Implementation module for the high-level compiler.
mod compiler;

View File

@@ -1,6 +1,6 @@
use crate::backend::{Backend, BackendError};
use crate::ir::Program as IR;
use crate::syntax::{Location, ParserError, Statement};
use crate::syntax::{ConstantType, Location, ParserError, Statement};
use crate::type_infer::TypeInferenceResult;
use codespan_reporting::diagnostic::Diagnostic;
use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::{self, Config};
@@ -129,17 +129,32 @@ impl REPL {
.source();
let syntax = Statement::parse(entry, source)?;
// if this is a variable binding, and we've never defined this variable before,
// we should tell cranelift about it. this is optimistic; if we fail to compile,
// then we won't use this definition until someone tries again.
if let Statement::Binding(_, ref name, _) = syntax {
if !self.variable_binding_sites.contains_key(name.as_str()) {
self.jitter.define_string(name)?;
self.jitter.define_variable(name.clone())?;
let program = match syntax {
Statement::Binding(loc, name, expr) => {
// if this is a variable binding, and we've never defined this variable before,
// we should tell cranelift about it. this is optimistic; if we fail to compile,
// then we won't use this definition until someone tries again.
if !self.variable_binding_sites.contains_key(&name.name) {
self.jitter.define_string(&name.name)?;
self.jitter
.define_variable(name.to_string(), ConstantType::U64)?;
}
crate::syntax::Program {
statements: vec![
Statement::Binding(loc.clone(), name.clone(), expr),
Statement::Print(loc, name),
],
}
}
nonbinding => crate::syntax::Program {
statements: vec![nonbinding],
},
};
let (mut errors, mut warnings) = syntax.validate(&mut self.variable_binding_sites);
let (mut errors, mut warnings) =
program.validate_with_bindings(&mut self.variable_binding_sites);
let stop = !errors.is_empty();
let messages = errors
.drain(..)
@@ -154,13 +169,39 @@ impl REPL {
return Ok(());
}
let ir = IR::from(syntax);
let name = format!("line{}", line_no);
let function_id = self.jitter.compile_function(&name, ir)?;
self.jitter.module.finalize_definitions()?;
let compiled_bytes = self.jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(())
match program.type_infer() {
TypeInferenceResult::Failure {
mut errors,
mut warnings,
} => {
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit_diagnostic(message)?;
}
Ok(())
}
TypeInferenceResult::Success {
result,
mut warnings,
} => {
for message in warnings.drain(..).map(Into::into) {
self.emit_diagnostic(message)?;
}
let name = format!("line{}", line_no);
let function_id = self.jitter.compile_function(&name, result)?;
self.jitter.module.finalize_definitions()?;
let compiled_bytes = self.jitter.bytes(function_id);
let compiled_function =
unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function();
Ok(())
}
}
}
}

View File

@@ -27,7 +27,7 @@ use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles};
use lalrpop_util::lalrpop_mod;
use logos::Logos;
mod arbitrary;
pub mod arbitrary;
mod ast;
mod eval;
mod location;
@@ -40,6 +40,8 @@ lalrpop_mod!(
mod pretty;
mod validate;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
pub use crate::syntax::ast::*;
pub use crate::syntax::location::Location;
pub use crate::syntax::parser::{ProgramParser, StatementParser};
@@ -48,7 +50,7 @@ pub use crate::syntax::tokens::{LexerError, Token};
use ::pretty::{Arena, Pretty};
use lalrpop_util::ParseError;
#[cfg(test)]
use proptest::{prop_assert, prop_assert_eq};
use proptest::{arbitrary::Arbitrary, prop_assert, prop_assert_eq};
#[cfg(test)]
use std::str::FromStr;
use thiserror::Error;
@@ -73,12 +75,12 @@ pub enum ParserError {
/// Raised when we're parsing the file, and run into a token in a
/// place we weren't expecting it.
#[error("Unrecognized token")]
UnrecognizedToken(Location, Location, Token, Vec<String>),
UnrecognizedToken(Location, Token, Vec<String>),
/// Raised when we were expecting the end of the file, but instead
/// got another token.
#[error("Extra token")]
ExtraToken(Location, Token, Location),
ExtraToken(Location, Token),
/// Raised when the lexer just had some sort of internal problem
/// and just gave up.
@@ -106,30 +108,24 @@ impl ParserError {
fn convert(file_idx: usize, err: ParseError<usize, Token, LexerError>) -> Self {
match err {
ParseError::InvalidToken { location } => {
ParserError::InvalidToken(Location::new(file_idx, location))
}
ParseError::UnrecognizedEof { location, expected } => {
ParserError::UnrecognizedEOF(Location::new(file_idx, location), expected)
ParserError::InvalidToken(Location::new(file_idx, location..location + 1))
}
ParseError::UnrecognizedEof { location, expected } => ParserError::UnrecognizedEOF(
Location::new(file_idx, location..location + 1),
expected,
),
ParseError::UnrecognizedToken {
token: (start, token, end),
expected,
} => ParserError::UnrecognizedToken(
Location::new(file_idx, start),
Location::new(file_idx, end),
token,
expected,
),
} => {
ParserError::UnrecognizedToken(Location::new(file_idx, start..end), token, expected)
}
ParseError::ExtraToken {
token: (start, token, end),
} => ParserError::ExtraToken(
Location::new(file_idx, start),
token,
Location::new(file_idx, end),
),
} => ParserError::ExtraToken(Location::new(file_idx, start..end), token),
ParseError::User { error } => match error {
LexerError::LexFailure(offset) => {
ParserError::LexFailure(Location::new(file_idx, offset))
ParserError::LexFailure(Location::new(file_idx, offset..offset + 1))
}
},
}
@@ -180,37 +176,25 @@ impl<'a> From<&'a ParserError> for Diagnostic<usize> {
),
// encountered a token where it shouldn't be
ParserError::UnrecognizedToken(start, end, token, expected) => {
ParserError::UnrecognizedToken(loc, token, expected) => {
let expected_str =
format!("unexpected token {}{}", token, display_expected(expected));
let unexpected_str = format!("unexpected token {}", token);
let labels = start.range_label(end);
Diagnostic::error()
.with_labels(
labels
.into_iter()
.map(|l| l.with_message(unexpected_str.clone()))
.collect(),
)
.with_message(expected_str)
.with_labels(vec![loc.primary_label().with_message(unexpected_str)])
}
// I think we get this when we get a token, but were expected EOF
ParserError::ExtraToken(start, token, end) => {
ParserError::ExtraToken(loc, token) => {
let expected_str =
format!("unexpected token {} after the expected end of file", token);
let unexpected_str = format!("unexpected token {}", token);
let labels = start.range_label(end);
Diagnostic::error()
.with_labels(
labels
.into_iter()
.map(|l| l.with_message(unexpected_str.clone()))
.collect(),
)
.with_message(expected_str)
.with_labels(vec![loc.primary_label().with_message(unexpected_str)])
}
// simple lexer errors
@@ -293,24 +277,27 @@ fn order_of_operations() {
Program::from_str(muladd1).unwrap(),
Program {
statements: vec![Statement::Binding(
Location::new(testfile, 0),
"x".to_string(),
Location::new(testfile, 0..1),
Name::manufactured("x"),
Expression::Primitive(
Location::new(testfile, 6),
Location::new(testfile, 6..7),
"+".to_string(),
vec![
Expression::Value(Location::new(testfile, 4), Value::Number(None, 1)),
Expression::Value(
Location::new(testfile, 4..5),
Value::Number(None, None, 1),
),
Expression::Primitive(
Location::new(testfile, 10),
Location::new(testfile, 10..11),
"*".to_string(),
vec![
Expression::Value(
Location::new(testfile, 8),
Value::Number(None, 2),
Location::new(testfile, 8..9),
Value::Number(None, None, 2),
),
Expression::Value(
Location::new(testfile, 12),
Value::Number(None, 3),
Location::new(testfile, 12..13),
Value::Number(None, None, 3),
),
]
)
@@ -350,8 +337,8 @@ proptest::proptest! {
}
#[test]
fn generated_run_or_overflow(program: Program) {
fn generated_run_or_overflow(program in Program::arbitrary_with(GenerationEnvironment::new(false))) {
use crate::eval::{EvalError, PrimOpError};
assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))))
prop_assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))));
}
}

View File

@@ -1,136 +1,189 @@
use std::collections::HashSet;
use crate::syntax::ast::{Expression, Program, Statement, Value};
use crate::syntax::ast::{ConstantType, Expression, Name, Program, Statement, Value};
use crate::syntax::location::Location;
use proptest::sample::select;
use proptest::{
prelude::{Arbitrary, BoxedStrategy, Strategy},
strategy::{Just, Union},
};
use std::collections::HashMap;
use std::ops::Range;
const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*";
#[derive(Debug)]
struct Name(String);
impl ConstantType {
fn get_operators(&self) -> &'static [(&'static str, usize)] {
match self {
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => {
&[("+", 2), ("-", 1), ("-", 2), ("*", 2), ("/", 2)]
}
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => {
&[("+", 2), ("-", 2), ("*", 2), ("/", 2)]
}
}
}
}
#[derive(Clone)]
pub struct GenerationEnvironment {
allow_inference: bool,
block_length: Range<usize>,
bindings: HashMap<Name, ConstantType>,
return_type: ConstantType,
}
impl Default for GenerationEnvironment {
fn default() -> Self {
GenerationEnvironment {
allow_inference: true,
block_length: 2..10,
bindings: HashMap::new(),
return_type: ConstantType::U64,
}
}
}
impl GenerationEnvironment {
pub fn new(allow_inference: bool) -> Self {
GenerationEnvironment {
allow_inference,
..Default::default()
}
}
}
impl Arbitrary for Program {
type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
proptest::collection::vec(
ProgramStatementInfo::arbitrary(),
genenv.block_length.clone(),
)
.prop_flat_map(move |mut items| {
let mut statements = Vec::new();
let mut genenv = genenv.clone();
for psi in items.drain(..) {
if genenv.bindings.is_empty() || psi.should_be_binding {
genenv.return_type = psi.binding_type;
let expr = Expression::arbitrary_with(genenv.clone());
genenv.bindings.insert(psi.name.clone(), psi.binding_type);
statements.push(
expr.prop_map(move |expr| {
Statement::Binding(Location::manufactured(), psi.name.clone(), expr)
})
.boxed(),
);
} else {
let printers = genenv.bindings.keys().map(|n| {
Just(Statement::Print(
Location::manufactured(),
Name::manufactured(n),
))
});
statements.push(Union::new(printers).boxed());
}
}
statements
.prop_map(|statements| Program { statements })
.boxed()
})
.boxed()
}
}
impl Arbitrary for Name {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
VALID_VARIABLE_NAMES.prop_map(Name).boxed()
VALID_VARIABLE_NAMES.prop_map(Name::manufactured).boxed()
}
}
impl Arbitrary for Program {
#[derive(Debug)]
struct ProgramStatementInfo {
should_be_binding: bool,
name: Name,
binding_type: ConstantType,
}
impl Arbitrary for ProgramStatementInfo {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let optionals = Vec::<Option<Name>>::arbitrary();
optionals
.prop_flat_map(|mut possible_names| {
let mut statements = Vec::new();
let mut defined_variables: HashSet<String> = HashSet::new();
for possible_name in possible_names.drain(..) {
match possible_name {
None if defined_variables.is_empty() => continue,
None => statements.push(
Union::new(defined_variables.iter().map(|name| {
Just(Statement::Print(Location::manufactured(), name.to_string()))
}))
.boxed(),
),
Some(new_name) => {
let closures_name = new_name.0.clone();
let retval =
Expression::arbitrary_with(Some(defined_variables.clone()))
.prop_map(move |exp| {
Statement::Binding(
Location::manufactured(),
closures_name.clone(),
exp,
)
})
.boxed();
defined_variables.insert(new_name.0);
statements.push(retval);
}
}
}
statements
})
.prop_map(|statements| Program { statements })
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(
Union::new(vec![Just(true), Just(true), Just(false)]),
Name::arbitrary(),
ConstantType::arbitrary(),
)
.prop_map(
|(should_be_binding, name, binding_type)| ProgramStatementInfo {
should_be_binding,
name,
binding_type,
},
)
.boxed()
}
}
impl Arbitrary for Statement {
type Parameters = Option<HashSet<String>>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let duplicated_args = args.clone();
let defined_variables = args.unwrap_or_default();
let binding_strategy = (
VALID_VARIABLE_NAMES,
Expression::arbitrary_with(duplicated_args),
)
.prop_map(|(name, exp)| Statement::Binding(Location::manufactured(), name, exp))
.boxed();
if defined_variables.is_empty() {
binding_strategy
} else {
let print_strategy = Union::new(
defined_variables
.iter()
.map(|x| Just(Statement::Print(Location::manufactured(), x.to_string()))),
)
.boxed();
Union::new([binding_strategy, print_strategy]).boxed()
}
}
}
impl Arbitrary for Expression {
type Parameters = Option<HashSet<String>>;
type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let defined_variables = args.unwrap_or_default();
let value_strategy = Value::arbitrary()
.prop_map(move |x| Expression::Value(Location::manufactured(), x))
fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
// Value(Location, Value). These are the easiest variations to create, because we can always
// create one.
let value_strategy = Value::arbitrary_with(genenv.clone())
.prop_map(|x| Expression::Value(Location::manufactured(), x))
.boxed();
let leaf_strategy = if defined_variables.is_empty() {
// Reference(Location, String), These are slightly trickier, because we can end up in a situation
// where either no variables are defined, or where none of the defined variables have a type we
// can work with. So what we're going to do is combine this one with the previous one as a "leaf
// strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy
// if we can't actually create an references.
let mut bound_variables_of_type = genenv
.bindings
.iter()
.filter(|(_, v)| genenv.return_type == **v)
.map(|(n, _)| n)
.collect::<Vec<_>>();
let leaf_strategy = if bound_variables_of_type.is_empty() {
value_strategy
} else {
let reference_strategy = Union::new(defined_variables.iter().map(|x| {
Just(Expression::Reference(
Location::manufactured(),
x.to_owned(),
))
}))
.boxed();
Union::new([value_strategy, reference_strategy]).boxed()
let mut strats = bound_variables_of_type
.drain(..)
.map(|x| {
Just(Expression::Reference(
Location::manufactured(),
x.name.clone(),
))
.boxed()
})
.collect::<Vec<_>>();
strats.push(value_strategy);
Union::new(strats).boxed()
};
// now we generate our recursive types, given our leaf strategy
leaf_strategy
.prop_recursive(3, 64, 2, move |inner| {
.prop_recursive(3, 10, 2, move |strat| {
(
select(super::BINARY_OPERATORS),
proptest::collection::vec(inner, 2),
select(genenv.return_type.get_operators()),
strat.clone(),
strat,
)
.prop_map(move |(operator, exprs)| {
Expression::Primitive(Location::manufactured(), operator.to_string(), exprs)
.prop_map(|((oper, count), left, right)| {
let mut args = vec![left, right];
while args.len() > count {
args.pop();
}
Expression::Primitive(Location::manufactured(), oper.to_string(), args)
})
})
.boxed()
@@ -138,22 +191,57 @@ impl Arbitrary for Expression {
}
impl Arbitrary for Value {
type Parameters = ();
type Parameters = GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let base_strategy = Union::new([
fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy {
let printed_base_strategy = Union::new([
Just(None::<u8>),
Just(Some(2)),
Just(Some(8)),
Just(Some(10)),
Just(Some(16)),
]);
let value_strategy = u64::arbitrary();
let value_strategy = i64::arbitrary();
(base_strategy, value_strategy)
.prop_map(move |(base, value)| Value::Number(base, value))
(printed_base_strategy, bool::arbitrary(), value_strategy)
.prop_map(move |(base, declare_type, value)| {
let converted_value = match genenv.return_type {
ConstantType::I8 => value % (i8::MAX as u64),
ConstantType::U8 => value % (u8::MAX as u64),
ConstantType::I16 => value % (i16::MAX as u64),
ConstantType::U16 => value % (u16::MAX as u64),
ConstantType::I32 => value % (i32::MAX as u64),
ConstantType::U32 => value % (u32::MAX as u64),
ConstantType::I64 => value % (i64::MAX as u64),
ConstantType::U64 => value,
};
let ty = if declare_type || !genenv.allow_inference {
Some(genenv.return_type)
} else {
None
};
Value::Number(base, ty, converted_value)
})
.boxed()
}
}
impl Arbitrary for ConstantType {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
Union::new([
Just(ConstantType::I8),
Just(ConstantType::I16),
Just(ConstantType::I32),
Just(ConstantType::I64),
Just(ConstantType::U8),
Just(ConstantType::U16),
Just(ConstantType::U32),
Just(ConstantType::U64),
])
.boxed()
}
}

View File

@@ -1,7 +1,10 @@
use crate::syntax::Location;
use std::fmt;
use std::hash::Hash;
/// The set of valid binary operators.
pub static BINARY_OPERATORS: &[&str] = &["+", "-", "*", "/"];
use internment::ArcIntern;
pub use crate::syntax::tokens::ConstantType;
use crate::syntax::Location;
/// A structure represented a parsed program.
///
@@ -16,6 +19,56 @@ pub struct Program {
pub statements: Vec<Statement>,
}
/// A Name.
///
/// This is basically a string, but annotated with the place the string
/// is in the source file.
#[derive(Clone, Debug)]
pub struct Name {
pub name: String,
pub location: Location,
}
impl Name {
pub fn new<S: ToString>(n: S, location: Location) -> Name {
Name {
name: n.to_string(),
location,
}
}
pub fn manufactured<S: ToString>(n: S) -> Name {
Name {
name: n.to_string(),
location: Location::manufactured(),
}
}
pub fn intern(self) -> ArcIntern<String> {
ArcIntern::new(self.name)
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state)
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name.fmt(f)
}
}
/// A parsed statement.
///
/// Statements are guaranteed to be syntactically valid, but may be
@@ -29,8 +82,8 @@ pub struct Program {
/// thing, not if they are the exact same statement.
#[derive(Clone, Debug)]
pub enum Statement {
Binding(Location, String, Expression),
Print(Location, String),
Binding(Location, Name, Expression),
Print(Location, Name),
}
impl PartialEq for Statement {
@@ -58,6 +111,7 @@ impl PartialEq for Statement {
pub enum Expression {
Value(Location, Value),
Reference(Location, String),
Cast(Location, String, Box<Expression>),
Primitive(Location, String, Vec<Expression>),
}
@@ -72,6 +126,10 @@ impl PartialEq for Expression {
Expression::Reference(_, var2) => var1 == var2,
_ => false,
},
Expression::Cast(_, t1, e1) => match other {
Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2,
_ => false,
},
Expression::Primitive(_, prim1, args1) => match other {
Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2,
_ => false,
@@ -83,6 +141,12 @@ impl PartialEq for Expression {
/// A value from the source syntax
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Value {
/// The value of the number, and an optional base that it was written in
Number(Option<u8>, i64),
/// The value of the number, an optional base that it was written in, and any
/// type information provided.
///
/// u64 is chosen because it should be big enough to carry the amount of
/// information we need, and technically we interpret -4 as the primitive unary
/// operation "-" on the number 4. We'll translate this into a type-specific
/// number at a later time.
Number(Option<u8>, Option<ConstantType>, u64),
}

View File

@@ -1,7 +1,8 @@
use internment::ArcIntern;
use crate::eval::{EvalEnvironment, EvalError, Value};
use crate::syntax::{Expression, Program, Statement};
use crate::eval::{EvalEnvironment, EvalError, PrimitiveType, Value};
use crate::syntax::{ConstantType, Expression, Program, Statement};
use std::str::FromStr;
impl Program {
/// Evaluate the program, returning either an error or what it prints out when run.
@@ -24,11 +25,11 @@ impl Program {
match stmt {
Statement::Binding(_, name, value) => {
let actual_value = value.eval(&env)?;
env = env.extend(ArcIntern::new(name.clone()), actual_value);
env = env.extend(name.clone().intern(), actual_value);
}
Statement::Print(_, name) => {
let value = env.lookup(ArcIntern::new(name.clone()))?;
let value = env.lookup(name.clone().intern())?;
let line = format!("{} = {}\n", name, value);
stdout.push_str(&line);
}
@@ -43,11 +44,28 @@ impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
Expression::Value(_, v) => match v {
super::Value::Number(_, v) => Ok(Value::I64(*v)),
super::Value::Number(_, ty, v) => match ty {
None => Ok(Value::U64(*v)),
// FIXME: make these types validate their input size
Some(ConstantType::I8) => Ok(Value::I8(*v as i8)),
Some(ConstantType::I16) => Ok(Value::I16(*v as i16)),
Some(ConstantType::I32) => Ok(Value::I32(*v as i32)),
Some(ConstantType::I64) => Ok(Value::I64(*v as i64)),
Some(ConstantType::U8) => Ok(Value::U8(*v as u8)),
Some(ConstantType::U16) => Ok(Value::U16(*v as u16)),
Some(ConstantType::U32) => Ok(Value::U32(*v as u32)),
Some(ConstantType::U64) => Ok(Value::U64(*v)),
},
},
Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?),
Expression::Cast(_, target, expr) => {
let target_type = PrimitiveType::from_str(target)?;
let value = expr.eval(env)?;
Ok(target_type.safe_cast(&value)?)
}
Expression::Primitive(_, op, args) => {
let mut arg_values = Vec::with_capacity(args.len());
@@ -66,12 +84,12 @@ impl Expression {
fn two_plus_three() {
let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let output = input.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output);
assert_eq!("x = 5u64\n", &output);
}
#[test]
fn lotsa_math() {
let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let output = input.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output);
assert_eq!("x = 7u64\n", &output);
}

View File

@@ -1,13 +1,15 @@
use std::ops::Range;
use codespan_reporting::diagnostic::{Diagnostic, Label};
/// A source location, for use in pointing users towards warnings and errors.
///
/// Internally, locations are very tied to the `codespan_reporting` library,
/// and the primary use of them is to serve as anchors within that library.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Location {
file_idx: usize,
offset: usize,
location: Range<usize>,
}
impl Location {
@@ -17,8 +19,8 @@ impl Location {
/// The file index is based on the file database being used. See the
/// `codespan_reporting::files::SimpleFiles::add` function, which is
/// normally where we get this index.
pub fn new(file_idx: usize, offset: usize) -> Self {
Location { file_idx, offset }
pub fn new(file_idx: usize, location: Range<usize>) -> Self {
Location { file_idx, location }
}
/// Generate a `Location` for a completely manufactured bit of code.
@@ -30,7 +32,7 @@ impl Location {
pub fn manufactured() -> Self {
Location {
file_idx: 0,
offset: 0,
location: 0..0,
}
}
@@ -47,7 +49,7 @@ impl Location {
/// actually happened), but you'd probably want to make the first location
/// the secondary label to help users find it.
pub fn primary_label(&self) -> Label<usize> {
Label::primary(self.file_idx, self.offset..self.offset)
Label::primary(self.file_idx, self.location.clone())
}
/// Generate a secondary label for a [`Diagnostic`], based on this source
@@ -64,35 +66,7 @@ impl Location {
/// probably want to make the first location the secondary label to help
/// users find it.
pub fn secondary_label(&self) -> Label<usize> {
Label::secondary(self.file_idx, self.offset..self.offset)
}
/// Given this location and another, generate a primary label that
/// specifies the area between those two locations.
///
/// See [`Self::primary_label`] for some discussion of primary versus
/// secondary labels. If the two locations are the same, this method does
/// the exact same thing as [`Self::primary_label`]. If this item was
/// generated by [`Self::manufactured`], it will act as if you'd called
/// `primary_label` on the argument. Otherwise, it will generate the obvious
/// span.
///
/// This function will return `None` only in the case that you provide
/// labels from two different files, which it cannot sensibly handle.
pub fn range_label(&self, end: &Location) -> Option<Label<usize>> {
if self.file_idx == 0 {
return Some(end.primary_label());
}
if self.file_idx != end.file_idx {
return None;
}
if self.offset > end.offset {
Some(Label::primary(self.file_idx, end.offset..self.offset))
} else {
Some(Label::primary(self.file_idx, self.offset..end.offset))
}
Label::secondary(self.file_idx, self.location.clone())
}
/// Return an error diagnostic centered at this location.
@@ -102,10 +76,7 @@ impl Location {
/// this particular location. You'll need to extend it with actually useful
/// information, like what kind of error it is.
pub fn error(&self) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![Label::primary(
self.file_idx,
self.offset..self.offset,
)])
Diagnostic::error().with_labels(vec![Label::primary(self.file_idx, self.location.clone())])
}
/// Return an error diagnostic centered at this location, with the given message.
@@ -115,10 +86,34 @@ impl Location {
/// even more information to ut, using [`Diagnostic::with_labels`],
/// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`].
pub fn labelled_error(&self, msg: &str) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![Label::primary(
self.file_idx,
self.offset..self.offset,
)
.with_message(msg)])
Diagnostic::error().with_labels(vec![
Label::primary(self.file_idx, self.location.clone()).with_message(msg)
])
}
/// Merge two locations into a single location spanning the whole range between
/// them.
///
/// This function returns None if the locations are from different files; this
/// can happen if one of the locations is manufactured, for example.
pub fn merge(&self, other: &Self) -> Option<Self> {
if self.file_idx != other.file_idx {
None
} else {
let start = if self.location.start <= other.location.start {
self.location.start
} else {
other.location.start
};
let end = if self.location.end >= other.location.end {
self.location.end
} else {
other.location.end
};
Some(Location {
file_idx: self.file_idx,
location: start..end,
})
}
}
}

View File

@@ -9,8 +9,8 @@
//! eventually want to leave lalrpop behind.)
//!
use crate::syntax::{LexerError, Location};
use crate::syntax::ast::{Program,Statement,Expression,Value};
use crate::syntax::tokens::Token;
use crate::syntax::ast::{Program,Statement,Expression,Value,Name};
use crate::syntax::tokens::{ConstantType, Token};
use internment::ArcIntern;
// one cool thing about lalrpop: we can pass arguments. in this case, the
@@ -32,6 +32,8 @@ extern {
";" => Token::Semi,
"(" => Token::LeftParen,
")" => Token::RightParen,
"<" => Token::LessThan,
">" => Token::GreaterThan,
"print" => Token::Print,
@@ -44,7 +46,7 @@ extern {
// to name and use "their value", you get their source location.
// For these, we want "their value" to be their actual contents,
// which is why we put their types in angle brackets.
"<num>" => Token::Number((<Option<u8>>,<i64>)),
"<num>" => Token::Number((<Option<u8>>,<Option<ConstantType>>,<u64>)),
"<var>" => Token::Variable(<ArcIntern<String>>),
}
}
@@ -89,10 +91,19 @@ pub Statement: Statement = {
// A statement can be a variable binding. Note, here, that we use this
// funny @L thing to get the source location before the variable, so that
// we can say that this statement spans across everything.
<l:@L> <v:"<var>"> "=" <e:Expression> ";" => Statement::Binding(Location::new(file_idx, l), v.to_string(), e),
<ls: @L> <v:"<var>"> <var_end: @L> "=" <e:Expression> ";" <le: @L> =>
Statement::Binding(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, ls..var_end)),
e,
),
// Alternatively, a statement can just be a print statement.
"print" <l:@L> <v:"<var>"> ";" => Statement::Print(Location::new(file_idx, l), v.to_string()),
<ls: @L> "print" <name_start: @L> <v:"<var>"> <name_end: @L> ";" <le: @L> =>
Statement::Print(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, name_start..name_end)),
),
}
// Expressions! Expressions are a little fiddly, because we're going to
@@ -124,15 +135,27 @@ Expression: Expression = {
// we group addition and subtraction under the heading "additive"
AdditiveExpression: Expression = {
<e1:AdditiveExpression> <l:@L> "+" <e2:MultiplicativeExpression> => Expression::Primitive(Location::new(file_idx, l), "+".to_string(), vec![e1, e2]),
<e1:AdditiveExpression> <l:@L> "-" <e2:MultiplicativeExpression> => Expression::Primitive(Location::new(file_idx, l), "-".to_string(), vec![e1, e2]),
<ls: @L> <e1:AdditiveExpression> <l: @L> "+" <e2:MultiplicativeExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "+".to_string(), vec![e1, e2]),
<ls: @L> <e1:AdditiveExpression> <l: @L> "-" <e2:MultiplicativeExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "-".to_string(), vec![e1, e2]),
MultiplicativeExpression,
}
// similarly, we group multiplication and division under "multiplicative"
MultiplicativeExpression: Expression = {
<e1:MultiplicativeExpression> <l:@L> "*" <e2:AtomicExpression> => Expression::Primitive(Location::new(file_idx, l), "*".to_string(), vec![e1, e2]),
<e1:MultiplicativeExpression> <l:@L> "/" <e2:AtomicExpression> => Expression::Primitive(Location::new(file_idx, l), "/".to_string(), vec![e1, e2]),
<ls: @L> <e1:MultiplicativeExpression> <l: @L> "*" <e2:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "*".to_string(), vec![e1, e2]),
<ls: @L> <e1:MultiplicativeExpression> <l: @L> "/" <e2:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "/".to_string(), vec![e1, e2]),
UnaryExpression,
}
UnaryExpression: Expression = {
<l: @L> "-" <e:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]),
<l: @L> "<" <v:"<var>"> ">" <e:UnaryExpression> <le: @L> =>
Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)),
AtomicExpression,
}
@@ -140,22 +163,9 @@ MultiplicativeExpression: Expression = {
// they cannot be further divided into parts
AtomicExpression: Expression = {
// just a variable reference
<l:@L> <v:"<var>"> => Expression::Reference(Location::new(file_idx, l), v.to_string()),
<l: @L> <v:"<var>"> <end: @L> => Expression::Reference(Location::new(file_idx, l..end), v.to_string()),
// just a number
<l:@L> <n:"<num>"> => {
let val = Value::Number(n.0, n.1);
Expression::Value(Location::new(file_idx, l), val)
},
// a tricky case: also just a number, but using a negative sign. an
// alternative way to do this -- and we may do this eventually -- is
// to implement a unary negation expression. this has the odd effect
// that the user never actually writes down a negative number; they just
// write positive numbers which are immediately sent to a negation
// primitive!
<l:@L> "-" <n:"<num>"> => {
let val = Value::Number(n.0, -n.1);
Expression::Value(Location::new(file_idx, l), val)
},
<l: @L> <n:"<num>"> <end: @L> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)),
// finally, let people parenthesize expressions and get back to a
// lower precedence
"(" <e:Expression> ")" => e,

View File

@@ -1,6 +1,8 @@
use crate::syntax::ast::{Expression, Program, Statement, Value, BINARY_OPERATORS};
use crate::syntax::ast::{Expression, Program, Statement, Value};
use pretty::{DocAllocator, DocBuilder, Pretty};
use super::ConstantType;
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
@@ -50,14 +52,14 @@ where
match self {
Expression::Value(_, val) => val.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.to_string()),
Expression::Primitive(_, op, exprs) if BINARY_OPERATORS.contains(&op.as_ref()) => {
assert_eq!(
exprs.len(),
2,
"Found binary operator with {} components?",
exprs.len()
);
Expression::Cast(_, t, e) => allocator
.text(t.clone())
.angles()
.append(e.pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => allocator
.text(op.to_string())
.append(exprs[0].pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
@@ -84,15 +86,14 @@ where
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
match self {
Value::Number(opt_base, value) => {
let sign = if *value < 0 { "-" } else { "" };
Value::Number(opt_base, ty, value) => {
let value_str = match opt_base {
None => format!("{}", value),
Some(2) => format!("{}0b{:b}", sign, value.abs()),
Some(8) => format!("{}0o{:o}", sign, value.abs()),
Some(10) => format!("{}0d{}", sign, value.abs()),
Some(16) => format!("{}0x{:x}", sign, value.abs()),
Some(_) => format!("!!{}{:x}!!", sign, value.abs()),
None => format!("{}{}", value, type_suffix(ty)),
Some(2) => format!("0b{:b}{}", value, type_suffix(ty)),
Some(8) => format!("0o{:o}{}", value, type_suffix(ty)),
Some(10) => format!("0d{}{}", value, type_suffix(ty)),
Some(16) => format!("0x{:x}{}", value, type_suffix(ty)),
Some(_) => format!("!!{:x}{}!!", value, type_suffix(ty)),
};
allocator.text(value_str)
@@ -101,6 +102,20 @@ where
}
}
fn type_suffix(x: &Option<ConstantType>) -> &'static str {
match x {
None => "",
Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32",
Some(ConstantType::I64) => "i64",
Some(ConstantType::U8) => "u8",
Some(ConstantType::U16) => "u16",
Some(ConstantType::U32) => "u32",
Some(ConstantType::U64) => "u64",
}
}
#[derive(Clone, Copy)]
struct CommaSep {}

View File

@@ -40,6 +40,12 @@ pub enum Token {
#[token(")")]
RightParen,
#[token("<")]
LessThan,
#[token(">")]
GreaterThan,
// Next we take of any reserved words; I always like to put
// these before we start recognizing more complicated regular
// expressions. I don't think it matters, but it works for me.
@@ -53,13 +59,14 @@ pub enum Token {
/// Numbers capture both the value we read from the input,
/// converted to an `i64`, as well as the base the user used
/// to write the number, if they did so.
#[regex(r"0b[01]+", |v| parse_number(Some(2), v))]
#[regex(r"0o[0-7]+", |v| parse_number(Some(8), v))]
#[regex(r"0d[0-9]+", |v| parse_number(Some(10), v))]
#[regex(r"0x[0-9a-fA-F]+", |v| parse_number(Some(16), v))]
#[regex(r"[0-9]+", |v| parse_number(None, v))]
Number((Option<u8>, i64)),
/// to write the number and/or the type the user specified,
/// if they did either.
#[regex(r"0b[01]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(2), v))]
#[regex(r"0o[0-7]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(8), v))]
#[regex(r"0d[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(10), v))]
#[regex(r"0x[0-9a-fA-F]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(16), v))]
#[regex(r"[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(None, v))]
Number((Option<u8>, Option<ConstantType>, u64)),
// Variables; this is a very standard, simple set of characters
// for variables, but feel free to experiment with more complicated
@@ -88,15 +95,29 @@ impl fmt::Display for Token {
Token::Semi => write!(f, "';'"),
Token::LeftParen => write!(f, "'('"),
Token::RightParen => write!(f, "')'"),
Token::LessThan => write!(f, "<"),
Token::GreaterThan => write!(f, ">"),
Token::Print => write!(f, "'print'"),
Token::Operator(c) => write!(f, "'{}'", c),
Token::Number((None, v)) => write!(f, "'{}'", v),
Token::Number((Some(2), v)) => write!(f, "'0b{:b}'", v),
Token::Number((Some(8), v)) => write!(f, "'0o{:o}'", v),
Token::Number((Some(10), v)) => write!(f, "'{}'", v),
Token::Number((Some(16), v)) => write!(f, "'0x{:x}'", v),
Token::Number((Some(b), v)) => {
write!(f, "Invalidly-based-number<base={},val={}>", b, v)
Token::Number((None, otype, v)) => write!(f, "'{}{}'", v, display_optional_type(otype)),
Token::Number((Some(2), otype, v)) => {
write!(f, "'0b{:b}{}'", v, display_optional_type(otype))
}
Token::Number((Some(8), otype, v)) => {
write!(f, "'0o{:o}{}'", v, display_optional_type(otype))
}
Token::Number((Some(10), otype, v)) => {
write!(f, "'{}{}'", v, display_optional_type(otype))
}
Token::Number((Some(16), otype, v)) => {
write!(f, "'0x{:x}{}'", v, display_optional_type(otype))
}
Token::Number((Some(b), opt_type, v)) => {
write!(
f,
"Invalidly-based-number<base={},val={},opt_type={:?}>",
b, v, opt_type
)
}
Token::Variable(s) => write!(f, "'{}'", s),
Token::Error => write!(f, "<error>"),
@@ -122,6 +143,125 @@ impl Token {
}
}
#[repr(i64)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ConstantType {
U8 = 10,
U16 = 11,
U32 = 12,
U64 = 13,
I8 = 20,
I16 = 21,
I32 = 22,
I64 = 23,
}
impl From<ConstantType> for cranelift_codegen::ir::Type {
fn from(value: ConstantType) -> Self {
match value {
ConstantType::I8 | ConstantType::U8 => cranelift_codegen::ir::types::I8,
ConstantType::I16 | ConstantType::U16 => cranelift_codegen::ir::types::I16,
ConstantType::I32 | ConstantType::U32 => cranelift_codegen::ir::types::I32,
ConstantType::I64 | ConstantType::U64 => cranelift_codegen::ir::types::I64,
}
}
}
impl ConstantType {
/// Returns true if the given type is (a) numeric and (b) signed;
pub fn is_signed(&self) -> bool {
matches!(
self,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64
)
}
/// Return the set of types that can be safely casted into this type.
pub fn safe_casts_to(self) -> Vec<ConstantType> {
match self {
ConstantType::I8 => vec![ConstantType::I8],
ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8],
ConstantType::I32 => vec![
ConstantType::I32,
ConstantType::I16,
ConstantType::I8,
ConstantType::U16,
ConstantType::U8,
],
ConstantType::I64 => vec![
ConstantType::I64,
ConstantType::I32,
ConstantType::I16,
ConstantType::I8,
ConstantType::U32,
ConstantType::U16,
ConstantType::U8,
],
ConstantType::U8 => vec![ConstantType::U8],
ConstantType::U16 => vec![ConstantType::U16, ConstantType::U8],
ConstantType::U32 => vec![ConstantType::U32, ConstantType::U16, ConstantType::U8],
ConstantType::U64 => vec![
ConstantType::U64,
ConstantType::U32,
ConstantType::U16,
ConstantType::U8,
],
}
}
/// Return the set of all currently-available constant types
pub fn all_types() -> Vec<Self> {
vec![
ConstantType::U8,
ConstantType::U16,
ConstantType::U32,
ConstantType::U64,
ConstantType::I8,
ConstantType::I16,
ConstantType::I32,
ConstantType::I64,
]
}
/// Return the name of the given type, as a string
pub fn name(&self) -> String {
match self {
ConstantType::I8 => "i8".to_string(),
ConstantType::I16 => "i16".to_string(),
ConstantType::I32 => "i32".to_string(),
ConstantType::I64 => "i64".to_string(),
ConstantType::U8 => "u8".to_string(),
ConstantType::U16 => "u16".to_string(),
ConstantType::U32 => "u32".to_string(),
ConstantType::U64 => "u64".to_string(),
}
}
}
#[derive(Debug, Error, PartialEq)]
pub enum InvalidConstantType {
#[error("Unrecognized constant {0} for constant type")]
Value(i64),
}
impl TryFrom<i64> for ConstantType {
type Error = InvalidConstantType;
fn try_from(value: i64) -> Result<Self, Self::Error> {
match value {
10 => Ok(ConstantType::U8),
11 => Ok(ConstantType::U16),
12 => Ok(ConstantType::U32),
13 => Ok(ConstantType::U64),
20 => Ok(ConstantType::I8),
21 => Ok(ConstantType::I16),
22 => Ok(ConstantType::I32),
23 => Ok(ConstantType::I64),
_ => Err(InvalidConstantType::Value(value)),
}
}
}
/// Parse a number in the given base, return a pair of the base and the
/// parsed number. This is just a helper used for all of the number
/// regular expression cases, which kicks off to the obvious Rust
@@ -129,24 +269,66 @@ impl Token {
fn parse_number(
base: Option<u8>,
value: &Lexer<Token>,
) -> Result<(Option<u8>, i64), ParseIntError> {
) -> Result<(Option<u8>, Option<ConstantType>, u64), ParseIntError> {
let (radix, strval) = match base {
None => (10, value.slice()),
Some(radix) => (radix, &value.slice()[2..]),
};
let intval = i64::from_str_radix(strval, radix as u32)?;
Ok((base, intval))
let (declared_type, strval) = if let Some(strval) = strval.strip_suffix("u8") {
(Some(ConstantType::U8), strval)
} else if let Some(strval) = strval.strip_suffix("u16") {
(Some(ConstantType::U16), strval)
} else if let Some(strval) = strval.strip_suffix("u32") {
(Some(ConstantType::U32), strval)
} else if let Some(strval) = strval.strip_suffix("u64") {
(Some(ConstantType::U64), strval)
} else if let Some(strval) = strval.strip_suffix("i8") {
(Some(ConstantType::I8), strval)
} else if let Some(strval) = strval.strip_suffix("i16") {
(Some(ConstantType::I16), strval)
} else if let Some(strval) = strval.strip_suffix("i32") {
(Some(ConstantType::I32), strval)
} else if let Some(strval) = strval.strip_suffix("i64") {
(Some(ConstantType::I64), strval)
} else {
(None, strval)
};
let intval = u64::from_str_radix(strval, radix as u32)?;
Ok((base, declared_type, intval))
}
fn display_optional_type(otype: &Option<ConstantType>) -> &'static str {
match otype {
None => "",
Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32",
Some(ConstantType::I64) => "i64",
Some(ConstantType::U8) => "u8",
Some(ConstantType::U16) => "u16",
Some(ConstantType::U32) => "u32",
Some(ConstantType::U64) => "u64",
}
}
#[test]
fn lex_numbers() {
let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc // 9");
assert_eq!(lex0.next(), Some(Token::Number((None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(2), 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(8), 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(10), 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(16), 12))));
let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc 12u8 0xci64// 9");
assert_eq!(lex0.next(), Some(Token::Number((None, None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(2), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(8), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(10), None, 12))));
assert_eq!(lex0.next(), Some(Token::Number((Some(16), None, 12))));
assert_eq!(
lex0.next(),
Some(Token::Number((None, Some(ConstantType::U8), 12)))
);
assert_eq!(
lex0.next(),
Some(Token::Number((Some(16), Some(ConstantType::I64), 12)))
);
assert_eq!(lex0.next(), None);
}
@@ -168,6 +350,31 @@ fn lexer_spans() {
assert_eq!(lex0.next(), Some((Token::Equals, 2..3)));
assert_eq!(lex0.next(), Some((Token::var("x"), 4..5)));
assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7)));
assert_eq!(lex0.next(), Some((Token::Number((None, 1)), 8..9)));
assert_eq!(lex0.next(), Some((Token::Number((None, None, 1)), 8..9)));
assert_eq!(lex0.next(), None);
}
#[test]
fn further_spans() {
let mut lex0 = Token::lexer("x = 2i64 + 2i64;\ny = -x;\nprint y;").spanned();
assert_eq!(lex0.next(), Some((Token::var("x"), 0..1)));
assert_eq!(lex0.next(), Some((Token::Equals, 2..3)));
assert_eq!(
lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 4..8))
);
assert_eq!(lex0.next(), Some((Token::Operator('+'), 9..10)));
assert_eq!(
lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 11..15))
);
assert_eq!(lex0.next(), Some((Token::Semi, 15..16)));
assert_eq!(lex0.next(), Some((Token::var("y"), 17..18)));
assert_eq!(lex0.next(), Some((Token::Equals, 19..20)));
assert_eq!(lex0.next(), Some((Token::Operator('-'), 21..22)));
assert_eq!(lex0.next(), Some((Token::var("x"), 22..23)));
assert_eq!(lex0.next(), Some((Token::Semi, 23..24)));
assert_eq!(lex0.next(), Some((Token::Print, 25..30)));
assert_eq!(lex0.next(), Some((Token::var("y"), 31..32)));
assert_eq!(lex0.next(), Some((Token::Semi, 32..33)));
}

View File

@@ -1,6 +1,9 @@
use crate::syntax::{Expression, Location, Program, Statement};
use crate::{
eval::PrimitiveType,
syntax::{Expression, Location, Program, Statement},
};
use codespan_reporting::diagnostic::Diagnostic;
use std::collections::HashMap;
use std::{collections::HashMap, str::FromStr};
/// An error we found while validating the input program.
///
@@ -11,6 +14,7 @@ use std::collections::HashMap;
/// and using [`codespan_reporting`] to present them to the user.
pub enum Error {
UnboundVariable(Location, String),
UnknownType(Location, String),
}
impl From<Error> for Diagnostic<usize> {
@@ -19,6 +23,10 @@ impl From<Error> for Diagnostic<usize> {
Error::UnboundVariable(location, name) => location
.labelled_error("unbound here")
.with_message(format!("Unbound variable '{}'", name)),
Error::UnknownType(location, name) => location
.labelled_error("type referenced here")
.with_message(format!("Unknown type '{}'", name)),
}
}
}
@@ -57,12 +65,24 @@ impl Program {
/// example, and generates warnings for things that are inadvisable but not
/// actually a problem.
pub fn validate(&self) -> (Vec<Error>, Vec<Warning>) {
let mut bound_variables = HashMap::new();
self.validate_with_bindings(&mut bound_variables)
}
/// Validate that the program makes semantic sense, not just syntactic sense.
///
/// This checks for things like references to variables that don't exist, for
/// example, and generates warnings for things that are inadvisable but not
/// actually a problem.
pub fn validate_with_bindings(
&self,
bound_variables: &mut HashMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) {
let mut errors = vec![];
let mut warnings = vec![];
let mut bound_variables = HashMap::new();
for stmt in self.statements.iter() {
let (mut new_errors, mut new_warnings) = stmt.validate(&mut bound_variables);
let (mut new_errors, mut new_warnings) = stmt.validate(bound_variables);
errors.append(&mut new_errors);
warnings.append(&mut new_warnings);
}
@@ -81,7 +101,7 @@ impl Statement {
/// occurs. We use a `HashMap` to map these bound locations to the locations
/// where their bound, because these locations are handy when generating errors
/// and warnings.
pub fn validate(
fn validate(
&self,
bound_variables: &mut HashMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) {
@@ -97,20 +117,20 @@ impl Statement {
errors.append(&mut exp_errors);
warnings.append(&mut exp_warnings);
if let Some(original_binding_site) = bound_variables.get(var) {
if let Some(original_binding_site) = bound_variables.get(&var.name) {
warnings.push(Warning::ShadowedVariable(
original_binding_site.clone(),
loc.clone(),
var.clone(),
var.to_string(),
));
} else {
bound_variables.insert(var.clone(), loc.clone());
bound_variables.insert(var.to_string(), loc.clone());
}
}
Statement::Print(_, var) if bound_variables.contains_key(var) => {}
Statement::Print(_, var) if bound_variables.contains_key(&var.name) => {}
Statement::Print(loc, var) => {
errors.push(Error::UnboundVariable(loc.clone(), var.clone()))
errors.push(Error::UnboundVariable(loc.clone(), var.to_string()))
}
}
@@ -127,6 +147,15 @@ impl Expression {
vec![Error::UnboundVariable(loc.clone(), var.clone())],
vec![],
),
Expression::Cast(location, t, expr) => {
let (mut errs, warns) = expr.validate(variable_map);
if PrimitiveType::from_str(t).is_err() {
errs.push(Error::UnknownType(location.clone(), t.clone()))
}
(errs, warns)
}
Expression::Primitive(_, _, args) => {
let mut errors = vec![];
let mut warnings = vec![];
@@ -142,3 +171,19 @@ impl Expression {
}
}
}
#[test]
fn cast_checks_are_reasonable() {
let good_stmt = Statement::parse(0, "x = <u16>4u8;").expect("valid test case");
let (good_errs, good_warns) = good_stmt.validate(&mut HashMap::new());
assert!(good_errs.is_empty());
assert!(good_warns.is_empty());
let bad_stmt = Statement::parse(0, "x = <apple>4u8;").expect("valid test case");
let (bad_errs, bad_warns) = bad_stmt.validate(&mut HashMap::new());
assert!(bad_warns.is_empty());
assert_eq!(bad_errs.len(), 1);
assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple"));
}

52
src/type_infer.rs Normal file
View File

@@ -0,0 +1,52 @@
//! A type inference pass for NGR
//!
//! The type checker implemented here is a relatively straightforward one, designed to be
//! fairly easy to understand rather than super fast. So don't be expecting the fastest
//! type checker in the West, here.
//!
//! The actual type checker operates in three phases. In the first phase, we translate
//! the syntax AST into something that's close to the final IR. During the process, we
//! generate a list of type constraints to solve. In the second phase, we try to solve
//! all the constraints we've generated. If that's successful, in the final phase, we
//! do the final conversion to the IR AST, filling in any type information we've learned
//! along the way.
mod ast;
mod convert;
mod finalize;
mod solve;
use self::convert::convert_program;
use self::finalize::finalize_program;
use self::solve::solve_constraints;
pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning};
use crate::ir::ast as ir;
use crate::syntax;
#[cfg(test)]
use crate::syntax::arbitrary::GenerationEnvironment;
#[cfg(test)]
use proptest::prelude::Arbitrary;
impl syntax::Program {
/// Infer the types for the syntactic AST, returning either a type-checked program in
/// the IR, or a series of type errors encountered during inference.
///
/// You really should have made sure that this program was validated before running
/// this method, otherwise you may experience panics during operation.
pub fn type_infer(self) -> TypeInferenceResult<ir::Program> {
let mut constraint_db = vec![];
let program = convert_program(self, &mut constraint_db);
let inference_result = solve_constraints(constraint_db);
inference_result.map(|resolutions| finalize_program(program, &resolutions))
}
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) {
let syntax_result = input.eval();
let ir = input.type_infer().expect("arbitrary should generate type-safe programs");
let ir_result = ir.eval();
proptest::prop_assert_eq!(syntax_result, ir_result);
}
}

336
src/type_infer/ast.rs Normal file
View File

@@ -0,0 +1,336 @@
pub use crate::ir::ast::Primitive;
/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going
/// to want to use while we're doing type inference, but don't want to keep around
/// afterwards. These are:
///
/// * A notion of a type variable
/// * An unknown numeric constant form
///
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use std::fmt;
use std::sync::atomic::AtomicUsize;
/// We're going to represent variables as interned strings.
///
/// These should be fast enough for comparison that it's OK, since it's going to end up
/// being pretty much the pointer to the string.
type Variable = ArcIntern<String>;
/// The representation of a program within our IR. For now, this is exactly one file.
///
/// In addition, for the moment there's not really much of interest to hold here besides
/// the list of statements read from the file. Order is important. In the future, you
/// could imagine caching analysis information in this structure.
///
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
/// to print the structure whenever possible, especially if you value your or your
/// user's time. The latter is useful for testing that conversions of `Program` retain
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
/// syntactically valid, although they may contain runtime issue like over- or underflow.
#[derive(Debug)]
pub struct Program {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) statements: Vec<Statement>,
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let mut result = allocator.nil();
for stmt in self.statements.iter() {
// there's probably a better way to do this, rather than constantly
// adding to the end, but this works.
result = result
.append(stmt.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline());
}
result
}
}
/// The representation of a statement in the language.
///
/// For now, this is either a binding site (`x = 4`) or a print statement
/// (`print x`). Someday, though, more!
///
/// As with `Program`, this type implements [`Pretty`], which should
/// be used to display the structure whenever possible. It does not
/// implement [`Arbitrary`], though, mostly because it's slightly
/// complicated to do so.
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
}
}
}
/// The representation of an expression.
///
/// Note that expressions, like everything else in this syntax tree,
/// supports [`Pretty`], and it's strongly encouraged that you use
/// that trait/module when printing these structures.
///
/// Also, Expressions at this point in the compiler are explicitly
/// defined so that they are *not* recursive. By this point, if an
/// expression requires some other data (like, for example, invoking
/// a primitive), any subexpressions have been bound to variables so
/// that the referenced data will always either be a constant or a
/// variable reference.
#[derive(Debug, PartialEq)]
pub enum Expression {
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(op.pretty(allocator))
.append(allocator.space())
.append(right)
.parens()
}
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Clone, Debug, PartialEq)]
pub enum ValueOrRef {
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
Expression::Atomic(value)
}
}
/// A constant in the IR.
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Unknown(Option<u8>, u64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::Unknown(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::U64)
}
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Variable(Location, ArcIntern<String>),
Primitive(PrimitiveType),
}
impl Type {
pub fn is_concrete(&self) -> bool {
!matches!(self, Type::Variable(_, _))
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Variable(_, x) => allocator.text(x.to_string()),
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Variable(_, x) => write!(f, "{}", x),
Type::Primitive(pt) => pt.fmt(f),
}
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
/// Generate a fresh new type; this will be a unique new type variable.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gentype() -> Type {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let name = ArcIntern::new(format!(
"t<{}>",
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
));
Type::Variable(Location::manufactured(), name)
}

378
src/type_infer/convert.rs Normal file
View File

@@ -0,0 +1,378 @@
use super::ast as ir;
use super::ast::Type;
use crate::eval::PrimitiveType;
use crate::syntax::{self, ConstantType};
use crate::type_infer::solve::Constraint;
use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the
/// program, with appropriate type variables introduced and their constraints added to
/// the given database.
///
/// If the input function has been validated (which it should be), then this should run
/// into no error conditions. However, if you failed to validate the input, then this
/// function can panic.
pub fn convert_program(
mut program: syntax::Program,
constraint_db: &mut Vec<Constraint>,
) -> ir::Program {
let mut statements = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
for stmt in program.statements.drain(..) {
statements.append(&mut convert_statement(
stmt,
constraint_db,
&mut renames,
&mut bindings,
));
}
ir::Program { statements }
}
/// This function takes a syntactic statements and converts it into a series of
/// IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> Vec<ir::Statement> {
match statement {
syntax::Statement::Print(loc, name) => {
let iname = ArcIntern::new(name.to_string());
let final_name = renames
.get(&iname)
.map(Clone::clone)
.unwrap_or_else(|| iname.clone());
let varty = bindings
.get(&final_name)
.expect("print variable defined before use")
.clone();
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
vec![ir::Statement::Print(loc, varty, iname)]
}
syntax::Statement::Binding(loc, name, expr) => {
let (mut prereqs, expr, ty) =
convert_expression(expr, constraint_db, renames, bindings);
let iname = ArcIntern::new(name.to_string());
let final_name = if bindings.contains_key(&iname) {
let new_name = ir::gensym(iname.as_str());
renames.insert(iname, new_name.clone());
new_name
} else {
iname
};
bindings.insert(final_name.clone(), ty.clone());
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr));
prereqs
}
}
}
/// This function takes a syntactic expression and converts it into a series
/// of IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_expression(
expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>,
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> (Vec<ir::Statement>, ir::Expression, Type) {
match expression {
syntax::Expression::Value(loc, val) => match val {
syntax::Value::Number(base, mctype, value) => {
let (newval, newtype) = match mctype {
None => {
let newtype = ir::gentype();
let newval = ir::Value::Unknown(base, value);
constraint_db.push(Constraint::ConstantNumericType(
loc.clone(),
newtype.clone(),
));
(newval, newtype)
}
Some(ConstantType::U8) => (
ir::Value::U8(base, value as u8),
ir::Type::Primitive(PrimitiveType::U8),
),
Some(ConstantType::U16) => (
ir::Value::U16(base, value as u16),
ir::Type::Primitive(PrimitiveType::U16),
),
Some(ConstantType::U32) => (
ir::Value::U32(base, value as u32),
ir::Type::Primitive(PrimitiveType::U32),
),
Some(ConstantType::U64) => (
ir::Value::U64(base, value),
ir::Type::Primitive(PrimitiveType::U64),
),
Some(ConstantType::I8) => (
ir::Value::I8(base, value as i8),
ir::Type::Primitive(PrimitiveType::I8),
),
Some(ConstantType::I16) => (
ir::Value::I16(base, value as i16),
ir::Type::Primitive(PrimitiveType::I16),
),
Some(ConstantType::I32) => (
ir::Value::I32(base, value as i32),
ir::Type::Primitive(PrimitiveType::I32),
),
Some(ConstantType::I64) => (
ir::Value::I64(base, value as i64),
ir::Type::Primitive(PrimitiveType::I64),
),
};
constraint_db.push(Constraint::FitsInNumType(
loc.clone(),
newtype.clone(),
value,
));
(
vec![],
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)),
newtype,
)
}
},
syntax::Expression::Reference(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames.get(&iname).cloned().unwrap_or(iname);
let rtype = bindings
.get(&final_name)
.cloned()
.expect("variable bound before use");
let refexp =
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
(vec![], refexp, rtype)
}
syntax::Expression::Cast(loc, target, expr) => {
let (mut stmts, nexpr, etype) =
convert_expression(*expr, constraint_db, renames, bindings);
let val_or_ref = simplify_expr(nexpr, &mut stmts);
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast");
let target_type = Type::Primitive(target_prim_type);
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone()));
(stmts, res, target_type)
}
syntax::Expression::Primitive(loc, fun, mut args) => {
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
let mut stmts = vec![];
let mut nargs = vec![];
let mut atypes = vec![];
let ret_type = ir::gentype();
for arg in args.drain(..) {
let (mut astmts, aexp, atype) =
convert_expression(arg, constraint_db, renames, bindings);
stmts.append(&mut astmts);
nargs.push(simplify_expr(aexp, &mut stmts));
atypes.push(atype);
}
constraint_db.push(Constraint::ProperPrimitiveArgs(
loc.clone(),
primop,
atypes.clone(),
ret_type.clone(),
));
(
stmts,
ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs),
ret_type,
)
}
}
}
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef {
match expr {
ir::Expression::Atomic(v_or_ref) => v_or_ref,
expr => {
let etype = expr.type_of().clone();
let loc = expr.location().clone();
let nname = ir::gensym("g");
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr);
stmts.push(nbinding);
ir::ValueOrRef::Ref(loc, etype, nname)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::syntax::Location;
fn one() -> syntax::Expression {
syntax::Expression::Value(
Location::manufactured(),
syntax::Value::Number(None, None, 1),
)
}
fn vec_contains<T, F: Fn(&T) -> bool>(x: &[T], f: F) -> bool {
for x in x.iter() {
if f(x) {
return true;
}
}
false
}
fn infer_expression(
x: syntax::Expression,
) -> (ir::Expression, Vec<ir::Statement>, Vec<Constraint>, Type) {
let mut constraints = Vec::new();
let renames = HashMap::new();
let mut bindings = HashMap::new();
let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings);
(expr, stmts, constraints, ty)
}
fn infer_statement(x: syntax::Statement) -> (Vec<ir::Statement>, Vec<Constraint>) {
let mut constraints = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings);
(res, constraints)
}
#[test]
fn constant_one() {
let (expr, stmts, constraints, ty) = infer_expression(one());
assert!(stmts.is_empty());
assert!(matches!(
expr,
ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1)))
));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty)
));
}
#[test]
fn one_plus_one() {
let opo = syntax::Expression::Primitive(
Location::manufactured(),
"+".to_string(),
vec![one(), one()],
);
let (expr, stmts, constraints, ty) = infer_expression(opo);
assert!(stmts.is_empty());
assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty)
));
}
#[test]
fn one_plus_one_plus_one() {
let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse");
let (stmts, constraints) = infer_statement(stmt);
assert_eq!(stmts.len(), 2);
let ir::Statement::Binding(
_args,
name1,
temp_ty1,
ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1),
) = stmts.get(0).expect("item two")
else {
panic!("Failed to match first statement");
};
let ir::Statement::Binding(
_args,
name2,
temp_ty2,
ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2),
) = stmts.get(1).expect("item two")
else {
panic!("Failed to match second statement");
};
let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] =
&primargs1[..]
else {
panic!("Failed to match first arguments");
};
let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] =
&primargs2[..]
else {
panic!("Failed to match first arguments");
};
assert_ne!(name1, name2);
assert_ne!(temp_ty1, temp_ty2);
assert_ne!(primty1, primty2);
assert_eq!(name1, left2name);
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty)
));
for (i, s) in stmts.iter().enumerate() {
println!("{}: {:?}", i, s);
}
for (i, c) in constraints.iter().enumerate() {
println!("{}: {:?}", i, c);
}
}
}

187
src/type_infer/finalize.rs Normal file
View File

@@ -0,0 +1,187 @@
use super::{ast as input, solve::TypeResolutions};
use crate::{eval::PrimitiveType, ir as output};
pub fn finalize_program(
mut program: input::Program,
resolutions: &TypeResolutions,
) -> output::Program {
output::Program {
statements: program
.statements
.drain(..)
.map(|x| finalize_statement(x, resolutions))
.collect(),
}
}
fn finalize_statement(
statement: input::Statement,
resolutions: &TypeResolutions,
) -> output::Statement {
match statement {
input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding(
loc,
var,
finalize_type(ty, resolutions),
finalize_expression(expr, resolutions),
),
input::Statement::Print(loc, ty, var) => {
output::Statement::Print(loc, finalize_type(ty, resolutions), var)
}
}
}
fn finalize_expression(
expression: input::Expression,
resolutions: &TypeResolutions,
) -> output::Expression {
match expression {
input::Expression::Atomic(val_or_ref) => {
output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions))
}
input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast(
loc,
finalize_type(target, resolutions),
finalize_val_or_ref(val_or_ref, resolutions),
),
input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive(
loc,
finalize_type(ty, resolutions),
prim,
args.drain(..)
.map(|x| finalize_val_or_ref(x, resolutions))
.collect(),
),
}
}
fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type {
match ty {
input::Type::Primitive(x) => output::Type::Primitive(x),
input::Type::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => output::Type::Primitive(*pt),
},
}
}
fn finalize_val_or_ref(
valref: input::ValueOrRef,
resolutions: &TypeResolutions,
) -> output::ValueOrRef {
match valref {
input::ValueOrRef::Ref(loc, ty, var) => {
output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var)
}
input::ValueOrRef::Value(loc, ty, val) => {
let new_type = finalize_type(ty, resolutions);
match val {
input::Value::Unknown(base, value) => match new_type {
output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U8(base, value as u8),
),
output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U16(base, value as u16),
),
output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U32(base, value as u32),
),
output::Type::Primitive(PrimitiveType::U64) => {
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I8(base, value as i8),
),
output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I16(base, value as i16),
),
output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I32(base, value as i32),
),
output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I64(base, value as i64),
),
},
input::Value::U8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value))
}
input::Value::U16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value))
}
input::Value::U32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value))
}
input::Value::U64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
input::Value::I8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value))
}
input::Value::I16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value))
}
input::Value::I32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value))
}
input::Value::I64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value))
}
}
}
}
}

542
src/type_infer/solve.rs Normal file
View File

@@ -0,0 +1,542 @@
use super::ast as ir;
use super::ast::Type;
use crate::{eval::PrimitiveType, syntax::Location};
use codespan_reporting::diagnostic::Diagnostic;
use internment::ArcIntern;
use std::{collections::HashMap, fmt};
/// A type inference constraint that we're going to need to solve.
#[derive(Debug)]
pub enum Constraint {
/// The given type must be printable using the `print` built-in
Printable(Location, Type),
/// The provided numeric value fits in the given constant type
FitsInNumType(Location, Type, u64),
/// The given primitive has the proper arguments types associated with it
ProperPrimitiveArgs(Location, ir::Primitive, Vec<Type>, Type),
/// The given type can be casted to the target type safely
CanCastTo(Location, Type, Type),
/// The given type must be some numeric type, but this is not a constant
/// value, so don't try to default it if we can't figure it out
NumericType(Location, Type),
/// The given type is attached to a constant and must be some numeric type.
/// If we can't figure it out, we should warn the user and then just use a
/// default.
ConstantNumericType(Location, Type),
/// The two types should be equivalent
Equivalent(Location, Type, Type),
}
impl fmt::Display for Constraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty),
Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty),
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => {
write!(f, "PRIM {} {} -> {}", op, args[0], ret)
}
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => {
write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret)
}
Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret),
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty),
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
}
}
}
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>;
/// The results of type inference; like [`Result`], but with a bit more information.
///
/// This result is parameterized, because sometimes it's handy to return slightly
/// different things; there's a [`TypeInferenceResult::map`] function for performing
/// those sorts of conversions.
pub enum TypeInferenceResult<Result> {
Success {
result: Result,
warnings: Vec<TypeInferenceWarning>,
},
Failure {
errors: Vec<TypeInferenceError>,
warnings: Vec<TypeInferenceWarning>,
},
}
impl<R> TypeInferenceResult<R> {
// If this was a successful type inference, run the function over the result to
// create a new result.
//
// This is the moral equivalent of [`Result::map`], but for type inference results.
pub fn map<U, F>(self, f: F) -> TypeInferenceResult<U>
where
F: FnOnce(R) -> U,
{
match self {
TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success {
result: f(result),
warnings,
},
TypeInferenceResult::Failure { errors, warnings } => {
TypeInferenceResult::Failure { errors, warnings }
}
}
}
// Return the final result, or panic if it's not a success
pub fn expect(self, msg: &str) -> R {
match self {
TypeInferenceResult::Success { result, .. } => result,
TypeInferenceResult::Failure { .. } => {
panic!("tried to get value from failed type inference: {}", msg)
}
}
}
}
/// The various kinds of errors that can occur while doing type inference.
pub enum TypeInferenceError {
/// The user provide a constant that is too large for its inferred type.
ConstantTooLarge(Location, PrimitiveType, u64),
/// The two types needed to be equivalent, but weren't.
NotEquivalent(Location, PrimitiveType, PrimitiveType),
/// We cannot safely cast the first type to the second type.
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
/// The primitive invocation provided the wrong number of arguments.
WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize),
/// We had a constraint we just couldn't solve.
CouldNotSolve(Constraint),
}
impl From<TypeInferenceError> for Diagnostic<usize> {
fn from(value: TypeInferenceError) -> Self {
match value {
TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc
.labelled_error("constant too large for type")
.with_message(format!(
"Type {} has a max value of {}, which is smaller than {}",
primty,
primty.max_value(),
value
)),
TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc
.labelled_error("type inference error")
.with_message(format!("Expected type {}, received type {}", ty1, ty2)),
TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc
.labelled_error("unsafe type cast")
.with_message(format!("Cannot safely cast {} to {}", ty1, ty2)),
TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc
.labelled_error("wrong number of arguments")
.with_message(format!(
"expected {} for {}, received {}",
if lower == upper && lower > 1 {
format!("{} arguments", lower)
} else if lower == upper {
format!("{} argument", lower)
} else {
format!("{}-{} arguments", lower, upper)
},
prim,
observed
)),
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if it was safe to cast from {} to {:#?}",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if {} and {:#?} were equivalent",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not determine if {} could fit in {}",
val, ty
))
}
TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if {} was a numeric type", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) =>
panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty),
TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if type {} was printable", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not tell if primitive {} received the proper argument types",
prim
))
}
}
}
}
/// Warnings that we might want to tell the user about.
///
/// These are fine, probably, but could indicate some behavior the user might not
/// expect, and so they might want to do something about them.
pub enum TypeInferenceWarning {
DefaultedTo(Location, Type),
}
impl From<TypeInferenceWarning> for Diagnostic<usize> {
fn from(value: TypeInferenceWarning) -> Self {
match value {
TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning()
.with_labels(vec![loc.primary_label().with_message("unknown type")])
.with_message(format!("Defaulted unknown type to {}", ty)),
}
}
}
/// Solve all the constraints in the provided database.
///
/// This process can take a bit, so you might not want to do it multiple times. Basically,
/// it's going to grind on these constraints until either it figures them out, or it stops
/// making progress. I haven't done the math on the constraints to even figure out if this
/// is guaranteed to halt, though, let alone terminate in some reasonable amount of time.
///
/// The return value is a type inference result, which pairs some warnings with either a
/// successful set of type resolutions (mappings from type variables to their values), or
/// a series of inference errors.
pub fn solve_constraints(
mut constraint_db: Vec<Constraint>,
) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![];
let mut warnings = vec![];
let mut resolutions = HashMap::new();
let mut changed_something = true;
// We want to run this inference endlessly, until either we have solved all of our
// constraints. Internal to the loop, we have a check that will make sure that we
// do (eventually) stop.
while changed_something && !constraint_db.is_empty() {
// Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out.
changed_something = false;
// This is sort of a double-buffering thing; we're going to rename constraint_db
// and then set it to a new empty vector, which we'll add to as we find new
// constraints or find ourselves unable to solve existing ones.
let mut local_constraints = constraint_db;
constraint_db = vec![];
// OK. First thing we're going to do is run through all of our constraints,
// and see if we can solve any, or reduce them to theoretically more simple
// constraints.
for constraint in local_constraints.drain(..) {
match constraint {
// Currently, all of our types are printable
Constraint::Printable(_loc, _ty) => changed_something = true,
// Case #1: We have two primitive types. If they're equal, we've discharged this
// constraint! We can just continue. If they're not equal, add an error and then
// see what else we come up with.
Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => {
if t1 != t2 {
errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2));
}
changed_something = true;
}
// Case #2: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name))
| Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => {
match resolutions.get(&name) {
None => {
resolutions.insert(name, t);
}
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)),
}
changed_something = true;
}
// Case #3: They're both variables. In which case, we'll have to do much the same
// check, but now on their resolutions.
Constraint::Equivalent(
ref loc,
Type::Variable(_, ref name1),
Type::Variable(_, ref name2),
) => match (resolutions.get(name1), resolutions.get(name2)) {
(None, None) => {
constraint_db.push(constraint);
}
(Some(pt), None) => {
resolutions.insert(name2.clone(), *pt);
changed_something = true;
}
(None, Some(pt)) => {
resolutions.insert(name1.clone(), *pt);
changed_something = true;
}
(Some(pt1), Some(pt2)) if pt1 == pt2 => {
changed_something = true;
}
(Some(pt1), Some(pt2)) => {
errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2));
changed_something = true;
}
},
// Make sure that the provided number fits within the provided constant type. For the
// moment, we're going to call an error here a failure, although this could be a
// warning in the future.
Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => {
if ctype.max_value() < val {
errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val));
}
changed_something = true;
}
// If we have a non-constant type, then let's see if we can advance this to a constant
// type
Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Variable(vloc, var),
val,
)),
Some(nt) => {
constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Primitive(*nt),
val,
));
changed_something = true;
}
}
}
// If the left type in a "can cast to" check is a variable, let's see if we can advance
// it into something more tangible
Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
Type::Variable(vloc, var),
to_type,
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
Type::Primitive(*nt),
to_type,
));
changed_something = true;
}
}
}
// If the right type in a "can cast to" check is a variable, same deal
Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Primitive(*nt),
));
changed_something = true;
}
}
}
// If both of them are types, then we can actually do the test. yay!
Constraint::CanCastTo(
loc,
Type::Primitive(from_type),
Type::Primitive(to_type),
) => {
if !from_type.can_cast_to(&to_type) {
errors.push(TypeInferenceError::CannotSafelyCast(
loc, from_type, to_type,
));
}
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::NumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db
.push(Constraint::NumericType(loc, Type::Variable(vloc, var))),
Some(nt) => {
constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::NumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::ConstantNumericType(
loc,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db
.push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::ConstantNumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// OK, this one could be a little tricky if we tried to do it all at once, but
// instead what we're going to do here is just use this constraint to generate
// a bunch more constraints, and then go have the engine solve those. The only
// real errors we're going to come up with here are "arity errors"; errors we
// find by discovering that the number of arguments provided doesn't make sense
// given the primitive being used.
Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim {
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide
if args.len() != 2 =>
{
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
2,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc, left, ret));
changed_something = true;
}
ir::Primitive::Minus if args.is_empty() || args.len() > 2 => {
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
1,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Minus if args.len() == 1 => {
let arg = args.pop().expect("1 > 0");
constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(loc, arg, ret));
changed_something = true;
}
ir::Primitive::Minus => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret));
changed_something = true;
}
},
}
}
// If that didn't actually come up with anything, and we just recycled all the constraints
// back into the database unchanged, then let's take a look for cases in which we just
// wanted something we didn't know to be a number. Basically, those are cases where the
// user just wrote a number, but didn't tell us what type it was, and there isn't enough
// information in the context to tell us. If that happens, we'll just set that type to
// be u64, and warn the user that we did so.
if !changed_something && !constraint_db.is_empty() {
local_constraints = constraint_db;
constraint_db = vec![];
for constraint in local_constraints.drain(..) {
match constraint {
Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => {
let resty = Type::Primitive(PrimitiveType::U64);
constraint_db.push(Constraint::Equivalent(
loc.clone(),
t,
Type::Primitive(PrimitiveType::U64),
));
warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty));
changed_something = true;
}
_ => constraint_db.push(constraint),
}
}
}
}
// OK, we left our loop. Which means that either we solved everything, or we didn't.
// If we didn't, turn the unsolved constraints into type inference errors, and add
// them to our error list.
let mut unsolved_constraint_errors = constraint_db
.drain(..)
.map(TypeInferenceError::CouldNotSolve)
.collect();
errors.append(&mut unsolved_constraint_errors);
// How'd we do?
if errors.is_empty() {
TypeInferenceResult::Success {
result: resolutions,
warnings,
}
} else {
TypeInferenceResult::Failure { errors, warnings }
}
}