diff --git a/.gitignore b/.gitignore index 55de121..90624f7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ Cargo.lock **/*.o test +test.exe +test.ilk +test.pdb *.dSYM .vscode proptest-regressions/ \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 0451d2e..9531890 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,12 +12,12 @@ path = "src/lib.rs" clap = { version = "^3.0.14", features = ["derive"] } codespan = "0.11.1" codespan-reporting = "0.11.1" -cranelift-codegen = "0.94.0" -cranelift-jit = "0.94.0" -cranelift-frontend = "0.94.0" -cranelift-module = "0.94.0" -cranelift-native = "0.94.0" -cranelift-object = "0.94.0" +cranelift-codegen = "0.99.2" +cranelift-jit = "0.99.2" +cranelift-frontend = "0.99.2" +cranelift-module = "0.99.2" +cranelift-native = "0.99.2" +cranelift-object = "0.99.2" internment = { version = "0.7.0", default-features = false, features = ["arc"] } lalrpop-util = "^0.20.0" lazy_static = "^1.4.0" diff --git a/build.rs b/build.rs index 23c7d3f..2f4266e 100644 --- a/build.rs +++ b/build.rs @@ -1,5 +1,88 @@ extern crate lalrpop; +use std::env; +use std::fs::{self, File}; +use std::io::Write; +use std::path::{Path, PathBuf}; + fn main() { lalrpop::process_root().unwrap(); + if let Err(e) = generate_example_tests() { + eprintln!("Failure building example tests: {}", e); + std::process::exit(3); + } +} + +fn generate_example_tests() -> std::io::Result<()> { + let out_dir = env::var_os("OUT_DIR").expect("OUT_DIR"); + let dest_path = Path::new(&out_dir).join("examples.rs"); + let mut output = File::create(dest_path)?; + + generate_tests(&mut output, PathBuf::from("examples"))?; + + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=example.rs"); + Ok(()) +} + +fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> { + for entry in fs::read_dir(path_so_far)? { + let entry = entry?; + let file_name = entry + .file_name() + .into_string() + .expect("reasonable file name"); + let name = file_name.trim_end_matches(".ngr"); + + if entry.file_type()?.is_dir() { + writeln!(f, "mod {} {{", name)?; + writeln!(f, " use super::*;")?; + generate_tests(f, entry.path())?; + writeln!(f, "}}")?; + } else { + writeln!(f, "#[test]")?; + writeln!(f, "fn {}() {{", name)?; + writeln!(f, " let mut file_database = SimpleFiles::new();")?; + writeln!( + f, + " let syntax = Syntax::parse_file(&mut file_database, {:?});", + entry.path().display() + )?; + if entry.path().to_string_lossy().contains("broken") { + writeln!(f, " if syntax.is_err() {{")?; + writeln!(f, " return;")?; + writeln!(f, " }}")?; + writeln!(f, " let (errors, _) = syntax.unwrap().validate();")?; + writeln!( + f, + " assert_ne!(errors.len(), 0, \"should have seen an error\");" + )?; + } else { + // NOTE: Since the advent of defaulting rules and type checking, we + // can't guarantee that syntax.eval() will return the same result as + // ir.eval() or backend::eval(). We must do type checking to force + // constants into the right types, first. So this now checks only that + // the result of ir.eval() and backend::eval() are the same. + writeln!( + f, + " let syntax = syntax.expect(\"file should have parsed\");" + )?; + writeln!(f, " let (errors, _) = syntax.validate();")?; + writeln!( + f, + " assert_eq!(errors.len(), 0, \"file should have no validation errors\");" + )?; + writeln!( + f, + " let ir = syntax.type_infer().expect(\"example is typed correctly\");" + )?; + writeln!(f, " let ir_result = ir.eval();")?; + writeln!(f, " let compiled_result = Backend::::eval(ir);")?; + writeln!(f, " assert_eq!(ir_result, compiled_result);")?; + } + writeln!(f, "}}")?; + } + } + + Ok(()) } diff --git a/examples/basic/broken1.ngr b/examples/basic/broken1.ngr new file mode 100644 index 0000000..775e24a --- /dev/null +++ b/examples/basic/broken1.ngr @@ -0,0 +1,4 @@ +x = 5; +x = 4*x + 3; +print x; +print y; diff --git a/examples/basic/cast1.ngr b/examples/basic/cast1.ngr new file mode 100644 index 0000000..0dc2488 --- /dev/null +++ b/examples/basic/cast1.ngr @@ -0,0 +1,7 @@ +x8 = 5u8; +x16 = x8 + 1u16; +print x16; +x32 = x16 + 1u32; +print x32; +x64 = x32 + 1u64; +print x64; \ No newline at end of file diff --git a/examples/basic/cast2.ngr b/examples/basic/cast2.ngr new file mode 100644 index 0000000..4c61255 --- /dev/null +++ b/examples/basic/cast2.ngr @@ -0,0 +1,7 @@ +x8 = 5i8; +x16 = x8 - 1i16; +print x16; +x32 = x16 - 1i32; +print x32; +x64 = x32 - 1i64; +print x64; \ No newline at end of file diff --git a/examples/basic/negation.ngr b/examples/basic/negation.ngr new file mode 100644 index 0000000..18fd866 --- /dev/null +++ b/examples/basic/negation.ngr @@ -0,0 +1,3 @@ +x = 2i64 + 2i64; +y = -x; +print y; \ No newline at end of file diff --git a/examples/basic/test2.ngr b/examples/basic/test2.ngr index 775e24a..0c9dd6d 100644 --- a/examples/basic/test2.ngr +++ b/examples/basic/test2.ngr @@ -1,4 +1,4 @@ -x = 5; -x = 4*x + 3; -print x; -print y; +// this test is useful for making sure we don't accidentally +// use a signed divide operation anywhere important +a = 96u8 / 160u8; +print a; \ No newline at end of file diff --git a/examples/basic/type_checker1.ngr b/examples/basic/type_checker1.ngr new file mode 100644 index 0000000..aa0a44a --- /dev/null +++ b/examples/basic/type_checker1.ngr @@ -0,0 +1,6 @@ +x = 1 + 1u16; +print x; +y = 1u16 + 1; +print y; +z = 1 + 1 + 1; +print z; \ No newline at end of file diff --git a/runtime/rts.c b/runtime/rts.c index 4cf41e5..a999bae 100644 --- a/runtime/rts.c +++ b/runtime/rts.c @@ -2,12 +2,33 @@ #include #include -void print(char *_ignore, char *variable_name, int64_t value) { - printf("%s = %" PRId64 "i64\n", variable_name, value); -} - -void caller() { - print(NULL, "x", 4); +void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) { + switch(vtype) { + case /* U8 = */ 10: + printf("%s = %" PRIu8 "u8\n", variable_name, (uint8_t)value); + break; + case /* U16 = */ 11: + printf("%s = %" PRIu16 "u16\n", variable_name, (uint16_t)value); + break; + case /* U32 = */ 12: + printf("%s = %" PRIu32 "u32\n", variable_name, (uint32_t)value); + break; + case /* U64 = */ 13: + printf("%s = %" PRIu64 "u64\n", variable_name, (uint64_t)value); + break; + case /* I8 = */ 20: + printf("%s = %" PRIi8 "i8\n", variable_name, (int8_t)value); + break; + case /* I16 = */ 21: + printf("%s = %" PRIi16 "i16\n", variable_name, (int16_t)value); + break; + case /* I32 = */ 22: + printf("%s = %" PRIi32 "i32\n", variable_name, (int32_t)value); + break; + case /* I64 = */ 23: + printf("%s = %" PRIi64 "i64\n", variable_name, value); + break; + } } extern void gogogo(); diff --git a/src/backend.rs b/src/backend.rs index b6b8808..4ac8e83 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -31,15 +31,15 @@ mod eval; mod into_crane; mod runtime; -use std::collections::HashMap; - pub use self::error::BackendError; pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; +use crate::syntax::ConstantType; use cranelift_codegen::settings::Configurable; use cranelift_codegen::{isa, settings}; use cranelift_jit::{JITBuilder, JITModule}; -use cranelift_module::{default_libcall_names, DataContext, DataId, FuncId, Linkage, Module}; +use cranelift_module::{default_libcall_names, DataDescription, DataId, FuncId, Linkage, Module}; use cranelift_object::{ObjectBuilder, ObjectModule}; +use std::collections::HashMap; use target_lexicon::Triple; const EMPTY_DATUM: [u8; 8] = [0; 8]; @@ -55,11 +55,12 @@ const EMPTY_DATUM: [u8; 8] = [0; 8]; /// implementations. pub struct Backend { pub module: M, - data_ctx: DataContext, + data_ctx: DataDescription, runtime_functions: RuntimeFunctions, defined_strings: HashMap, - defined_symbols: HashMap, + defined_symbols: HashMap, output_buffer: Option, + platform: Triple, } impl Backend { @@ -85,11 +86,12 @@ impl Backend { Ok(Backend { module, - data_ctx: DataContext::new(), + data_ctx: DataDescription::new(), runtime_functions, defined_strings: HashMap::new(), defined_symbols: HashMap::new(), output_buffer, + platform: Triple::host(), }) } @@ -123,11 +125,12 @@ impl Backend { Ok(Backend { module, - data_ctx: DataContext::new(), + data_ctx: DataDescription::new(), runtime_functions, defined_strings: HashMap::new(), defined_symbols: HashMap::new(), output_buffer: None, + platform, }) } @@ -154,7 +157,7 @@ impl Backend { let global_id = self .module .declare_data(&name, Linkage::Local, false, false)?; - let mut data_context = DataContext::new(); + let mut data_context = DataDescription::new(); data_context.set_align(8); data_context.define(s0.into_boxed_str().into_boxed_bytes()); self.module.define_data(global_id, &data_context)?; @@ -167,14 +170,18 @@ impl Backend { /// These variables can be shared between functions, and will be exported from the /// module itself as public data in the case of static compilation. There initial /// value will be null. - pub fn define_variable(&mut self, name: String) -> Result { + pub fn define_variable( + &mut self, + name: String, + ctype: ConstantType, + ) -> Result { self.data_ctx.define(Box::new(EMPTY_DATUM)); let id = self .module .declare_data(&name, Linkage::Export, true, false)?; self.module.define_data(id, &self.data_ctx)?; self.data_ctx.clear(); - self.defined_symbols.insert(name, id); + self.defined_symbols.insert(name, (id, ctype)); Ok(id) } diff --git a/src/backend/error.rs b/src/backend/error.rs index caa9e59..7509da3 100644 --- a/src/backend/error.rs +++ b/src/backend/error.rs @@ -1,4 +1,4 @@ -use crate::backend::runtime::RuntimeFunctionError; +use crate::{backend::runtime::RuntimeFunctionError, eval::PrimitiveType, ir::Type}; use codespan_reporting::diagnostic::Diagnostic; use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError}; use cranelift_module::ModuleError; @@ -39,6 +39,8 @@ pub enum BackendError { LookupError(#[from] LookupError), #[error(transparent)] Write(#[from] cranelift_object::object::write::Error), + #[error("Invalid type cast from {from} to {to}")] + InvalidTypeCast { from: PrimitiveType, to: Type }, } impl From for Diagnostic { @@ -64,6 +66,9 @@ impl From for Diagnostic { BackendError::Write(me) => { Diagnostic::error().with_message(format!("Cranelift object write error: {}", me)) } + BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message( + format!("Internal error trying to cast from {} to {}", from, to), + ), } } } @@ -103,6 +108,17 @@ impl PartialEq for BackendError { BackendError::Write(b) => a == b, _ => false, }, + + BackendError::InvalidTypeCast { + from: from1, + to: to1, + } => match other { + BackendError::InvalidTypeCast { + from: from2, + to: to2, + } => from1 == from2 && to1 == to2, + _ => false, + }, } } } diff --git a/src/backend/eval.rs b/src/backend/eval.rs index e9c88f1..78ed7a5 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -1,10 +1,13 @@ -use std::path::Path; - use crate::backend::Backend; use crate::eval::EvalError; use crate::ir::Program; +#[cfg(test)] +use crate::syntax::arbitrary::GenerationEnvironment; use cranelift_jit::JITModule; use cranelift_object::ObjectModule; +#[cfg(test)] +use proptest::arbitrary::Arbitrary; +use std::path::Path; use target_lexicon::Triple; impl Backend { @@ -28,7 +31,8 @@ impl Backend { let compiled_bytes = jitter.bytes(function_id); let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; compiled_function(); - Ok(jitter.output()) + let output = jitter.output(); + Ok(output) } } @@ -116,7 +120,7 @@ proptest::proptest! { // without error, assuming any possible input ... well, any possible input that // doesn't involve overflow or underflow. #[test] - fn static_backend(program: Program) { + fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { use crate::eval::PrimOpError; let basic_result = program.eval(); @@ -127,8 +131,18 @@ proptest::proptest! { let basic_result = basic_result.map(|x| x.replace('\n', "\r\n")); if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { +// use pretty::{DocAllocator, Pretty}; +// let allocator = pretty::BoxAllocator; +// allocator +// .text("---------------") +// .append(allocator.hardline()) +// .append(program.pretty(&allocator)) +// .1 +// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) +// .expect("rendering works"); + let compiled_result = Backend::::eval(program); - assert_eq!(basic_result, compiled_result); + proptest::prop_assert_eq!(basic_result, compiled_result); } } @@ -136,14 +150,24 @@ proptest::proptest! { // without error, assuming any possible input ... well, any possible input that // doesn't involve overflow or underflow. #[test] - fn jit_backend(program: Program) { + fn jit_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { use crate::eval::PrimOpError; +// use pretty::{DocAllocator, Pretty}; +// let allocator = pretty::BoxAllocator; +// allocator +// .text("---------------") +// .append(allocator.hardline()) +// .append(program.pretty(&allocator)) +// .1 +// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) +// .expect("rendering works"); + let basic_result = program.eval(); if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { let compiled_result = Backend::::eval(program); - assert_eq!(basic_result, compiled_result); + proptest::prop_assert_eq!(basic_result, compiled_result); } } } diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 5965a9f..e398e3d 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -1,9 +1,11 @@ use std::collections::HashMap; -use crate::ir::{Expression, Primitive, Program, Statement, Value, ValueOrRef}; +use crate::eval::PrimitiveType; +use crate::ir::{Expression, Primitive, Program, Statement, Type, Value, ValueOrRef}; +use crate::syntax::ConstantType; use cranelift_codegen::entity::EntityRef; use cranelift_codegen::ir::{ - entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName, + self, entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName, }; use cranelift_codegen::isa::CallConv; use cranelift_codegen::Context; @@ -41,7 +43,7 @@ impl Backend { let basic_signature = Signature { params: vec![], returns: vec![], - call_conv: CallConv::SystemV, + call_conv: CallConv::triple_default(&self.platform), }; // this generates the handle for the function that we'll eventually want to @@ -85,12 +87,12 @@ impl Backend { // Just like with strings, generating the `GlobalValue`s we need can potentially // be a little tricky to do on the fly, so we generate the complete list right // here and then use it later. - let pre_defined_symbols: HashMap = self + let pre_defined_symbols: HashMap = self .defined_symbols .iter() - .map(|(k, v)| { + .map(|(k, (v, t))| { let local_data = self.module.declare_data_in_func(*v, &mut ctx.func); - (k.clone(), local_data) + (k.clone(), (local_data, *t)) }) .collect(); @@ -123,7 +125,7 @@ impl Backend { // Print statements are fairly easy to compile: we just lookup the // output buffer, the address of the string to print, and the value // of whatever variable we're printing. Then we just call print. - Statement::Print(ann, var) => { + Statement::Print(ann, t, var) => { // Get the output buffer (or null) from our general compilation context. let buffer_ptr = self.output_buffer_ptr(); let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); @@ -135,31 +137,47 @@ impl Backend { // Look up the value for the variable. Because this might be a // global variable (and that requires special logic), we just turn // this into an `Expression` and re-use the logic in that implementation. - let val = Expression::Reference(ann, var).into_crane( + let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane( &mut builder, &variable_table, &pre_defined_symbols, )?; + let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); + + let casted_val = match vtype { + ConstantType::U64 | ConstantType::I64 => val, + ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => { + builder.ins().sextend(types::I64, val) + } + ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => { + builder.ins().uextend(types::I64, val) + } + }; + // Finally, we can generate the call to print. - builder - .ins() - .call(print_func_ref, &[buffer_ptr, name_ptr, val]); + builder.ins().call( + print_func_ref, + &[buffer_ptr, name_ptr, vtype_repr, casted_val], + ); } // Variable binding is a little more con - Statement::Binding(_, var_name, value) => { + Statement::Binding(_, var_name, _, value) => { // Kick off to the `Expression` implementation to see what value we're going // to bind to this variable. - let val = + let (val, etype) = value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?; // Now the question is: is this a local variable, or a global one? - if let Some(global_id) = pre_defined_symbols.get(var_name.as_str()) { + if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) { // It's a global variable! In this case, we assume that someone has already // dedicated some space in memory to store this value. We look this location // up, and then tell Cranelift to store the value there. - let val_ptr = builder.ins().symbol_value(types::I64, *global_id); + assert_eq!(etype, *ctype); + let val_ptr = builder + .ins() + .symbol_value(ir::Type::from(*ctype), *global_id); builder.ins().store(MemFlags::new(), val, val_ptr, 0); } else { // It's a local variable! In this case, we need to allocate a new Cranelift @@ -171,12 +189,10 @@ impl Backend { next_var_num += 1; // We can add the variable directly to our local variable map; it's `Copy`. - variable_table.insert(var_name, var); + variable_table.insert(var_name, (var, etype)); - // Now we tell Cranelift about our new variable, which has type I64 because - // everything we have at this point is of type I64. Once it's declare, we - // define it as having the value we computed above. - builder.declare_var(var, types::I64); + // Now we tell Cranelift about our new variable! + builder.declare_var(var, ir::Type::from(etype)); builder.def_var(var, val); } } @@ -195,7 +211,7 @@ impl Backend { // so we register it using the function ID and our builder context. However, the // result of this function isn't actually super helpful. So we ignore it, unless // it's an error. - let _ = self.module.define_function(func_id, &mut ctx)?; + self.module.define_function(func_id, &mut ctx)?; // done! Ok(func_id) @@ -231,54 +247,110 @@ impl Expression { fn into_crane( self, builder: &mut FunctionBuilder, - local_variables: &HashMap, Variable>, - global_variables: &HashMap, - ) -> Result { + local_variables: &HashMap, (Variable, ConstantType)>, + global_variables: &HashMap, + ) -> Result<(entities::Value, ConstantType), BackendError> { match self { - // Values are pretty straightforward to compile, mostly because we only - // have one type of variable, and it's an integer type. - Expression::Value(_, Value::Number(_, v)) => Ok(builder.ins().iconst(types::I64, v)), + Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables), - Expression::Reference(_, name) => { - // first we see if this is a local variable (which is nicer, from an - // optimization point of view.) - if let Some(local_var) = local_variables.get(&name) { - return Ok(builder.use_var(*local_var)); + Expression::Cast(_, target_type, expr) => { + let (val, val_type) = + expr.into_crane(builder, local_variables, global_variables)?; + + match (val_type, &target_type) { + (ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)), + (ConstantType::I8, Type::Primitive(PrimitiveType::I16)) => { + Ok((builder.ins().sextend(types::I16, val), ConstantType::I16)) + } + (ConstantType::I8, Type::Primitive(PrimitiveType::I32)) => { + Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)) + } + (ConstantType::I8, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)) + } + + (ConstantType::I16, Type::Primitive(PrimitiveType::I16)) => Ok((val, val_type)), + (ConstantType::I16, Type::Primitive(PrimitiveType::I32)) => { + Ok((builder.ins().sextend(types::I32, val), ConstantType::I32)) + } + (ConstantType::I16, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)) + } + + (ConstantType::I32, Type::Primitive(PrimitiveType::I32)) => Ok((val, val_type)), + (ConstantType::I32, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().sextend(types::I64, val), ConstantType::I64)) + } + + (ConstantType::I64, Type::Primitive(PrimitiveType::I64)) => Ok((val, val_type)), + + (ConstantType::U8, Type::Primitive(PrimitiveType::U8)) => Ok((val, val_type)), + (ConstantType::U8, Type::Primitive(PrimitiveType::U16)) => { + Ok((builder.ins().uextend(types::I16, val), ConstantType::U16)) + } + (ConstantType::U8, Type::Primitive(PrimitiveType::U32)) => { + Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)) + } + (ConstantType::U8, Type::Primitive(PrimitiveType::U64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)) + } + + (ConstantType::U16, Type::Primitive(PrimitiveType::U16)) => Ok((val, val_type)), + (ConstantType::U16, Type::Primitive(PrimitiveType::U32)) => { + Ok((builder.ins().uextend(types::I32, val), ConstantType::U32)) + } + (ConstantType::U16, Type::Primitive(PrimitiveType::U64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)) + } + + (ConstantType::U32, Type::Primitive(PrimitiveType::U32)) => Ok((val, val_type)), + (ConstantType::U32, Type::Primitive(PrimitiveType::U64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::U64)) + } + + (ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)), + + _ => Err(BackendError::InvalidTypeCast { + from: val_type.into(), + to: target_type, + }), } - - // then we check to see if this is a global reference, which requires us to - // first lookup where the value is stored, and then load it. - if let Some(global_var) = global_variables.get(name.as_ref()) { - let val_ptr = builder.ins().symbol_value(types::I64, *global_var); - return Ok(builder.ins().load(types::I64, MemFlags::new(), val_ptr, 0)); - } - - // this should never happen, because we should have made sure that there are - // no unbound variables a long time before this. but still ... - Err(BackendError::VariableLookupFailure(name)) } - Expression::Primitive(_, prim, mut vals) => { - // we're going to use `pop`, so we're going to pull and compile the right value ... - let right = - vals.pop() - .unwrap() - .into_crane(builder, local_variables, global_variables)?; - // ... and then the left. - let left = - vals.pop() - .unwrap() - .into_crane(builder, local_variables, global_variables)?; + Expression::Primitive(_, _, prim, mut vals) => { + let mut values = vec![]; + let mut first_type = None; + + for val in vals.drain(..) { + let (compiled, compiled_type) = + val.into_crane(builder, local_variables, global_variables)?; + + if let Some(leftmost_type) = first_type { + assert_eq!(leftmost_type, compiled_type); + } else { + first_type = Some(compiled_type); + } + + values.push(compiled); + } + + let first_type = first_type.expect("primitive op has at least one argument"); // then we just need to tell Cranelift how to do each of our primitives! Much // like Statements, above, we probably want to eventually shuffle this off into // a separate function (maybe something off `Primitive`), but for now it's simple // enough that we just do the `match` here. match prim { - Primitive::Plus => Ok(builder.ins().iadd(left, right)), - Primitive::Minus => Ok(builder.ins().isub(left, right)), - Primitive::Times => Ok(builder.ins().imul(left, right)), - Primitive::Divide => Ok(builder.ins().sdiv(left, right)), + Primitive::Plus => Ok((builder.ins().iadd(values[0], values[1]), first_type)), + Primitive::Minus if values.len() == 2 => { + Ok((builder.ins().isub(values[0], values[1]), first_type)) + } + Primitive::Minus => Ok((builder.ins().ineg(values[0]), first_type)), + Primitive::Times => Ok((builder.ins().imul(values[0], values[1]), first_type)), + Primitive::Divide if first_type.is_signed() => { + Ok((builder.ins().sdiv(values[0], values[1]), first_type)) + } + Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)), } } } @@ -291,9 +363,66 @@ impl ValueOrRef { fn into_crane( self, builder: &mut FunctionBuilder, - local_variables: &HashMap, Variable>, - global_variables: &HashMap, - ) -> Result { - Expression::from(self).into_crane(builder, local_variables, global_variables) + local_variables: &HashMap, (Variable, ConstantType)>, + global_variables: &HashMap, + ) -> Result<(entities::Value, ConstantType), BackendError> { + match self { + // Values are pretty straightforward to compile, mostly because we only + // have one type of variable, and it's an integer type. + ValueOrRef::Value(_, _, val) => match val { + Value::I8(_, v) => { + Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8)) + } + Value::I16(_, v) => Ok(( + builder.ins().iconst(types::I16, v as i64), + ConstantType::I16, + )), + Value::I32(_, v) => Ok(( + builder.ins().iconst(types::I32, v as i64), + ConstantType::I32, + )), + Value::I64(_, v) => Ok((builder.ins().iconst(types::I64, v), ConstantType::I64)), + Value::U8(_, v) => { + Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::U8)) + } + Value::U16(_, v) => Ok(( + builder.ins().iconst(types::I16, v as i64), + ConstantType::U16, + )), + Value::U32(_, v) => Ok(( + builder.ins().iconst(types::I32, v as i64), + ConstantType::U32, + )), + Value::U64(_, v) => Ok(( + builder.ins().iconst(types::I64, v as i64), + ConstantType::U64, + )), + }, + + ValueOrRef::Ref(_, _, name) => { + // first we see if this is a local variable (which is nicer, from an + // optimization point of view.) + if let Some((local_var, etype)) = local_variables.get(&name) { + return Ok((builder.use_var(*local_var), *etype)); + } + + // then we check to see if this is a global reference, which requires us to + // first lookup where the value is stored, and then load it. + if let Some((global_var, etype)) = global_variables.get(name.as_ref()) { + let cranelift_type = ir::Type::from(*etype); + let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var); + return Ok(( + builder + .ins() + .load(cranelift_type, MemFlags::new(), val_ptr, 0), + *etype, + )); + } + + // this should never happen, because we should have made sure that there are + // no unbound variables a long time before this. but still ... + Err(BackendError::VariableLookupFailure(name)) + } + } } } diff --git a/src/backend/runtime.rs b/src/backend/runtime.rs index a03acf7..40907aa 100644 --- a/src/backend/runtime.rs +++ b/src/backend/runtime.rs @@ -8,6 +8,8 @@ use std::fmt::Write; use target_lexicon::Triple; use thiserror::Error; +use crate::syntax::ConstantType; + /// An object for querying / using functions built into the runtime. /// /// Right now, this is a quite a bit of boilerplate for very nebulous @@ -49,7 +51,7 @@ impl RuntimeFunctions { "print", Linkage::Import, &Signature { - params: vec![string_param, string_param, int64_param], + params: vec![string_param, string_param, int64_param, int64_param], returns: vec![], call_conv: CallConv::triple_default(platform), }, @@ -98,13 +100,30 @@ impl RuntimeFunctions { // we extend with the output, so that multiple JIT'd `Program`s can run concurrently // without stomping over each other's output. If `output_buffer` is NULL, we just print // to stdout. -extern "C" fn runtime_print(output_buffer: *mut String, name: *const i8, value: i64) { +extern "C" fn runtime_print( + output_buffer: *mut String, + name: *const i8, + vtype_repr: i64, + value: i64, +) { let cstr = unsafe { CStr::from_ptr(name) }; let reconstituted = cstr.to_string_lossy(); + let output = match vtype_repr.try_into() { + Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8), + Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16), + Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32), + Ok(ConstantType::I64) => format!("{} = {}i64", reconstituted, value), + Ok(ConstantType::U8) => format!("{} = {}u8", reconstituted, value as u8), + Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16), + Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32), + Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64), + Err(_) => format!("{} = {}", reconstituted, value), + }; + if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { - writeln!(output_buffer, "{} = {}i64", reconstituted, value).unwrap(); + writeln!(output_buffer, "{}", output).unwrap(); } else { - println!("{} = {}", reconstituted, value); + println!("{}", output); } } diff --git a/src/bin/ngrc.rs b/src/bin/ngrc.rs index 821b0e2..e486ec0 100644 --- a/src/bin/ngrc.rs +++ b/src/bin/ngrc.rs @@ -17,7 +17,7 @@ fn main() { let args = CommandLineArguments::parse(); let mut compiler = ngr::Compiler::default(); - let output_file = args.output.unwrap_or("output.o".to_string()); + let output_file = args.output.unwrap_or_else(|| "output.o".to_string()); if let Some(bytes) = compiler.compile(&args.file) { std::fs::write(&output_file, bytes) diff --git a/src/compiler.rs b/src/compiler.rs index 41cc037..b94a200 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,6 +1,5 @@ -use crate::backend::Backend; -use crate::ir::Program as IR; use crate::syntax::Program as Syntax; +use crate::{backend::Backend, type_infer::TypeInferenceResult}; use codespan_reporting::{ diagnostic::Diagnostic, files::SimpleFiles, @@ -100,8 +99,38 @@ impl Compiler { return Ok(None); } - // Now that we've validated it, turn it into IR. - let ir = IR::from(syntax); + // Now that we've validated it, let's do type inference, potentially turning + // into IR while we're at it. + let ir = match syntax.type_infer() { + TypeInferenceResult::Failure { + mut errors, + mut warnings, + } => { + let messages = errors + .drain(..) + .map(Into::into) + .chain(warnings.drain(..).map(Into::into)); + + for message in messages { + self.emit(message); + } + + return Ok(None); + } + + TypeInferenceResult::Success { + result, + mut warnings, + } => { + let messages = warnings.drain(..).map(Into::into); + + for message in messages { + self.emit(message); + } + + result + } + }; // Finally, send all this to Cranelift for conversion into an object file. let mut backend = Backend::object_file(Triple::host())?; diff --git a/src/eval.rs b/src/eval.rs index cf77f74..a593a63 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -35,11 +35,13 @@ //! mod env; mod primop; +mod primtype; mod value; use cranelift_module::ModuleError; pub use env::{EvalEnvironment, LookupError}; pub use primop::PrimOpError; +pub use primtype::PrimitiveType; pub use value::Value; use crate::backend::BackendError; diff --git a/src/eval/env.rs b/src/eval/env.rs index a1a0320..551570c 100644 --- a/src/eval/env.rs +++ b/src/eval/env.rs @@ -87,9 +87,9 @@ mod tests { let tester = tester.extend(arced("bar"), 2i64.into()); let tester = tester.extend(arced("goo"), 5i64.into()); - assert_eq!(tester.lookup(arced("foo")), Ok(1.into())); - assert_eq!(tester.lookup(arced("bar")), Ok(2.into())); - assert_eq!(tester.lookup(arced("goo")), Ok(5.into())); + assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into())); + assert_eq!(tester.lookup(arced("bar")), Ok(2i64.into())); + assert_eq!(tester.lookup(arced("goo")), Ok(5i64.into())); assert!(tester.lookup(arced("baz")).is_err()); } @@ -103,14 +103,14 @@ mod tests { check_nested(&tester); - assert_eq!(tester.lookup(arced("foo")), Ok(1.into())); + assert_eq!(tester.lookup(arced("foo")), Ok(1i64.into())); assert!(tester.lookup(arced("bar")).is_err()); } fn check_nested(env: &EvalEnvironment) { let nested_env = env.extend(arced("bar"), 2i64.into()); - assert_eq!(nested_env.lookup(arced("foo")), Ok(1.into())); - assert_eq!(nested_env.lookup(arced("bar")), Ok(2.into())); + assert_eq!(nested_env.lookup(arced("foo")), Ok(1i64.into())); + assert_eq!(nested_env.lookup(arced("bar")), Ok(2i64.into())); } fn arced(s: &str) -> ArcIntern { diff --git a/src/eval/primop.rs b/src/eval/primop.rs index 49c014c..01a8dd1 100644 --- a/src/eval/primop.rs +++ b/src/eval/primop.rs @@ -1,3 +1,4 @@ +use crate::eval::primtype::PrimitiveType; use crate::eval::value::Value; /// Errors that can occur running primitive operations in the evaluators. @@ -22,6 +23,13 @@ pub enum PrimOpError { BadArgCount(String, usize), #[error("Unknown primitive operation {0}")] UnknownPrimOp(String), + #[error("Unsafe cast from {from} to {to}")] + UnsafeCast { + from: PrimitiveType, + to: PrimitiveType, + }, + #[error("Unknown primitive type {0}")] + UnknownPrimType(String), } // Implementing primitives in an interpreter like this is *super* tedious, @@ -37,39 +45,95 @@ pub enum PrimOpError { macro_rules! run_op { ($op: ident, $left: expr, $right: expr) => { match $op { - "+" => $left - .checked_add($right) - .ok_or(PrimOpError::MathFailure("+")) - .map(Into::into), - "-" => $left - .checked_sub($right) - .ok_or(PrimOpError::MathFailure("-")) - .map(Into::into), - "*" => $left - .checked_mul($right) - .ok_or(PrimOpError::MathFailure("*")) - .map(Into::into), - "/" => $left - .checked_div($right) - .ok_or(PrimOpError::MathFailure("/")) - .map(Into::into), + "+" => Ok($left.wrapping_add($right).into()), + "-" => Ok($left.wrapping_sub($right).into()), + "*" => Ok($left.wrapping_mul($right).into()), + "/" if $right == 0 => Err(PrimOpError::MathFailure("/")), + "/" => Ok($left.wrapping_div($right).into()), _ => Err(PrimOpError::UnknownPrimOp($op.to_string())), } }; } impl Value { + fn unary_op(operation: &str, value: &Value) -> Result { + match operation { + "-" => match value { + Value::I8(x) => Ok(Value::I8(x.wrapping_neg())), + Value::I16(x) => Ok(Value::I16(x.wrapping_neg())), + Value::I32(x) => Ok(Value::I32(x.wrapping_neg())), + Value::I64(x) => Ok(Value::I64(x.wrapping_neg())), + _ => Err(PrimOpError::BadTypeFor("-", value.clone())), + }, + _ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)), + } + } + fn binary_op(operation: &str, left: &Value, right: &Value) -> Result { match left { - // for now we only have one type, but in the future this is - // going to be very irritating. + Value::I8(x) => match right { + Value::I8(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, + Value::I16(x) => match right { + Value::I16(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, + Value::I32(x) => match right { + Value::I32(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, Value::I64(x) => match right { Value::I64(y) => run_op!(operation, x, *y), - // _ => Err(PrimOpError::TypeMismatch( - // operation.to_string(), - // left.clone(), - // right.clone(), - // )), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, + Value::U8(x) => match right { + Value::U8(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, + Value::U16(x) => match right { + Value::U16(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, + Value::U32(x) => match right { + Value::U32(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), + }, + Value::U64(x) => match right { + Value::U64(y) => run_op!(operation, x, *y), + _ => Err(PrimOpError::TypeMismatch( + operation.to_string(), + left.clone(), + right.clone(), + )), }, } } @@ -83,13 +147,10 @@ impl Value { /// its worth being careful to make sure that your inputs won't cause either /// condition. pub fn calculate(operation: &str, values: Vec) -> Result { - if values.len() == 2 { - Value::binary_op(operation, &values[0], &values[1]) - } else { - Err(PrimOpError::BadArgCount( - operation.to_string(), - values.len(), - )) + match values.len() { + 1 => Value::unary_op(operation, &values[0]), + 2 => Value::binary_op(operation, &values[0], &values[1]), + x => Err(PrimOpError::BadArgCount(operation.to_string(), x)), } } } diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs new file mode 100644 index 0000000..7690137 --- /dev/null +++ b/src/eval/primtype.rs @@ -0,0 +1,173 @@ +use crate::{ + eval::{PrimOpError, Value}, + syntax::ConstantType, +}; +use std::{fmt::Display, str::FromStr}; + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum PrimitiveType { + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, +} + +impl Display for PrimitiveType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PrimitiveType::I8 => write!(f, "i8"), + PrimitiveType::I16 => write!(f, "i16"), + PrimitiveType::I32 => write!(f, "i32"), + PrimitiveType::I64 => write!(f, "i64"), + PrimitiveType::U8 => write!(f, "u8"), + PrimitiveType::U16 => write!(f, "u16"), + PrimitiveType::U32 => write!(f, "u32"), + PrimitiveType::U64 => write!(f, "u64"), + } + } +} + +impl<'a> From<&'a Value> for PrimitiveType { + fn from(value: &Value) -> Self { + match value { + Value::I8(_) => PrimitiveType::I8, + Value::I16(_) => PrimitiveType::I16, + Value::I32(_) => PrimitiveType::I32, + Value::I64(_) => PrimitiveType::I64, + Value::U8(_) => PrimitiveType::U8, + Value::U16(_) => PrimitiveType::U16, + Value::U32(_) => PrimitiveType::U32, + Value::U64(_) => PrimitiveType::U64, + } + } +} + +impl From for PrimitiveType { + fn from(value: ConstantType) -> Self { + match value { + ConstantType::I8 => PrimitiveType::I8, + ConstantType::I16 => PrimitiveType::I16, + ConstantType::I32 => PrimitiveType::I32, + ConstantType::I64 => PrimitiveType::I64, + ConstantType::U8 => PrimitiveType::U8, + ConstantType::U16 => PrimitiveType::U16, + ConstantType::U32 => PrimitiveType::U32, + ConstantType::U64 => PrimitiveType::U64, + } + } +} + +impl FromStr for PrimitiveType { + type Err = PrimOpError; + + fn from_str(s: &str) -> Result { + match s { + "i8" => Ok(PrimitiveType::I8), + "i16" => Ok(PrimitiveType::I16), + "i32" => Ok(PrimitiveType::I32), + "i64" => Ok(PrimitiveType::I64), + "u8" => Ok(PrimitiveType::U8), + "u16" => Ok(PrimitiveType::U16), + "u32" => Ok(PrimitiveType::U32), + "u64" => Ok(PrimitiveType::U64), + _ => Err(PrimOpError::UnknownPrimType(s.to_string())), + } + } +} + +impl PrimitiveType { + /// Return true if this type can be safely cast into the target type. + pub fn can_cast_to(&self, target: &PrimitiveType) -> bool { + match self { + PrimitiveType::U8 => matches!( + target, + PrimitiveType::U8 + | PrimitiveType::U16 + | PrimitiveType::U32 + | PrimitiveType::U64 + | PrimitiveType::I16 + | PrimitiveType::I32 + | PrimitiveType::I64 + ), + PrimitiveType::U16 => matches!( + target, + PrimitiveType::U16 + | PrimitiveType::U32 + | PrimitiveType::U64 + | PrimitiveType::I32 + | PrimitiveType::I64 + ), + PrimitiveType::U32 => matches!( + target, + PrimitiveType::U32 | PrimitiveType::U64 | PrimitiveType::I64 + ), + PrimitiveType::U64 => target == &PrimitiveType::U64, + PrimitiveType::I8 => matches!( + target, + PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64 + ), + PrimitiveType::I16 => matches!( + target, + PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64 + ), + PrimitiveType::I32 => matches!(target, PrimitiveType::I32 | PrimitiveType::I64), + PrimitiveType::I64 => target == &PrimitiveType::I64, + } + } + + /// Try to cast the given value to this type, returning the new value. + /// + /// Returns an error if the cast is not safe *in* *general*. This means that + /// this function will error even if the number will actually fit in the target + /// type, but it would not be generally safe to cast a member of the given + /// type to the target type. (So, for example, "1i64" is a number that could + /// work as a "u64", but since negative numbers wouldn't work, a cast from + /// "1i64" to "u64" will fail.) + pub fn safe_cast(&self, source: &Value) -> Result { + match (self, source) { + (PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)), + (PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)), + (PrimitiveType::U16, Value::U16(x)) => Ok(Value::U16(*x)), + (PrimitiveType::U32, Value::U8(x)) => Ok(Value::U32(*x as u32)), + (PrimitiveType::U32, Value::U16(x)) => Ok(Value::U32(*x as u32)), + (PrimitiveType::U32, Value::U32(x)) => Ok(Value::U32(*x)), + (PrimitiveType::U64, Value::U8(x)) => Ok(Value::U64(*x as u64)), + (PrimitiveType::U64, Value::U16(x)) => Ok(Value::U64(*x as u64)), + (PrimitiveType::U64, Value::U32(x)) => Ok(Value::U64(*x as u64)), + (PrimitiveType::U64, Value::U64(x)) => Ok(Value::U64(*x)), + + (PrimitiveType::I8, Value::I8(x)) => Ok(Value::I8(*x)), + (PrimitiveType::I16, Value::I8(x)) => Ok(Value::I16(*x as i16)), + (PrimitiveType::I16, Value::I16(x)) => Ok(Value::I16(*x)), + (PrimitiveType::I32, Value::I8(x)) => Ok(Value::I32(*x as i32)), + (PrimitiveType::I32, Value::I16(x)) => Ok(Value::I32(*x as i32)), + (PrimitiveType::I32, Value::I32(x)) => Ok(Value::I32(*x)), + (PrimitiveType::I64, Value::I8(x)) => Ok(Value::I64(*x as i64)), + (PrimitiveType::I64, Value::I16(x)) => Ok(Value::I64(*x as i64)), + (PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)), + (PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)), + + _ => Err(PrimOpError::UnsafeCast { + from: source.into(), + to: *self, + }), + } + } + + pub fn max_value(&self) -> u64 { + match self { + PrimitiveType::U8 => u8::MAX as u64, + PrimitiveType::U16 => u16::MAX as u64, + PrimitiveType::U32 => u32::MAX as u64, + PrimitiveType::U64 => u64::MAX, + PrimitiveType::I8 => i8::MAX as u64, + PrimitiveType::I16 => i16::MAX as u64, + PrimitiveType::I32 => i32::MAX as u64, + PrimitiveType::I64 => i64::MAX as u64, + } + } +} diff --git a/src/eval/value.rs b/src/eval/value.rs index ba0b0bd..12f3746 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -7,19 +7,75 @@ use std::fmt::Display; /// by type so that we don't mix them up. #[derive(Clone, Debug, PartialEq)] pub enum Value { + I8(i8), + I16(i16), + I32(i32), I64(i64), + U8(u8), + U16(u16), + U32(u32), + U64(u64), } impl Display for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Value::I8(x) => write!(f, "{}i8", x), + Value::I16(x) => write!(f, "{}i16", x), + Value::I32(x) => write!(f, "{}i32", x), Value::I64(x) => write!(f, "{}i64", x), + Value::U8(x) => write!(f, "{}u8", x), + Value::U16(x) => write!(f, "{}u16", x), + Value::U32(x) => write!(f, "{}u32", x), + Value::U64(x) => write!(f, "{}u64", x), } } } +impl From for Value { + fn from(value: i8) -> Self { + Value::I8(value) + } +} + +impl From for Value { + fn from(value: i16) -> Self { + Value::I16(value) + } +} + +impl From for Value { + fn from(value: i32) -> Self { + Value::I32(value) + } +} + impl From for Value { fn from(value: i64) -> Self { Value::I64(value) } } + +impl From for Value { + fn from(value: u8) -> Self { + Value::U8(value) + } +} + +impl From for Value { + fn from(value: u16) -> Self { + Value::U16(value) + } +} + +impl From for Value { + fn from(value: u32) -> Self { + Value::U32(value) + } +} + +impl From for Value { + fn from(value: u64) -> Self { + Value::U64(value) + } +} diff --git a/src/examples.rs b/src/examples.rs new file mode 100644 index 0000000..82a864f --- /dev/null +++ b/src/examples.rs @@ -0,0 +1,6 @@ +use crate::backend::Backend; +use crate::syntax::Program as Syntax; +use codespan_reporting::files::SimpleFiles; +use cranelift_jit::JITModule; + +include!(concat!(env!("OUT_DIR"), "/examples.rs")); diff --git a/src/ir.rs b/src/ir.rs index 88454e4..af38b02 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -12,9 +12,8 @@ //! validating syntax, and then figuring out how to turn it into Cranelift //! and object code. After that point, however, this will be the module to //! come to for analysis and optimization work. -mod ast; +pub mod ast; mod eval; -mod from_syntax; mod strings; pub use ast::*; diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 3d8446d..478f393 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -1,10 +1,14 @@ -use crate::syntax::Location; +use crate::{ + eval::PrimitiveType, + syntax::{self, ConstantType, Location}, +}; use internment::ArcIntern; -use pretty::{DocAllocator, Pretty}; +use pretty::{BoxAllocator, DocAllocator, Pretty}; use proptest::{ prelude::Arbitrary, strategy::{BoxedStrategy, Strategy}, }; +use std::{fmt, str::FromStr}; /// We're going to represent variables as interned strings. /// @@ -52,12 +56,15 @@ where } impl Arbitrary for Program { - type Parameters = (); + type Parameters = crate::syntax::arbitrary::GenerationEnvironment; type Strategy = BoxedStrategy; fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { crate::syntax::Program::arbitrary_with(args) - .prop_map(Program::from) + .prop_map(|x| { + x.type_infer() + .expect("arbitrary_with should generate type-correct programs") + }) .boxed() } } @@ -74,8 +81,8 @@ impl Arbitrary for Program { /// #[derive(Debug)] pub enum Statement { - Binding(Location, Variable, Expression), - Print(Location, Variable), + Binding(Location, Variable, Type, Expression), + Print(Location, Type, Variable), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement @@ -85,13 +92,13 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - Statement::Binding(_, var, expr) => allocator + Statement::Binding(_, var, _, expr) => allocator .text(var.as_ref().to_string()) .append(allocator.space()) .append(allocator.text("=")) .append(allocator.space()) .append(expr.pretty(allocator)), - Statement::Print(_, var) => allocator + Statement::Print(_, _, var) => allocator .text("print") .append(allocator.space()) .append(allocator.text(var.as_ref().to_string())), @@ -113,9 +120,32 @@ where /// variable reference. #[derive(Debug)] pub enum Expression { - Value(Location, Value), - Reference(Location, Variable), - Primitive(Location, Primitive, Vec), + Atomic(ValueOrRef), + Cast(Location, Type, ValueOrRef), + Primitive(Location, Type, Primitive, Vec), +} + +impl Expression { + /// Return a reference to the type of the expression, as inferred or recently + /// computed. + pub fn type_of(&self) -> &Type { + match self { + Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, + Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, + Expression::Cast(_, t, _) => t, + Expression::Primitive(_, t, _, _) => t, + } + } + + /// Return a reference to the location associated with the expression. + pub fn location(&self) -> &Location { + match self { + Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, + Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, + Expression::Cast(l, _, _) => l, + Expression::Primitive(l, _, _, _) => l, + } + } } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression @@ -125,12 +155,16 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - Expression::Value(_, val) => val.pretty(allocator), - Expression::Reference(_, var) => allocator.text(var.as_ref().to_string()), - Expression::Primitive(_, op, exprs) if exprs.len() == 1 => { + Expression::Atomic(x) => x.pretty(allocator), + Expression::Cast(_, t, e) => allocator + .text("<") + .append(t.pretty(allocator)) + .append(allocator.text(">")) + .append(e.pretty(allocator)), + Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { op.pretty(allocator).append(exprs[0].pretty(allocator)) } - Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { + Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { let left = exprs[0].pretty(allocator); let right = exprs[1].pretty(allocator); @@ -140,7 +174,7 @@ where .append(right) .parens() } - Expression::Primitive(_, op, exprs) => { + Expression::Primitive(_, _, op, exprs) => { allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) } } @@ -161,10 +195,10 @@ pub enum Primitive { Divide, } -impl<'a> TryFrom<&'a str> for Primitive { - type Error = String; +impl FromStr for Primitive { + type Err = String; - fn try_from(value: &str) -> Result { + fn from_str(value: &str) -> Result { match value { "+" => Ok(Primitive::Plus), "-" => Ok(Primitive::Minus), @@ -190,15 +224,21 @@ where } } +impl fmt::Display for Primitive { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f) + } +} + /// An expression that is always either a value or a reference. /// /// This is the type used to guarantee that we don't nest expressions /// at this level. Instead, expressions that take arguments take one /// of these, which can only be a constant or a reference. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum ValueOrRef { - Value(Location, Value), - Ref(Location, ArcIntern), + Value(Location, Type, Value), + Ref(Location, Type, ArcIntern), } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef @@ -208,30 +248,50 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - ValueOrRef::Value(_, v) => v.pretty(allocator), - ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()), + ValueOrRef::Value(_, _, v) => v.pretty(allocator), + ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), } } } impl From for Expression { fn from(value: ValueOrRef) -> Self { - match value { - ValueOrRef::Value(loc, val) => Expression::Value(loc, val), - ValueOrRef::Ref(loc, var) => Expression::Reference(loc, var), - } + Expression::Atomic(value) } } /// A constant in the IR. -#[derive(Debug)] +/// +/// The optional argument in numeric types is the base that was used by the +/// user to input the number. By retaining it, we can ensure that if we need +/// to print the number back out, we can do so in the form that the user +/// entered it. +#[derive(Clone, Debug)] pub enum Value { - /// A numerical constant. - /// - /// The optional argument is the base that was used by the user to input - /// the number. By retaining it, we can ensure that if we need to print the - /// number back out, we can do so in the form that the user entered it. - Number(Option, i64), + I8(Option, i8), + I16(Option, i16), + I32(Option, i32), + I64(Option, i64), + U8(Option, u8), + U16(Option, u16), + U32(Option, u32), + U64(Option, u64), +} + +impl Value { + /// Return the type described by this value + pub fn type_of(&self) -> Type { + match self { + Value::I8(_, _) => Type::Primitive(PrimitiveType::I8), + Value::I16(_, _) => Type::Primitive(PrimitiveType::I16), + Value::I32(_, _) => Type::Primitive(PrimitiveType::I32), + Value::I64(_, _) => Type::Primitive(PrimitiveType::I64), + Value::U8(_, _) => Type::Primitive(PrimitiveType::U8), + Value::U16(_, _) => Type::Primitive(PrimitiveType::U16), + Value::U32(_, _) => Type::Primitive(PrimitiveType::U32), + Value::U64(_, _) => Type::Primitive(PrimitiveType::U64), + } + } } impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value @@ -240,19 +300,64 @@ where D: ?Sized + DocAllocator<'a, A>, { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - Value::Number(opt_base, value) => { - let value_str = match opt_base { - None => format!("{}", value), - Some(2) => format!("0b{:b}", value), - Some(8) => format!("0o{:o}", value), - Some(10) => format!("0d{}", value), - Some(16) => format!("0x{:x}", value), - Some(_) => format!("!!{:x}!!", value), - }; + let pretty_internal = |opt_base: &Option, x, t| { + syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator) + }; - allocator.text(value_str) + let pretty_internal_signed = |opt_base, x: i64, t| { + let base = pretty_internal(opt_base, x.unsigned_abs(), t); + + allocator.text("-").append(base) + }; + + match self { + Value::I8(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I8) } + Value::I16(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I16) + } + Value::I32(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I32) + } + Value::I64(opt_base, value) => { + pretty_internal_signed(opt_base, *value, ConstantType::I64) + } + Value::U8(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U8) + } + Value::U16(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U16) + } + Value::U32(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U32) + } + Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Type { + Primitive(PrimitiveType), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Type::Primitive(pt) => allocator.text(format!("{}", pt)), + } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Primitive(pt) => pt.fmt(f), } } } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 78b8b0b..6841508 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -1,7 +1,7 @@ use crate::eval::{EvalEnvironment, EvalError, Value}; use crate::ir::{Expression, Program, Statement}; -use super::{Primitive, ValueOrRef}; +use super::{Primitive, Type, ValueOrRef}; impl Program { /// Evaluate the program, returning either an error or a string containing everything @@ -14,12 +14,12 @@ impl Program { for stmt in self.statements.iter() { match stmt { - Statement::Binding(_, name, value) => { + Statement::Binding(_, name, _, value) => { let actual_value = value.eval(&env)?; env = env.extend(name.clone(), actual_value); } - Statement::Print(_, name) => { + Statement::Print(_, _, name) => { let value = env.lookup(name.clone())?; let line = format!("{} = {}\n", name, value); stdout.push_str(&line); @@ -34,26 +34,21 @@ impl Program { impl Expression { fn eval(&self, env: &EvalEnvironment) -> Result { match self { - Expression::Value(_, v) => match v { - super::Value::Number(_, v) => Ok(Value::I64(*v)), - }, + Expression::Atomic(x) => x.eval(env), - Expression::Reference(_, n) => Ok(env.lookup(n.clone())?), + Expression::Cast(_, t, valref) => { + let value = valref.eval(env)?; - Expression::Primitive(_, op, args) => { - let mut arg_values = Vec::with_capacity(args.len()); - - // we implement primitive operations by first evaluating each of the - // arguments to the function, and then gathering up all the values - // produced. - for arg in args.iter() { - match arg { - ValueOrRef::Ref(_, n) => arg_values.push(env.lookup(n.clone())?), - ValueOrRef::Value(_, super::Value::Number(_, v)) => { - arg_values.push(Value::I64(*v)) - } - } + match t { + Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), } + } + + Expression::Primitive(_, _, op, args) => { + let arg_values = args + .iter() + .map(|x| x.eval(env)) + .collect::, EvalError>>()?; // and then finally we call `calculate` to run them. trust me, it's nice // to not have to deal with all the nonsense hidden under `calculate`. @@ -68,19 +63,38 @@ impl Expression { } } +impl ValueOrRef { + fn eval(&self, env: &EvalEnvironment) -> Result { + match self { + ValueOrRef::Value(_, _, v) => match v { + super::Value::I8(_, v) => Ok(Value::I8(*v)), + super::Value::I16(_, v) => Ok(Value::I16(*v)), + super::Value::I32(_, v) => Ok(Value::I32(*v)), + super::Value::I64(_, v) => Ok(Value::I64(*v)), + super::Value::U8(_, v) => Ok(Value::U8(*v)), + super::Value::U16(_, v) => Ok(Value::U16(*v)), + super::Value::U32(_, v) => Ok(Value::U32(*v)), + super::Value::U64(_, v) => Ok(Value::U64(*v)), + }, + + ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?), + } + } +} + #[test] fn two_plus_three() { let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); - let ir = Program::from(input); + let ir = input.type_infer().expect("test should be type-valid"); let output = ir.eval().expect("runs successfully"); - assert_eq!("x = 5i64\n", &output); + assert_eq!("x = 5u64\n", &output); } #[test] fn lotsa_math() { let input = crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); - let ir = Program::from(input); + let ir = input.type_infer().expect("test should be type-valid"); let output = ir.eval().expect("runs successfully"); - assert_eq!("x = 7i64\n", &output); + assert_eq!("x = 7u64\n", &output); } diff --git a/src/ir/from_syntax.rs b/src/ir/from_syntax.rs deleted file mode 100644 index 46c7c69..0000000 --- a/src/ir/from_syntax.rs +++ /dev/null @@ -1,186 +0,0 @@ -use internment::ArcIntern; -use std::sync::atomic::AtomicUsize; - -use crate::ir::ast as ir; -use crate::syntax; - -use super::ValueOrRef; - -impl From for ir::Program { - /// We implement the top-level conversion of a syntax::Program into an - /// ir::Program using just the standard `From::from`, because we don't - /// need to return any arguments and we shouldn't produce any errors. - /// Technically there's an `unwrap` deep under the hood that we could - /// float out, but the validator really should've made sure that never - /// happens, so we're just going to assume. - fn from(mut value: syntax::Program) -> Self { - let mut statements = Vec::new(); - - for stmt in value.statements.drain(..) { - statements.append(&mut stmt.simplify()); - } - - ir::Program { statements } - } -} - -impl From for ir::Program { - /// One interesting thing about this conversion is that there isn't - /// a natural translation from syntax::Statement to ir::Statement, - /// because the syntax version can have nested expressions and the - /// IR version can't. - /// - /// As a result, we can naturally convert a syntax::Statement into - /// an ir::Program, because we can allow the additional binding - /// sites to be generated, instead. And, bonus, it turns out that - /// this is what we wanted anyways. - fn from(value: syntax::Statement) -> Self { - ir::Program { - statements: value.simplify(), - } - } -} - -impl syntax::Statement { - /// Simplify a syntax::Statement into a series of ir::Statements. - /// - /// The reason this function is one-to-many is because we may have to - /// introduce new binding sites in order to avoid having nested - /// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed - /// in syntax::Expression but are expressly *not* allowed in - /// ir::Expression. So this pass converts them into bindings, like - /// this: - /// - /// x = (1 + 2) * 3; - /// - /// ==> - /// - /// x:1 = 1 + 2; - /// x:2 = x:1 * 3; - /// x = x:2 - /// - /// Thus ensuring that things are nice and simple. Note that the - /// binding of `x:2` is not, strictly speaking, necessary, but it - /// makes the code below much easier to read. - fn simplify(self) -> Vec { - let mut new_statements = vec![]; - - match self { - // Print statements we don't have to do much with - syntax::Statement::Print(loc, name) => { - new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name))) - } - - // Bindings, however, may involve a single expression turning into - // a series of statements and then an expression. - syntax::Statement::Binding(loc, name, value) => { - let (mut prereqs, new_value) = value.rebind(&name); - new_statements.append(&mut prereqs); - new_statements.push(ir::Statement::Binding( - loc, - ArcIntern::new(name), - new_value.into(), - )) - } - } - - new_statements - } -} - -impl syntax::Expression { - /// This actually does the meat of the simplification work, here, by rebinding - /// any nested expressions into their own variables. We have this return - /// `ValueOrRef` in all cases because it makes for slighly less code; in the - /// case when we actually want an `Expression`, we can just use `into()`. - fn rebind(self, base_name: &str) -> (Vec, ir::ValueOrRef) { - match self { - // Values just convert in the obvious way, and require no prereqs - syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())), - - // Similarly, references just convert in the obvious way, and require - // no prereqs - syntax::Expression::Reference(loc, name) => { - (vec![], ValueOrRef::Ref(loc, ArcIntern::new(name))) - } - - // Primitive expressions are where we do the real work. - syntax::Expression::Primitive(loc, prim, mut expressions) => { - // generate a fresh new name for the binding site we're going to - // introduce, basing the name on wherever we came from; so if this - // expression was bound to `x` originally, it might become `x:23`. - // - // gensym is guaranteed to give us a name that is unused anywhere - // else in the program. - let new_name = gensym(base_name); - let mut prereqs = Vec::new(); - let mut new_exprs = Vec::new(); - - // here we loop through every argument, and recurse on the expressions - // we find. that will give us any new binding sites that *they* introduce, - // and a simple value or reference that we can use in our result. - for expr in expressions.drain(..) { - let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str()); - prereqs.append(&mut cur_prereqs); - new_exprs.push(arg); - } - - // now we're going to use those new arguments to run the primitive, binding - // the results to the new variable we introduced. - let prim = - ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function"); - prereqs.push(ir::Statement::Binding( - loc.clone(), - new_name.clone(), - ir::Expression::Primitive(loc.clone(), prim, new_exprs), - )); - - // and finally, we can return all the new bindings, and a reference to - // the variable we just introduced to hold the value of the primitive - // invocation. - (prereqs, ValueOrRef::Ref(loc, new_name)) - } - } - } -} - -impl From for ir::Value { - fn from(value: syntax::Value) -> Self { - match value { - syntax::Value::Number(base, val) => ir::Value::Number(base, val), - } - } -} - -impl From for ir::Primitive { - fn from(value: String) -> Self { - value.try_into().unwrap() - } -} - -/// Generate a fresh new name based on the given name. -/// -/// The new name is guaranteed to be unique across the entirety of the -/// execution. This is achieved by using characters in the variable name -/// that would not be valid input, and by including a counter that is -/// incremented on every invocation. -fn gensym(name: &str) -> ArcIntern { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let new_name = format!( - "<{}:{}>", - name, - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - ); - ArcIntern::new(new_name) -} - -proptest::proptest! { - #[test] - fn translation_maintains_semantics(input: syntax::Program) { - let syntax_result = input.eval(); - let ir = ir::Program::from(input); - let ir_result = ir.eval(); - assert_eq!(syntax_result, ir_result); - } -} diff --git a/src/ir/strings.rs b/src/ir/strings.rs index f7b291e..70f939f 100644 --- a/src/ir/strings.rs +++ b/src/ir/strings.rs @@ -21,12 +21,12 @@ impl Program { impl Statement { fn register_strings(&self, string_set: &mut HashSet>) { match self { - Statement::Binding(_, name, expr) => { + Statement::Binding(_, name, _, expr) => { string_set.insert(name.clone()); expr.register_strings(string_set); } - Statement::Print(_, name) => { + Statement::Print(_, _, name) => { string_set.insert(name.clone()); } } diff --git a/src/lib.rs b/src/lib.rs index 18ade12..c155075 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,8 +63,11 @@ //! pub mod backend; pub mod eval; +#[cfg(test)] +mod examples; pub mod ir; pub mod syntax; +pub mod type_infer; /// Implementation module for the high-level compiler. mod compiler; diff --git a/src/repl.rs b/src/repl.rs index 5d511b5..fcf6b62 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,6 +1,6 @@ use crate::backend::{Backend, BackendError}; -use crate::ir::Program as IR; -use crate::syntax::{Location, ParserError, Statement}; +use crate::syntax::{ConstantType, Location, ParserError, Statement}; +use crate::type_infer::TypeInferenceResult; use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; @@ -129,17 +129,32 @@ impl REPL { .source(); let syntax = Statement::parse(entry, source)?; - // if this is a variable binding, and we've never defined this variable before, - // we should tell cranelift about it. this is optimistic; if we fail to compile, - // then we won't use this definition until someone tries again. - if let Statement::Binding(_, ref name, _) = syntax { - if !self.variable_binding_sites.contains_key(name.as_str()) { - self.jitter.define_string(name)?; - self.jitter.define_variable(name.clone())?; + let program = match syntax { + Statement::Binding(loc, name, expr) => { + // if this is a variable binding, and we've never defined this variable before, + // we should tell cranelift about it. this is optimistic; if we fail to compile, + // then we won't use this definition until someone tries again. + if !self.variable_binding_sites.contains_key(&name.name) { + self.jitter.define_string(&name.name)?; + self.jitter + .define_variable(name.to_string(), ConstantType::U64)?; + } + + crate::syntax::Program { + statements: vec![ + Statement::Binding(loc.clone(), name.clone(), expr), + Statement::Print(loc, name), + ], + } } + + nonbinding => crate::syntax::Program { + statements: vec![nonbinding], + }, }; - let (mut errors, mut warnings) = syntax.validate(&mut self.variable_binding_sites); + let (mut errors, mut warnings) = + program.validate_with_bindings(&mut self.variable_binding_sites); let stop = !errors.is_empty(); let messages = errors .drain(..) @@ -154,13 +169,39 @@ impl REPL { return Ok(()); } - let ir = IR::from(syntax); - let name = format!("line{}", line_no); - let function_id = self.jitter.compile_function(&name, ir)?; - self.jitter.module.finalize_definitions()?; - let compiled_bytes = self.jitter.bytes(function_id); - let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; - compiled_function(); - Ok(()) + match program.type_infer() { + TypeInferenceResult::Failure { + mut errors, + mut warnings, + } => { + let messages = errors + .drain(..) + .map(Into::into) + .chain(warnings.drain(..).map(Into::into)); + + for message in messages { + self.emit_diagnostic(message)?; + } + + Ok(()) + } + + TypeInferenceResult::Success { + result, + mut warnings, + } => { + for message in warnings.drain(..).map(Into::into) { + self.emit_diagnostic(message)?; + } + let name = format!("line{}", line_no); + let function_id = self.jitter.compile_function(&name, result)?; + self.jitter.module.finalize_definitions()?; + let compiled_bytes = self.jitter.bytes(function_id); + let compiled_function = + unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; + compiled_function(); + Ok(()) + } + } } } diff --git a/src/syntax.rs b/src/syntax.rs index 0ed88ee..8d505a1 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -27,7 +27,7 @@ use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles}; use lalrpop_util::lalrpop_mod; use logos::Logos; -mod arbitrary; +pub mod arbitrary; mod ast; mod eval; mod location; @@ -40,6 +40,8 @@ lalrpop_mod!( mod pretty; mod validate; +#[cfg(test)] +use crate::syntax::arbitrary::GenerationEnvironment; pub use crate::syntax::ast::*; pub use crate::syntax::location::Location; pub use crate::syntax::parser::{ProgramParser, StatementParser}; @@ -48,7 +50,7 @@ pub use crate::syntax::tokens::{LexerError, Token}; use ::pretty::{Arena, Pretty}; use lalrpop_util::ParseError; #[cfg(test)] -use proptest::{prop_assert, prop_assert_eq}; +use proptest::{arbitrary::Arbitrary, prop_assert, prop_assert_eq}; #[cfg(test)] use std::str::FromStr; use thiserror::Error; @@ -73,12 +75,12 @@ pub enum ParserError { /// Raised when we're parsing the file, and run into a token in a /// place we weren't expecting it. #[error("Unrecognized token")] - UnrecognizedToken(Location, Location, Token, Vec), + UnrecognizedToken(Location, Token, Vec), /// Raised when we were expecting the end of the file, but instead /// got another token. #[error("Extra token")] - ExtraToken(Location, Token, Location), + ExtraToken(Location, Token), /// Raised when the lexer just had some sort of internal problem /// and just gave up. @@ -106,30 +108,24 @@ impl ParserError { fn convert(file_idx: usize, err: ParseError) -> Self { match err { ParseError::InvalidToken { location } => { - ParserError::InvalidToken(Location::new(file_idx, location)) - } - ParseError::UnrecognizedEof { location, expected } => { - ParserError::UnrecognizedEOF(Location::new(file_idx, location), expected) + ParserError::InvalidToken(Location::new(file_idx, location..location + 1)) } + ParseError::UnrecognizedEof { location, expected } => ParserError::UnrecognizedEOF( + Location::new(file_idx, location..location + 1), + expected, + ), ParseError::UnrecognizedToken { token: (start, token, end), expected, - } => ParserError::UnrecognizedToken( - Location::new(file_idx, start), - Location::new(file_idx, end), - token, - expected, - ), + } => { + ParserError::UnrecognizedToken(Location::new(file_idx, start..end), token, expected) + } ParseError::ExtraToken { token: (start, token, end), - } => ParserError::ExtraToken( - Location::new(file_idx, start), - token, - Location::new(file_idx, end), - ), + } => ParserError::ExtraToken(Location::new(file_idx, start..end), token), ParseError::User { error } => match error { LexerError::LexFailure(offset) => { - ParserError::LexFailure(Location::new(file_idx, offset)) + ParserError::LexFailure(Location::new(file_idx, offset..offset + 1)) } }, } @@ -180,37 +176,25 @@ impl<'a> From<&'a ParserError> for Diagnostic { ), // encountered a token where it shouldn't be - ParserError::UnrecognizedToken(start, end, token, expected) => { + ParserError::UnrecognizedToken(loc, token, expected) => { let expected_str = format!("unexpected token {}{}", token, display_expected(expected)); let unexpected_str = format!("unexpected token {}", token); - let labels = start.range_label(end); Diagnostic::error() - .with_labels( - labels - .into_iter() - .map(|l| l.with_message(unexpected_str.clone())) - .collect(), - ) .with_message(expected_str) + .with_labels(vec![loc.primary_label().with_message(unexpected_str)]) } // I think we get this when we get a token, but were expected EOF - ParserError::ExtraToken(start, token, end) => { + ParserError::ExtraToken(loc, token) => { let expected_str = format!("unexpected token {} after the expected end of file", token); let unexpected_str = format!("unexpected token {}", token); - let labels = start.range_label(end); Diagnostic::error() - .with_labels( - labels - .into_iter() - .map(|l| l.with_message(unexpected_str.clone())) - .collect(), - ) .with_message(expected_str) + .with_labels(vec![loc.primary_label().with_message(unexpected_str)]) } // simple lexer errors @@ -293,24 +277,27 @@ fn order_of_operations() { Program::from_str(muladd1).unwrap(), Program { statements: vec![Statement::Binding( - Location::new(testfile, 0), - "x".to_string(), + Location::new(testfile, 0..1), + Name::manufactured("x"), Expression::Primitive( - Location::new(testfile, 6), + Location::new(testfile, 6..7), "+".to_string(), vec![ - Expression::Value(Location::new(testfile, 4), Value::Number(None, 1)), + Expression::Value( + Location::new(testfile, 4..5), + Value::Number(None, None, 1), + ), Expression::Primitive( - Location::new(testfile, 10), + Location::new(testfile, 10..11), "*".to_string(), vec![ Expression::Value( - Location::new(testfile, 8), - Value::Number(None, 2), + Location::new(testfile, 8..9), + Value::Number(None, None, 2), ), Expression::Value( - Location::new(testfile, 12), - Value::Number(None, 3), + Location::new(testfile, 12..13), + Value::Number(None, None, 3), ), ] ) @@ -350,8 +337,8 @@ proptest::proptest! { } #[test] - fn generated_run_or_overflow(program: Program) { + fn generated_run_or_overflow(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { use crate::eval::{EvalError, PrimOpError}; - assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))) + prop_assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))); } } diff --git a/src/syntax/arbitrary.rs b/src/syntax/arbitrary.rs index 52f43ab..93759e9 100644 --- a/src/syntax/arbitrary.rs +++ b/src/syntax/arbitrary.rs @@ -1,136 +1,189 @@ -use std::collections::HashSet; - -use crate::syntax::ast::{Expression, Program, Statement, Value}; +use crate::syntax::ast::{ConstantType, Expression, Name, Program, Statement, Value}; use crate::syntax::location::Location; use proptest::sample::select; use proptest::{ prelude::{Arbitrary, BoxedStrategy, Strategy}, strategy::{Just, Union}, }; +use std::collections::HashMap; +use std::ops::Range; const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*"; -#[derive(Debug)] -struct Name(String); +impl ConstantType { + fn get_operators(&self) -> &'static [(&'static str, usize)] { + match self { + ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => { + &[("+", 2), ("-", 1), ("-", 2), ("*", 2), ("/", 2)] + } + ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => { + &[("+", 2), ("-", 2), ("*", 2), ("/", 2)] + } + } + } +} + +#[derive(Clone)] +pub struct GenerationEnvironment { + allow_inference: bool, + block_length: Range, + bindings: HashMap, + return_type: ConstantType, +} + +impl Default for GenerationEnvironment { + fn default() -> Self { + GenerationEnvironment { + allow_inference: true, + block_length: 2..10, + bindings: HashMap::new(), + return_type: ConstantType::U64, + } + } +} + +impl GenerationEnvironment { + pub fn new(allow_inference: bool) -> Self { + GenerationEnvironment { + allow_inference, + ..Default::default() + } + } +} + +impl Arbitrary for Program { + type Parameters = GenerationEnvironment; + type Strategy = BoxedStrategy; + + fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { + proptest::collection::vec( + ProgramStatementInfo::arbitrary(), + genenv.block_length.clone(), + ) + .prop_flat_map(move |mut items| { + let mut statements = Vec::new(); + let mut genenv = genenv.clone(); + + for psi in items.drain(..) { + if genenv.bindings.is_empty() || psi.should_be_binding { + genenv.return_type = psi.binding_type; + let expr = Expression::arbitrary_with(genenv.clone()); + genenv.bindings.insert(psi.name.clone(), psi.binding_type); + statements.push( + expr.prop_map(move |expr| { + Statement::Binding(Location::manufactured(), psi.name.clone(), expr) + }) + .boxed(), + ); + } else { + let printers = genenv.bindings.keys().map(|n| { + Just(Statement::Print( + Location::manufactured(), + Name::manufactured(n), + )) + }); + statements.push(Union::new(printers).boxed()); + } + } + + statements + .prop_map(|statements| Program { statements }) + .boxed() + }) + .boxed() + } +} impl Arbitrary for Name { type Parameters = (); type Strategy = BoxedStrategy; fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - VALID_VARIABLE_NAMES.prop_map(Name).boxed() + VALID_VARIABLE_NAMES.prop_map(Name::manufactured).boxed() } } -impl Arbitrary for Program { +#[derive(Debug)] +struct ProgramStatementInfo { + should_be_binding: bool, + name: Name, + binding_type: ConstantType, +} + +impl Arbitrary for ProgramStatementInfo { type Parameters = (); type Strategy = BoxedStrategy; - fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - let optionals = Vec::>::arbitrary(); - - optionals - .prop_flat_map(|mut possible_names| { - let mut statements = Vec::new(); - let mut defined_variables: HashSet = HashSet::new(); - - for possible_name in possible_names.drain(..) { - match possible_name { - None if defined_variables.is_empty() => continue, - None => statements.push( - Union::new(defined_variables.iter().map(|name| { - Just(Statement::Print(Location::manufactured(), name.to_string())) - })) - .boxed(), - ), - Some(new_name) => { - let closures_name = new_name.0.clone(); - let retval = - Expression::arbitrary_with(Some(defined_variables.clone())) - .prop_map(move |exp| { - Statement::Binding( - Location::manufactured(), - closures_name.clone(), - exp, - ) - }) - .boxed(); - - defined_variables.insert(new_name.0); - statements.push(retval); - } - } - } - - statements - }) - .prop_map(|statements| Program { statements }) + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + ( + Union::new(vec![Just(true), Just(true), Just(false)]), + Name::arbitrary(), + ConstantType::arbitrary(), + ) + .prop_map( + |(should_be_binding, name, binding_type)| ProgramStatementInfo { + should_be_binding, + name, + binding_type, + }, + ) .boxed() } } -impl Arbitrary for Statement { - type Parameters = Option>; - type Strategy = BoxedStrategy; - - fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - let duplicated_args = args.clone(); - let defined_variables = args.unwrap_or_default(); - - let binding_strategy = ( - VALID_VARIABLE_NAMES, - Expression::arbitrary_with(duplicated_args), - ) - .prop_map(|(name, exp)| Statement::Binding(Location::manufactured(), name, exp)) - .boxed(); - - if defined_variables.is_empty() { - binding_strategy - } else { - let print_strategy = Union::new( - defined_variables - .iter() - .map(|x| Just(Statement::Print(Location::manufactured(), x.to_string()))), - ) - .boxed(); - - Union::new([binding_strategy, print_strategy]).boxed() - } - } -} - impl Arbitrary for Expression { - type Parameters = Option>; + type Parameters = GenerationEnvironment; type Strategy = BoxedStrategy; - fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - let defined_variables = args.unwrap_or_default(); - - let value_strategy = Value::arbitrary() - .prop_map(move |x| Expression::Value(Location::manufactured(), x)) + fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { + // Value(Location, Value). These are the easiest variations to create, because we can always + // create one. + let value_strategy = Value::arbitrary_with(genenv.clone()) + .prop_map(|x| Expression::Value(Location::manufactured(), x)) .boxed(); - let leaf_strategy = if defined_variables.is_empty() { + // Reference(Location, String), These are slightly trickier, because we can end up in a situation + // where either no variables are defined, or where none of the defined variables have a type we + // can work with. So what we're going to do is combine this one with the previous one as a "leaf + // strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy + // if we can't actually create an references. + let mut bound_variables_of_type = genenv + .bindings + .iter() + .filter(|(_, v)| genenv.return_type == **v) + .map(|(n, _)| n) + .collect::>(); + let leaf_strategy = if bound_variables_of_type.is_empty() { value_strategy } else { - let reference_strategy = Union::new(defined_variables.iter().map(|x| { - Just(Expression::Reference( - Location::manufactured(), - x.to_owned(), - )) - })) - .boxed(); - Union::new([value_strategy, reference_strategy]).boxed() + let mut strats = bound_variables_of_type + .drain(..) + .map(|x| { + Just(Expression::Reference( + Location::manufactured(), + x.name.clone(), + )) + .boxed() + }) + .collect::>(); + strats.push(value_strategy); + Union::new(strats).boxed() }; + // now we generate our recursive types, given our leaf strategy leaf_strategy - .prop_recursive(3, 64, 2, move |inner| { + .prop_recursive(3, 10, 2, move |strat| { ( - select(super::BINARY_OPERATORS), - proptest::collection::vec(inner, 2), + select(genenv.return_type.get_operators()), + strat.clone(), + strat, ) - .prop_map(move |(operator, exprs)| { - Expression::Primitive(Location::manufactured(), operator.to_string(), exprs) + .prop_map(|((oper, count), left, right)| { + let mut args = vec![left, right]; + while args.len() > count { + args.pop(); + } + Expression::Primitive(Location::manufactured(), oper.to_string(), args) }) }) .boxed() @@ -138,22 +191,57 @@ impl Arbitrary for Expression { } impl Arbitrary for Value { - type Parameters = (); + type Parameters = GenerationEnvironment; type Strategy = BoxedStrategy; - fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - let base_strategy = Union::new([ + fn arbitrary_with(genenv: Self::Parameters) -> Self::Strategy { + let printed_base_strategy = Union::new([ Just(None::), Just(Some(2)), Just(Some(8)), Just(Some(10)), Just(Some(16)), ]); + let value_strategy = u64::arbitrary(); - let value_strategy = i64::arbitrary(); - - (base_strategy, value_strategy) - .prop_map(move |(base, value)| Value::Number(base, value)) + (printed_base_strategy, bool::arbitrary(), value_strategy) + .prop_map(move |(base, declare_type, value)| { + let converted_value = match genenv.return_type { + ConstantType::I8 => value % (i8::MAX as u64), + ConstantType::U8 => value % (u8::MAX as u64), + ConstantType::I16 => value % (i16::MAX as u64), + ConstantType::U16 => value % (u16::MAX as u64), + ConstantType::I32 => value % (i32::MAX as u64), + ConstantType::U32 => value % (u32::MAX as u64), + ConstantType::I64 => value % (i64::MAX as u64), + ConstantType::U64 => value, + }; + let ty = if declare_type || !genenv.allow_inference { + Some(genenv.return_type) + } else { + None + }; + Value::Number(base, ty, converted_value) + }) .boxed() } } + +impl Arbitrary for ConstantType { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + Union::new([ + Just(ConstantType::I8), + Just(ConstantType::I16), + Just(ConstantType::I32), + Just(ConstantType::I64), + Just(ConstantType::U8), + Just(ConstantType::U16), + Just(ConstantType::U32), + Just(ConstantType::U64), + ]) + .boxed() + } +} diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index d71e872..6700a9f 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -1,7 +1,10 @@ -use crate::syntax::Location; +use std::fmt; +use std::hash::Hash; -/// The set of valid binary operators. -pub static BINARY_OPERATORS: &[&str] = &["+", "-", "*", "/"]; +use internment::ArcIntern; + +pub use crate::syntax::tokens::ConstantType; +use crate::syntax::Location; /// A structure represented a parsed program. /// @@ -16,6 +19,56 @@ pub struct Program { pub statements: Vec, } +/// A Name. +/// +/// This is basically a string, but annotated with the place the string +/// is in the source file. +#[derive(Clone, Debug)] +pub struct Name { + pub name: String, + pub location: Location, +} + +impl Name { + pub fn new(n: S, location: Location) -> Name { + Name { + name: n.to_string(), + location, + } + } + + pub fn manufactured(n: S) -> Name { + Name { + name: n.to_string(), + location: Location::manufactured(), + } + } + + pub fn intern(self) -> ArcIntern { + ArcIntern::new(self.name) + } +} + +impl PartialEq for Name { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} + +impl Eq for Name {} + +impl Hash for Name { + fn hash(&self, state: &mut H) { + self.name.hash(state) + } +} + +impl fmt::Display for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.name.fmt(f) + } +} + /// A parsed statement. /// /// Statements are guaranteed to be syntactically valid, but may be @@ -29,8 +82,8 @@ pub struct Program { /// thing, not if they are the exact same statement. #[derive(Clone, Debug)] pub enum Statement { - Binding(Location, String, Expression), - Print(Location, String), + Binding(Location, Name, Expression), + Print(Location, Name), } impl PartialEq for Statement { @@ -58,6 +111,7 @@ impl PartialEq for Statement { pub enum Expression { Value(Location, Value), Reference(Location, String), + Cast(Location, String, Box), Primitive(Location, String, Vec), } @@ -72,6 +126,10 @@ impl PartialEq for Expression { Expression::Reference(_, var2) => var1 == var2, _ => false, }, + Expression::Cast(_, t1, e1) => match other { + Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2, + _ => false, + }, Expression::Primitive(_, prim1, args1) => match other { Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, _ => false, @@ -83,6 +141,12 @@ impl PartialEq for Expression { /// A value from the source syntax #[derive(Clone, Debug, PartialEq, Eq)] pub enum Value { - /// The value of the number, and an optional base that it was written in - Number(Option, i64), + /// The value of the number, an optional base that it was written in, and any + /// type information provided. + /// + /// u64 is chosen because it should be big enough to carry the amount of + /// information we need, and technically we interpret -4 as the primitive unary + /// operation "-" on the number 4. We'll translate this into a type-specific + /// number at a later time. + Number(Option, Option, u64), } diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index 6504e26..d6fda74 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -1,7 +1,8 @@ use internment::ArcIntern; -use crate::eval::{EvalEnvironment, EvalError, Value}; -use crate::syntax::{Expression, Program, Statement}; +use crate::eval::{EvalEnvironment, EvalError, PrimitiveType, Value}; +use crate::syntax::{ConstantType, Expression, Program, Statement}; +use std::str::FromStr; impl Program { /// Evaluate the program, returning either an error or what it prints out when run. @@ -24,11 +25,11 @@ impl Program { match stmt { Statement::Binding(_, name, value) => { let actual_value = value.eval(&env)?; - env = env.extend(ArcIntern::new(name.clone()), actual_value); + env = env.extend(name.clone().intern(), actual_value); } Statement::Print(_, name) => { - let value = env.lookup(ArcIntern::new(name.clone()))?; + let value = env.lookup(name.clone().intern())?; let line = format!("{} = {}\n", name, value); stdout.push_str(&line); } @@ -43,11 +44,28 @@ impl Expression { fn eval(&self, env: &EvalEnvironment) -> Result { match self { Expression::Value(_, v) => match v { - super::Value::Number(_, v) => Ok(Value::I64(*v)), + super::Value::Number(_, ty, v) => match ty { + None => Ok(Value::U64(*v)), + // FIXME: make these types validate their input size + Some(ConstantType::I8) => Ok(Value::I8(*v as i8)), + Some(ConstantType::I16) => Ok(Value::I16(*v as i16)), + Some(ConstantType::I32) => Ok(Value::I32(*v as i32)), + Some(ConstantType::I64) => Ok(Value::I64(*v as i64)), + Some(ConstantType::U8) => Ok(Value::U8(*v as u8)), + Some(ConstantType::U16) => Ok(Value::U16(*v as u16)), + Some(ConstantType::U32) => Ok(Value::U32(*v as u32)), + Some(ConstantType::U64) => Ok(Value::U64(*v)), + }, }, Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?), + Expression::Cast(_, target, expr) => { + let target_type = PrimitiveType::from_str(target)?; + let value = expr.eval(env)?; + Ok(target_type.safe_cast(&value)?) + } + Expression::Primitive(_, op, args) => { let mut arg_values = Vec::with_capacity(args.len()); @@ -66,12 +84,12 @@ impl Expression { fn two_plus_three() { let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let output = input.eval().expect("runs successfully"); - assert_eq!("x = 5i64\n", &output); + assert_eq!("x = 5u64\n", &output); } #[test] fn lotsa_math() { let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); let output = input.eval().expect("runs successfully"); - assert_eq!("x = 7i64\n", &output); + assert_eq!("x = 7u64\n", &output); } diff --git a/src/syntax/location.rs b/src/syntax/location.rs index 3c97d3d..5acd070 100644 --- a/src/syntax/location.rs +++ b/src/syntax/location.rs @@ -1,13 +1,15 @@ +use std::ops::Range; + use codespan_reporting::diagnostic::{Diagnostic, Label}; /// A source location, for use in pointing users towards warnings and errors. /// /// Internally, locations are very tied to the `codespan_reporting` library, /// and the primary use of them is to serve as anchors within that library. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct Location { file_idx: usize, - offset: usize, + location: Range, } impl Location { @@ -17,8 +19,8 @@ impl Location { /// The file index is based on the file database being used. See the /// `codespan_reporting::files::SimpleFiles::add` function, which is /// normally where we get this index. - pub fn new(file_idx: usize, offset: usize) -> Self { - Location { file_idx, offset } + pub fn new(file_idx: usize, location: Range) -> Self { + Location { file_idx, location } } /// Generate a `Location` for a completely manufactured bit of code. @@ -30,7 +32,7 @@ impl Location { pub fn manufactured() -> Self { Location { file_idx: 0, - offset: 0, + location: 0..0, } } @@ -47,7 +49,7 @@ impl Location { /// actually happened), but you'd probably want to make the first location /// the secondary label to help users find it. pub fn primary_label(&self) -> Label { - Label::primary(self.file_idx, self.offset..self.offset) + Label::primary(self.file_idx, self.location.clone()) } /// Generate a secondary label for a [`Diagnostic`], based on this source @@ -64,35 +66,7 @@ impl Location { /// probably want to make the first location the secondary label to help /// users find it. pub fn secondary_label(&self) -> Label { - Label::secondary(self.file_idx, self.offset..self.offset) - } - - /// Given this location and another, generate a primary label that - /// specifies the area between those two locations. - /// - /// See [`Self::primary_label`] for some discussion of primary versus - /// secondary labels. If the two locations are the same, this method does - /// the exact same thing as [`Self::primary_label`]. If this item was - /// generated by [`Self::manufactured`], it will act as if you'd called - /// `primary_label` on the argument. Otherwise, it will generate the obvious - /// span. - /// - /// This function will return `None` only in the case that you provide - /// labels from two different files, which it cannot sensibly handle. - pub fn range_label(&self, end: &Location) -> Option> { - if self.file_idx == 0 { - return Some(end.primary_label()); - } - - if self.file_idx != end.file_idx { - return None; - } - - if self.offset > end.offset { - Some(Label::primary(self.file_idx, end.offset..self.offset)) - } else { - Some(Label::primary(self.file_idx, self.offset..end.offset)) - } + Label::secondary(self.file_idx, self.location.clone()) } /// Return an error diagnostic centered at this location. @@ -102,10 +76,7 @@ impl Location { /// this particular location. You'll need to extend it with actually useful /// information, like what kind of error it is. pub fn error(&self) -> Diagnostic { - Diagnostic::error().with_labels(vec![Label::primary( - self.file_idx, - self.offset..self.offset, - )]) + Diagnostic::error().with_labels(vec![Label::primary(self.file_idx, self.location.clone())]) } /// Return an error diagnostic centered at this location, with the given message. @@ -115,10 +86,34 @@ impl Location { /// even more information to ut, using [`Diagnostic::with_labels`], /// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`]. pub fn labelled_error(&self, msg: &str) -> Diagnostic { - Diagnostic::error().with_labels(vec![Label::primary( - self.file_idx, - self.offset..self.offset, - ) - .with_message(msg)]) + Diagnostic::error().with_labels(vec![ + Label::primary(self.file_idx, self.location.clone()).with_message(msg) + ]) + } + + /// Merge two locations into a single location spanning the whole range between + /// them. + /// + /// This function returns None if the locations are from different files; this + /// can happen if one of the locations is manufactured, for example. + pub fn merge(&self, other: &Self) -> Option { + if self.file_idx != other.file_idx { + None + } else { + let start = if self.location.start <= other.location.start { + self.location.start + } else { + other.location.start + }; + let end = if self.location.end >= other.location.end { + self.location.end + } else { + other.location.end + }; + Some(Location { + file_idx: self.file_idx, + location: start..end, + }) + } } } diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index 3d8de29..fc6a0c6 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -9,8 +9,8 @@ //! eventually want to leave lalrpop behind.) //! use crate::syntax::{LexerError, Location}; -use crate::syntax::ast::{Program,Statement,Expression,Value}; -use crate::syntax::tokens::Token; +use crate::syntax::ast::{Program,Statement,Expression,Value,Name}; +use crate::syntax::tokens::{ConstantType, Token}; use internment::ArcIntern; // one cool thing about lalrpop: we can pass arguments. in this case, the @@ -32,6 +32,8 @@ extern { ";" => Token::Semi, "(" => Token::LeftParen, ")" => Token::RightParen, + "<" => Token::LessThan, + ">" => Token::GreaterThan, "print" => Token::Print, @@ -44,7 +46,7 @@ extern { // to name and use "their value", you get their source location. // For these, we want "their value" to be their actual contents, // which is why we put their types in angle brackets. - "" => Token::Number((>,)), + "" => Token::Number((>,>,)), "" => Token::Variable(>), } } @@ -89,10 +91,19 @@ pub Statement: Statement = { // A statement can be a variable binding. Note, here, that we use this // funny @L thing to get the source location before the variable, so that // we can say that this statement spans across everything. - "> "=" ";" => Statement::Binding(Location::new(file_idx, l), v.to_string(), e), + "> "=" ";" => + Statement::Binding( + Location::new(file_idx, ls..le), + Name::new(v, Location::new(file_idx, ls..var_end)), + e, + ), // Alternatively, a statement can just be a print statement. - "print" "> ";" => Statement::Print(Location::new(file_idx, l), v.to_string()), + "print" "> ";" => + Statement::Print( + Location::new(file_idx, ls..le), + Name::new(v, Location::new(file_idx, name_start..name_end)), + ), } // Expressions! Expressions are a little fiddly, because we're going to @@ -124,15 +135,27 @@ Expression: Expression = { // we group addition and subtraction under the heading "additive" AdditiveExpression: Expression = { - "+" => Expression::Primitive(Location::new(file_idx, l), "+".to_string(), vec![e1, e2]), - "-" => Expression::Primitive(Location::new(file_idx, l), "-".to_string(), vec![e1, e2]), + "+" => + Expression::Primitive(Location::new(file_idx, ls..le), "+".to_string(), vec![e1, e2]), + "-" => + Expression::Primitive(Location::new(file_idx, ls..le), "-".to_string(), vec![e1, e2]), MultiplicativeExpression, } // similarly, we group multiplication and division under "multiplicative" MultiplicativeExpression: Expression = { - "*" => Expression::Primitive(Location::new(file_idx, l), "*".to_string(), vec![e1, e2]), - "/" => Expression::Primitive(Location::new(file_idx, l), "/".to_string(), vec![e1, e2]), + "*" => + Expression::Primitive(Location::new(file_idx, ls..le), "*".to_string(), vec![e1, e2]), + "/" => + Expression::Primitive(Location::new(file_idx, ls..le), "/".to_string(), vec![e1, e2]), + UnaryExpression, +} + +UnaryExpression: Expression = { + "-" => + Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]), + "<" "> ">" => + Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)), AtomicExpression, } @@ -140,22 +163,9 @@ MultiplicativeExpression: Expression = { // they cannot be further divided into parts AtomicExpression: Expression = { // just a variable reference - "> => Expression::Reference(Location::new(file_idx, l), v.to_string()), + "> => Expression::Reference(Location::new(file_idx, l..end), v.to_string()), // just a number - "> => { - let val = Value::Number(n.0, n.1); - Expression::Value(Location::new(file_idx, l), val) - }, - // a tricky case: also just a number, but using a negative sign. an - // alternative way to do this -- and we may do this eventually -- is - // to implement a unary negation expression. this has the odd effect - // that the user never actually writes down a negative number; they just - // write positive numbers which are immediately sent to a negation - // primitive! - "-" "> => { - let val = Value::Number(n.0, -n.1); - Expression::Value(Location::new(file_idx, l), val) - }, + "> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)), // finally, let people parenthesize expressions and get back to a // lower precedence "(" ")" => e, diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index 46a59fb..6a9338f 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -1,6 +1,8 @@ -use crate::syntax::ast::{Expression, Program, Statement, Value, BINARY_OPERATORS}; +use crate::syntax::ast::{Expression, Program, Statement, Value}; use pretty::{DocAllocator, DocBuilder, Pretty}; +use super::ConstantType; + impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program where A: 'a, @@ -50,14 +52,14 @@ where match self { Expression::Value(_, val) => val.pretty(allocator), Expression::Reference(_, var) => allocator.text(var.to_string()), - Expression::Primitive(_, op, exprs) if BINARY_OPERATORS.contains(&op.as_ref()) => { - assert_eq!( - exprs.len(), - 2, - "Found binary operator with {} components?", - exprs.len() - ); - + Expression::Cast(_, t, e) => allocator + .text(t.clone()) + .angles() + .append(e.pretty(allocator)), + Expression::Primitive(_, op, exprs) if exprs.len() == 1 => allocator + .text(op.to_string()) + .append(exprs[0].pretty(allocator)), + Expression::Primitive(_, op, exprs) if exprs.len() == 2 => { let left = exprs[0].pretty(allocator); let right = exprs[1].pretty(allocator); @@ -84,15 +86,14 @@ where { fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { match self { - Value::Number(opt_base, value) => { - let sign = if *value < 0 { "-" } else { "" }; + Value::Number(opt_base, ty, value) => { let value_str = match opt_base { - None => format!("{}", value), - Some(2) => format!("{}0b{:b}", sign, value.abs()), - Some(8) => format!("{}0o{:o}", sign, value.abs()), - Some(10) => format!("{}0d{}", sign, value.abs()), - Some(16) => format!("{}0x{:x}", sign, value.abs()), - Some(_) => format!("!!{}{:x}!!", sign, value.abs()), + None => format!("{}{}", value, type_suffix(ty)), + Some(2) => format!("0b{:b}{}", value, type_suffix(ty)), + Some(8) => format!("0o{:o}{}", value, type_suffix(ty)), + Some(10) => format!("0d{}{}", value, type_suffix(ty)), + Some(16) => format!("0x{:x}{}", value, type_suffix(ty)), + Some(_) => format!("!!{:x}{}!!", value, type_suffix(ty)), }; allocator.text(value_str) @@ -101,6 +102,20 @@ where } } +fn type_suffix(x: &Option) -> &'static str { + match x { + None => "", + Some(ConstantType::I8) => "i8", + Some(ConstantType::I16) => "i16", + Some(ConstantType::I32) => "i32", + Some(ConstantType::I64) => "i64", + Some(ConstantType::U8) => "u8", + Some(ConstantType::U16) => "u16", + Some(ConstantType::U32) => "u32", + Some(ConstantType::U64) => "u64", + } +} + #[derive(Clone, Copy)] struct CommaSep {} diff --git a/src/syntax/tokens.rs b/src/syntax/tokens.rs index e20757d..e521987 100644 --- a/src/syntax/tokens.rs +++ b/src/syntax/tokens.rs @@ -40,6 +40,12 @@ pub enum Token { #[token(")")] RightParen, + #[token("<")] + LessThan, + + #[token(">")] + GreaterThan, + // Next we take of any reserved words; I always like to put // these before we start recognizing more complicated regular // expressions. I don't think it matters, but it works for me. @@ -53,13 +59,14 @@ pub enum Token { /// Numbers capture both the value we read from the input, /// converted to an `i64`, as well as the base the user used - /// to write the number, if they did so. - #[regex(r"0b[01]+", |v| parse_number(Some(2), v))] - #[regex(r"0o[0-7]+", |v| parse_number(Some(8), v))] - #[regex(r"0d[0-9]+", |v| parse_number(Some(10), v))] - #[regex(r"0x[0-9a-fA-F]+", |v| parse_number(Some(16), v))] - #[regex(r"[0-9]+", |v| parse_number(None, v))] - Number((Option, i64)), + /// to write the number and/or the type the user specified, + /// if they did either. + #[regex(r"0b[01]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(2), v))] + #[regex(r"0o[0-7]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(8), v))] + #[regex(r"0d[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(10), v))] + #[regex(r"0x[0-9a-fA-F]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(Some(16), v))] + #[regex(r"[0-9]+(u8|i8|u16|i16|u32|i32|u64|i64)?", |v| parse_number(None, v))] + Number((Option, Option, u64)), // Variables; this is a very standard, simple set of characters // for variables, but feel free to experiment with more complicated @@ -88,15 +95,29 @@ impl fmt::Display for Token { Token::Semi => write!(f, "';'"), Token::LeftParen => write!(f, "'('"), Token::RightParen => write!(f, "')'"), + Token::LessThan => write!(f, "<"), + Token::GreaterThan => write!(f, ">"), Token::Print => write!(f, "'print'"), Token::Operator(c) => write!(f, "'{}'", c), - Token::Number((None, v)) => write!(f, "'{}'", v), - Token::Number((Some(2), v)) => write!(f, "'0b{:b}'", v), - Token::Number((Some(8), v)) => write!(f, "'0o{:o}'", v), - Token::Number((Some(10), v)) => write!(f, "'{}'", v), - Token::Number((Some(16), v)) => write!(f, "'0x{:x}'", v), - Token::Number((Some(b), v)) => { - write!(f, "Invalidly-based-number", b, v) + Token::Number((None, otype, v)) => write!(f, "'{}{}'", v, display_optional_type(otype)), + Token::Number((Some(2), otype, v)) => { + write!(f, "'0b{:b}{}'", v, display_optional_type(otype)) + } + Token::Number((Some(8), otype, v)) => { + write!(f, "'0o{:o}{}'", v, display_optional_type(otype)) + } + Token::Number((Some(10), otype, v)) => { + write!(f, "'{}{}'", v, display_optional_type(otype)) + } + Token::Number((Some(16), otype, v)) => { + write!(f, "'0x{:x}{}'", v, display_optional_type(otype)) + } + Token::Number((Some(b), opt_type, v)) => { + write!( + f, + "Invalidly-based-number", + b, v, opt_type + ) } Token::Variable(s) => write!(f, "'{}'", s), Token::Error => write!(f, ""), @@ -122,6 +143,125 @@ impl Token { } } +#[repr(i64)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ConstantType { + U8 = 10, + U16 = 11, + U32 = 12, + U64 = 13, + I8 = 20, + I16 = 21, + I32 = 22, + I64 = 23, +} + +impl From for cranelift_codegen::ir::Type { + fn from(value: ConstantType) -> Self { + match value { + ConstantType::I8 | ConstantType::U8 => cranelift_codegen::ir::types::I8, + ConstantType::I16 | ConstantType::U16 => cranelift_codegen::ir::types::I16, + ConstantType::I32 | ConstantType::U32 => cranelift_codegen::ir::types::I32, + ConstantType::I64 | ConstantType::U64 => cranelift_codegen::ir::types::I64, + } + } +} + +impl ConstantType { + /// Returns true if the given type is (a) numeric and (b) signed; + pub fn is_signed(&self) -> bool { + matches!( + self, + ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 + ) + } + + /// Return the set of types that can be safely casted into this type. + pub fn safe_casts_to(self) -> Vec { + match self { + ConstantType::I8 => vec![ConstantType::I8], + ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8], + ConstantType::I32 => vec![ + ConstantType::I32, + ConstantType::I16, + ConstantType::I8, + ConstantType::U16, + ConstantType::U8, + ], + ConstantType::I64 => vec![ + ConstantType::I64, + ConstantType::I32, + ConstantType::I16, + ConstantType::I8, + ConstantType::U32, + ConstantType::U16, + ConstantType::U8, + ], + ConstantType::U8 => vec![ConstantType::U8], + ConstantType::U16 => vec![ConstantType::U16, ConstantType::U8], + ConstantType::U32 => vec![ConstantType::U32, ConstantType::U16, ConstantType::U8], + ConstantType::U64 => vec![ + ConstantType::U64, + ConstantType::U32, + ConstantType::U16, + ConstantType::U8, + ], + } + } + + /// Return the set of all currently-available constant types + pub fn all_types() -> Vec { + vec![ + ConstantType::U8, + ConstantType::U16, + ConstantType::U32, + ConstantType::U64, + ConstantType::I8, + ConstantType::I16, + ConstantType::I32, + ConstantType::I64, + ] + } + + /// Return the name of the given type, as a string + pub fn name(&self) -> String { + match self { + ConstantType::I8 => "i8".to_string(), + ConstantType::I16 => "i16".to_string(), + ConstantType::I32 => "i32".to_string(), + ConstantType::I64 => "i64".to_string(), + ConstantType::U8 => "u8".to_string(), + ConstantType::U16 => "u16".to_string(), + ConstantType::U32 => "u32".to_string(), + ConstantType::U64 => "u64".to_string(), + } + } +} + +#[derive(Debug, Error, PartialEq)] +pub enum InvalidConstantType { + #[error("Unrecognized constant {0} for constant type")] + Value(i64), +} + +impl TryFrom for ConstantType { + type Error = InvalidConstantType; + + fn try_from(value: i64) -> Result { + match value { + 10 => Ok(ConstantType::U8), + 11 => Ok(ConstantType::U16), + 12 => Ok(ConstantType::U32), + 13 => Ok(ConstantType::U64), + 20 => Ok(ConstantType::I8), + 21 => Ok(ConstantType::I16), + 22 => Ok(ConstantType::I32), + 23 => Ok(ConstantType::I64), + _ => Err(InvalidConstantType::Value(value)), + } + } +} + /// Parse a number in the given base, return a pair of the base and the /// parsed number. This is just a helper used for all of the number /// regular expression cases, which kicks off to the obvious Rust @@ -129,24 +269,66 @@ impl Token { fn parse_number( base: Option, value: &Lexer, -) -> Result<(Option, i64), ParseIntError> { +) -> Result<(Option, Option, u64), ParseIntError> { let (radix, strval) = match base { None => (10, value.slice()), Some(radix) => (radix, &value.slice()[2..]), }; - let intval = i64::from_str_radix(strval, radix as u32)?; - Ok((base, intval)) + let (declared_type, strval) = if let Some(strval) = strval.strip_suffix("u8") { + (Some(ConstantType::U8), strval) + } else if let Some(strval) = strval.strip_suffix("u16") { + (Some(ConstantType::U16), strval) + } else if let Some(strval) = strval.strip_suffix("u32") { + (Some(ConstantType::U32), strval) + } else if let Some(strval) = strval.strip_suffix("u64") { + (Some(ConstantType::U64), strval) + } else if let Some(strval) = strval.strip_suffix("i8") { + (Some(ConstantType::I8), strval) + } else if let Some(strval) = strval.strip_suffix("i16") { + (Some(ConstantType::I16), strval) + } else if let Some(strval) = strval.strip_suffix("i32") { + (Some(ConstantType::I32), strval) + } else if let Some(strval) = strval.strip_suffix("i64") { + (Some(ConstantType::I64), strval) + } else { + (None, strval) + }; + + let intval = u64::from_str_radix(strval, radix as u32)?; + Ok((base, declared_type, intval)) +} + +fn display_optional_type(otype: &Option) -> &'static str { + match otype { + None => "", + Some(ConstantType::I8) => "i8", + Some(ConstantType::I16) => "i16", + Some(ConstantType::I32) => "i32", + Some(ConstantType::I64) => "i64", + Some(ConstantType::U8) => "u8", + Some(ConstantType::U16) => "u16", + Some(ConstantType::U32) => "u32", + Some(ConstantType::U64) => "u64", + } } #[test] fn lex_numbers() { - let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc // 9"); - assert_eq!(lex0.next(), Some(Token::Number((None, 12)))); - assert_eq!(lex0.next(), Some(Token::Number((Some(2), 12)))); - assert_eq!(lex0.next(), Some(Token::Number((Some(8), 12)))); - assert_eq!(lex0.next(), Some(Token::Number((Some(10), 12)))); - assert_eq!(lex0.next(), Some(Token::Number((Some(16), 12)))); + let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc 12u8 0xci64// 9"); + assert_eq!(lex0.next(), Some(Token::Number((None, None, 12)))); + assert_eq!(lex0.next(), Some(Token::Number((Some(2), None, 12)))); + assert_eq!(lex0.next(), Some(Token::Number((Some(8), None, 12)))); + assert_eq!(lex0.next(), Some(Token::Number((Some(10), None, 12)))); + assert_eq!(lex0.next(), Some(Token::Number((Some(16), None, 12)))); + assert_eq!( + lex0.next(), + Some(Token::Number((None, Some(ConstantType::U8), 12))) + ); + assert_eq!( + lex0.next(), + Some(Token::Number((Some(16), Some(ConstantType::I64), 12))) + ); assert_eq!(lex0.next(), None); } @@ -168,6 +350,31 @@ fn lexer_spans() { assert_eq!(lex0.next(), Some((Token::Equals, 2..3))); assert_eq!(lex0.next(), Some((Token::var("x"), 4..5))); assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7))); - assert_eq!(lex0.next(), Some((Token::Number((None, 1)), 8..9))); + assert_eq!(lex0.next(), Some((Token::Number((None, None, 1)), 8..9))); assert_eq!(lex0.next(), None); } + +#[test] +fn further_spans() { + let mut lex0 = Token::lexer("x = 2i64 + 2i64;\ny = -x;\nprint y;").spanned(); + assert_eq!(lex0.next(), Some((Token::var("x"), 0..1))); + assert_eq!(lex0.next(), Some((Token::Equals, 2..3))); + assert_eq!( + lex0.next(), + Some((Token::Number((None, Some(ConstantType::I64), 2)), 4..8)) + ); + assert_eq!(lex0.next(), Some((Token::Operator('+'), 9..10))); + assert_eq!( + lex0.next(), + Some((Token::Number((None, Some(ConstantType::I64), 2)), 11..15)) + ); + assert_eq!(lex0.next(), Some((Token::Semi, 15..16))); + assert_eq!(lex0.next(), Some((Token::var("y"), 17..18))); + assert_eq!(lex0.next(), Some((Token::Equals, 19..20))); + assert_eq!(lex0.next(), Some((Token::Operator('-'), 21..22))); + assert_eq!(lex0.next(), Some((Token::var("x"), 22..23))); + assert_eq!(lex0.next(), Some((Token::Semi, 23..24))); + assert_eq!(lex0.next(), Some((Token::Print, 25..30))); + assert_eq!(lex0.next(), Some((Token::var("y"), 31..32))); + assert_eq!(lex0.next(), Some((Token::Semi, 32..33))); +} diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index c318e93..30afe5c 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -1,6 +1,9 @@ -use crate::syntax::{Expression, Location, Program, Statement}; +use crate::{ + eval::PrimitiveType, + syntax::{Expression, Location, Program, Statement}, +}; use codespan_reporting::diagnostic::Diagnostic; -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; /// An error we found while validating the input program. /// @@ -11,6 +14,7 @@ use std::collections::HashMap; /// and using [`codespan_reporting`] to present them to the user. pub enum Error { UnboundVariable(Location, String), + UnknownType(Location, String), } impl From for Diagnostic { @@ -19,6 +23,10 @@ impl From for Diagnostic { Error::UnboundVariable(location, name) => location .labelled_error("unbound here") .with_message(format!("Unbound variable '{}'", name)), + + Error::UnknownType(location, name) => location + .labelled_error("type referenced here") + .with_message(format!("Unknown type '{}'", name)), } } } @@ -57,12 +65,24 @@ impl Program { /// example, and generates warnings for things that are inadvisable but not /// actually a problem. pub fn validate(&self) -> (Vec, Vec) { + let mut bound_variables = HashMap::new(); + self.validate_with_bindings(&mut bound_variables) + } + + /// Validate that the program makes semantic sense, not just syntactic sense. + /// + /// This checks for things like references to variables that don't exist, for + /// example, and generates warnings for things that are inadvisable but not + /// actually a problem. + pub fn validate_with_bindings( + &self, + bound_variables: &mut HashMap, + ) -> (Vec, Vec) { let mut errors = vec![]; let mut warnings = vec![]; - let mut bound_variables = HashMap::new(); for stmt in self.statements.iter() { - let (mut new_errors, mut new_warnings) = stmt.validate(&mut bound_variables); + let (mut new_errors, mut new_warnings) = stmt.validate(bound_variables); errors.append(&mut new_errors); warnings.append(&mut new_warnings); } @@ -81,7 +101,7 @@ impl Statement { /// occurs. We use a `HashMap` to map these bound locations to the locations /// where their bound, because these locations are handy when generating errors /// and warnings. - pub fn validate( + fn validate( &self, bound_variables: &mut HashMap, ) -> (Vec, Vec) { @@ -97,20 +117,20 @@ impl Statement { errors.append(&mut exp_errors); warnings.append(&mut exp_warnings); - if let Some(original_binding_site) = bound_variables.get(var) { + if let Some(original_binding_site) = bound_variables.get(&var.name) { warnings.push(Warning::ShadowedVariable( original_binding_site.clone(), loc.clone(), - var.clone(), + var.to_string(), )); } else { - bound_variables.insert(var.clone(), loc.clone()); + bound_variables.insert(var.to_string(), loc.clone()); } } - Statement::Print(_, var) if bound_variables.contains_key(var) => {} + Statement::Print(_, var) if bound_variables.contains_key(&var.name) => {} Statement::Print(loc, var) => { - errors.push(Error::UnboundVariable(loc.clone(), var.clone())) + errors.push(Error::UnboundVariable(loc.clone(), var.to_string())) } } @@ -127,6 +147,15 @@ impl Expression { vec![Error::UnboundVariable(loc.clone(), var.clone())], vec![], ), + Expression::Cast(location, t, expr) => { + let (mut errs, warns) = expr.validate(variable_map); + + if PrimitiveType::from_str(t).is_err() { + errs.push(Error::UnknownType(location.clone(), t.clone())) + } + + (errs, warns) + } Expression::Primitive(_, _, args) => { let mut errors = vec![]; let mut warnings = vec![]; @@ -142,3 +171,19 @@ impl Expression { } } } + +#[test] +fn cast_checks_are_reasonable() { + let good_stmt = Statement::parse(0, "x = 4u8;").expect("valid test case"); + let (good_errs, good_warns) = good_stmt.validate(&mut HashMap::new()); + + assert!(good_errs.is_empty()); + assert!(good_warns.is_empty()); + + let bad_stmt = Statement::parse(0, "x = 4u8;").expect("valid test case"); + let (bad_errs, bad_warns) = bad_stmt.validate(&mut HashMap::new()); + + assert!(bad_warns.is_empty()); + assert_eq!(bad_errs.len(), 1); + assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple")); +} diff --git a/src/type_infer.rs b/src/type_infer.rs new file mode 100644 index 0000000..0402998 --- /dev/null +++ b/src/type_infer.rs @@ -0,0 +1,52 @@ +//! A type inference pass for NGR +//! +//! The type checker implemented here is a relatively straightforward one, designed to be +//! fairly easy to understand rather than super fast. So don't be expecting the fastest +//! type checker in the West, here. +//! +//! The actual type checker operates in three phases. In the first phase, we translate +//! the syntax AST into something that's close to the final IR. During the process, we +//! generate a list of type constraints to solve. In the second phase, we try to solve +//! all the constraints we've generated. If that's successful, in the final phase, we +//! do the final conversion to the IR AST, filling in any type information we've learned +//! along the way. +mod ast; +mod convert; +mod finalize; +mod solve; + +use self::convert::convert_program; +use self::finalize::finalize_program; +use self::solve::solve_constraints; +pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning}; +use crate::ir::ast as ir; +use crate::syntax; +#[cfg(test)] +use crate::syntax::arbitrary::GenerationEnvironment; +#[cfg(test)] +use proptest::prelude::Arbitrary; + +impl syntax::Program { + /// Infer the types for the syntactic AST, returning either a type-checked program in + /// the IR, or a series of type errors encountered during inference. + /// + /// You really should have made sure that this program was validated before running + /// this method, otherwise you may experience panics during operation. + pub fn type_infer(self) -> TypeInferenceResult { + let mut constraint_db = vec![]; + let program = convert_program(self, &mut constraint_db); + let inference_result = solve_constraints(constraint_db); + + inference_result.map(|resolutions| finalize_program(program, &resolutions)) + } +} + +proptest::proptest! { + #[test] + fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) { + let syntax_result = input.eval(); + let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); + let ir_result = ir.eval(); + proptest::prop_assert_eq!(syntax_result, ir_result); + } +} diff --git a/src/type_infer/ast.rs b/src/type_infer/ast.rs new file mode 100644 index 0000000..ff91a00 --- /dev/null +++ b/src/type_infer/ast.rs @@ -0,0 +1,336 @@ +pub use crate::ir::ast::Primitive; +/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going +/// to want to use while we're doing type inference, but don't want to keep around +/// afterwards. These are: +/// +/// * A notion of a type variable +/// * An unknown numeric constant form +/// +use crate::{ + eval::PrimitiveType, + syntax::{self, ConstantType, Location}, +}; +use internment::ArcIntern; +use pretty::{DocAllocator, Pretty}; +use std::fmt; +use std::sync::atomic::AtomicUsize; + +/// We're going to represent variables as interned strings. +/// +/// These should be fast enough for comparison that it's OK, since it's going to end up +/// being pretty much the pointer to the string. +type Variable = ArcIntern; + +/// The representation of a program within our IR. For now, this is exactly one file. +/// +/// In addition, for the moment there's not really much of interest to hold here besides +/// the list of statements read from the file. Order is important. In the future, you +/// could imagine caching analysis information in this structure. +/// +/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used +/// to print the structure whenever possible, especially if you value your or your +/// user's time. The latter is useful for testing that conversions of `Program` retain +/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be +/// syntactically valid, although they may contain runtime issue like over- or underflow. +#[derive(Debug)] +pub struct Program { + // For now, a program is just a vector of statements. In the future, we'll probably + // extend this to include a bunch of other information, but for now: just a list. + pub(crate) statements: Vec, +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + let mut result = allocator.nil(); + + for stmt in self.statements.iter() { + // there's probably a better way to do this, rather than constantly + // adding to the end, but this works. + result = result + .append(stmt.pretty(allocator)) + .append(allocator.text(";")) + .append(allocator.hardline()); + } + + result + } +} + +/// The representation of a statement in the language. +/// +/// For now, this is either a binding site (`x = 4`) or a print statement +/// (`print x`). Someday, though, more! +/// +/// As with `Program`, this type implements [`Pretty`], which should +/// be used to display the structure whenever possible. It does not +/// implement [`Arbitrary`], though, mostly because it's slightly +/// complicated to do so. +/// +#[derive(Debug)] +pub enum Statement { + Binding(Location, Variable, Type, Expression), + Print(Location, Type, Variable), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Statement::Binding(_, var, _, expr) => allocator + .text(var.as_ref().to_string()) + .append(allocator.space()) + .append(allocator.text("=")) + .append(allocator.space()) + .append(expr.pretty(allocator)), + Statement::Print(_, _, var) => allocator + .text("print") + .append(allocator.space()) + .append(allocator.text(var.as_ref().to_string())), + } + } +} + +/// The representation of an expression. +/// +/// Note that expressions, like everything else in this syntax tree, +/// supports [`Pretty`], and it's strongly encouraged that you use +/// that trait/module when printing these structures. +/// +/// Also, Expressions at this point in the compiler are explicitly +/// defined so that they are *not* recursive. By this point, if an +/// expression requires some other data (like, for example, invoking +/// a primitive), any subexpressions have been bound to variables so +/// that the referenced data will always either be a constant or a +/// variable reference. +#[derive(Debug, PartialEq)] +pub enum Expression { + Atomic(ValueOrRef), + Cast(Location, Type, ValueOrRef), + Primitive(Location, Type, Primitive, Vec), +} + +impl Expression { + /// Return a reference to the type of the expression, as inferred or recently + /// computed. + pub fn type_of(&self) -> &Type { + match self { + Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, + Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, + Expression::Cast(_, t, _) => t, + Expression::Primitive(_, t, _, _) => t, + } + } + + /// Return a reference to the location associated with the expression. + pub fn location(&self) -> &Location { + match self { + Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, + Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, + Expression::Cast(l, _, _) => l, + Expression::Primitive(l, _, _, _) => l, + } + } +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Expression::Atomic(x) => x.pretty(allocator), + Expression::Cast(_, t, e) => allocator + .text("<") + .append(t.pretty(allocator)) + .append(allocator.text(">")) + .append(e.pretty(allocator)), + Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { + op.pretty(allocator).append(exprs[0].pretty(allocator)) + } + Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { + let left = exprs[0].pretty(allocator); + let right = exprs[1].pretty(allocator); + + left.append(allocator.space()) + .append(op.pretty(allocator)) + .append(allocator.space()) + .append(right) + .parens() + } + Expression::Primitive(_, _, op, exprs) => { + allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) + } + } + } +} + +/// An expression that is always either a value or a reference. +/// +/// This is the type used to guarantee that we don't nest expressions +/// at this level. Instead, expressions that take arguments take one +/// of these, which can only be a constant or a reference. +#[derive(Clone, Debug, PartialEq)] +pub enum ValueOrRef { + Value(Location, Type, Value), + Ref(Location, Type, ArcIntern), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + ValueOrRef::Value(_, _, v) => v.pretty(allocator), + ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), + } + } +} + +impl From for Expression { + fn from(value: ValueOrRef) -> Self { + Expression::Atomic(value) + } +} + +/// A constant in the IR. +/// +/// The optional argument in numeric types is the base that was used by the +/// user to input the number. By retaining it, we can ensure that if we need +/// to print the number back out, we can do so in the form that the user +/// entered it. +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + Unknown(Option, u64), + I8(Option, i8), + I16(Option, i16), + I32(Option, i32), + I64(Option, i64), + U8(Option, u8), + U16(Option, u16), + U32(Option, u32), + U64(Option, u64), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + let pretty_internal = |opt_base: &Option, x, t| { + syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator) + }; + + let pretty_internal_signed = |opt_base, x: i64, t| { + let base = pretty_internal(opt_base, x.unsigned_abs(), t); + + allocator.text("-").append(base) + }; + + match self { + Value::Unknown(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::U64) + } + Value::I8(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I8) + } + Value::I16(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I16) + } + Value::I32(opt_base, value) => { + pretty_internal_signed(opt_base, *value as i64, ConstantType::I32) + } + Value::I64(opt_base, value) => { + pretty_internal_signed(opt_base, *value, ConstantType::I64) + } + Value::U8(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U8) + } + Value::U16(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U16) + } + Value::U32(opt_base, value) => { + pretty_internal(opt_base, *value as u64, ConstantType::U32) + } + Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Type { + Variable(Location, ArcIntern), + Primitive(PrimitiveType), +} + +impl Type { + pub fn is_concrete(&self) -> bool { + !matches!(self, Type::Variable(_, _)) + } +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + Type::Variable(_, x) => allocator.text(x.to_string()), + Type::Primitive(pt) => allocator.text(format!("{}", pt)), + } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Variable(_, x) => write!(f, "{}", x), + Type::Primitive(pt) => pt.fmt(f), + } + } +} + +/// Generate a fresh new name based on the given name. +/// +/// The new name is guaranteed to be unique across the entirety of the +/// execution. This is achieved by using characters in the variable name +/// that would not be valid input, and by including a counter that is +/// incremented on every invocation. +pub fn gensym(name: &str) -> ArcIntern { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let new_name = format!( + "<{}:{}>", + name, + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + ); + ArcIntern::new(new_name) +} + +/// Generate a fresh new type; this will be a unique new type variable. +/// +/// The new name is guaranteed to be unique across the entirety of the +/// execution. This is achieved by using characters in the variable name +/// that would not be valid input, and by including a counter that is +/// incremented on every invocation. +pub fn gentype() -> Type { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let name = ArcIntern::new(format!( + "t<{}>", + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + )); + + Type::Variable(Location::manufactured(), name) +} diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs new file mode 100644 index 0000000..e53cdb6 --- /dev/null +++ b/src/type_infer/convert.rs @@ -0,0 +1,378 @@ +use super::ast as ir; +use super::ast::Type; +use crate::eval::PrimitiveType; +use crate::syntax::{self, ConstantType}; +use crate::type_infer::solve::Constraint; +use internment::ArcIntern; +use std::collections::HashMap; +use std::str::FromStr; + +/// This function takes a syntactic program and converts it into the IR version of the +/// program, with appropriate type variables introduced and their constraints added to +/// the given database. +/// +/// If the input function has been validated (which it should be), then this should run +/// into no error conditions. However, if you failed to validate the input, then this +/// function can panic. +pub fn convert_program( + mut program: syntax::Program, + constraint_db: &mut Vec, +) -> ir::Program { + let mut statements = Vec::new(); + let mut renames = HashMap::new(); + let mut bindings = HashMap::new(); + + for stmt in program.statements.drain(..) { + statements.append(&mut convert_statement( + stmt, + constraint_db, + &mut renames, + &mut bindings, + )); + } + + ir::Program { statements } +} + +/// This function takes a syntactic statements and converts it into a series of +/// IR statements, adding type variables and constraints as necessary. +/// +/// We generate a series of statements because we're going to flatten all +/// incoming expressions so that they are no longer recursive. This will +/// generate a bunch of new bindings for all the subexpressions, which we +/// return as a bundle. +/// +/// See the safety warning on [`convert_program`]! This function assumes that +/// you have run [`Statement::validate`], and will trigger panics in error +/// conditions if you have run that and had it come back clean. +fn convert_statement( + statement: syntax::Statement, + constraint_db: &mut Vec, + renames: &mut HashMap, ArcIntern>, + bindings: &mut HashMap, Type>, +) -> Vec { + match statement { + syntax::Statement::Print(loc, name) => { + let iname = ArcIntern::new(name.to_string()); + let final_name = renames + .get(&iname) + .map(Clone::clone) + .unwrap_or_else(|| iname.clone()); + let varty = bindings + .get(&final_name) + .expect("print variable defined before use") + .clone(); + + constraint_db.push(Constraint::Printable(loc.clone(), varty.clone())); + + vec![ir::Statement::Print(loc, varty, iname)] + } + + syntax::Statement::Binding(loc, name, expr) => { + let (mut prereqs, expr, ty) = + convert_expression(expr, constraint_db, renames, bindings); + let iname = ArcIntern::new(name.to_string()); + let final_name = if bindings.contains_key(&iname) { + let new_name = ir::gensym(iname.as_str()); + renames.insert(iname, new_name.clone()); + new_name + } else { + iname + }; + + bindings.insert(final_name.clone(), ty.clone()); + prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); + prereqs + } + } +} + +/// This function takes a syntactic expression and converts it into a series +/// of IR statements, adding type variables and constraints as necessary. +/// +/// We generate a series of statements because we're going to flatten all +/// incoming expressions so that they are no longer recursive. This will +/// generate a bunch of new bindings for all the subexpressions, which we +/// return as a bundle. +/// +/// See the safety warning on [`convert_program`]! This function assumes that +/// you have run [`Statement::validate`], and will trigger panics in error +/// conditions if you have run that and had it come back clean. +fn convert_expression( + expression: syntax::Expression, + constraint_db: &mut Vec, + renames: &HashMap, ArcIntern>, + bindings: &mut HashMap, Type>, +) -> (Vec, ir::Expression, Type) { + match expression { + syntax::Expression::Value(loc, val) => match val { + syntax::Value::Number(base, mctype, value) => { + let (newval, newtype) = match mctype { + None => { + let newtype = ir::gentype(); + let newval = ir::Value::Unknown(base, value); + + constraint_db.push(Constraint::ConstantNumericType( + loc.clone(), + newtype.clone(), + )); + (newval, newtype) + } + Some(ConstantType::U8) => ( + ir::Value::U8(base, value as u8), + ir::Type::Primitive(PrimitiveType::U8), + ), + Some(ConstantType::U16) => ( + ir::Value::U16(base, value as u16), + ir::Type::Primitive(PrimitiveType::U16), + ), + Some(ConstantType::U32) => ( + ir::Value::U32(base, value as u32), + ir::Type::Primitive(PrimitiveType::U32), + ), + Some(ConstantType::U64) => ( + ir::Value::U64(base, value), + ir::Type::Primitive(PrimitiveType::U64), + ), + Some(ConstantType::I8) => ( + ir::Value::I8(base, value as i8), + ir::Type::Primitive(PrimitiveType::I8), + ), + Some(ConstantType::I16) => ( + ir::Value::I16(base, value as i16), + ir::Type::Primitive(PrimitiveType::I16), + ), + Some(ConstantType::I32) => ( + ir::Value::I32(base, value as i32), + ir::Type::Primitive(PrimitiveType::I32), + ), + Some(ConstantType::I64) => ( + ir::Value::I64(base, value as i64), + ir::Type::Primitive(PrimitiveType::I64), + ), + }; + + constraint_db.push(Constraint::FitsInNumType( + loc.clone(), + newtype.clone(), + value, + )); + ( + vec![], + ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), + newtype, + ) + } + }, + + syntax::Expression::Reference(loc, name) => { + let iname = ArcIntern::new(name); + let final_name = renames.get(&iname).cloned().unwrap_or(iname); + let rtype = bindings + .get(&final_name) + .cloned() + .expect("variable bound before use"); + let refexp = + ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name)); + + (vec![], refexp, rtype) + } + + syntax::Expression::Cast(loc, target, expr) => { + let (mut stmts, nexpr, etype) = + convert_expression(*expr, constraint_db, renames, bindings); + let val_or_ref = simplify_expr(nexpr, &mut stmts); + let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast"); + let target_type = Type::Primitive(target_prim_type); + let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); + + constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); + + (stmts, res, target_type) + } + + syntax::Expression::Primitive(loc, fun, mut args) => { + let primop = ir::Primitive::from_str(&fun).expect("valid primitive"); + let mut stmts = vec![]; + let mut nargs = vec![]; + let mut atypes = vec![]; + let ret_type = ir::gentype(); + + for arg in args.drain(..) { + let (mut astmts, aexp, atype) = + convert_expression(arg, constraint_db, renames, bindings); + + stmts.append(&mut astmts); + nargs.push(simplify_expr(aexp, &mut stmts)); + atypes.push(atype); + } + + constraint_db.push(Constraint::ProperPrimitiveArgs( + loc.clone(), + primop, + atypes.clone(), + ret_type.clone(), + )); + + ( + stmts, + ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs), + ret_type, + ) + } + } +} + +fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::ValueOrRef { + match expr { + ir::Expression::Atomic(v_or_ref) => v_or_ref, + expr => { + let etype = expr.type_of().clone(); + let loc = expr.location().clone(); + let nname = ir::gensym("g"); + let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr); + + stmts.push(nbinding); + ir::ValueOrRef::Ref(loc, etype, nname) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::syntax::Location; + + fn one() -> syntax::Expression { + syntax::Expression::Value( + Location::manufactured(), + syntax::Value::Number(None, None, 1), + ) + } + + fn vec_contains bool>(x: &[T], f: F) -> bool { + for x in x.iter() { + if f(x) { + return true; + } + } + false + } + + fn infer_expression( + x: syntax::Expression, + ) -> (ir::Expression, Vec, Vec, Type) { + let mut constraints = Vec::new(); + let renames = HashMap::new(); + let mut bindings = HashMap::new(); + let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings); + (expr, stmts, constraints, ty) + } + + fn infer_statement(x: syntax::Statement) -> (Vec, Vec) { + let mut constraints = Vec::new(); + let mut renames = HashMap::new(); + let mut bindings = HashMap::new(); + let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings); + (res, constraints) + } + + #[test] + fn constant_one() { + let (expr, stmts, constraints, ty) = infer_expression(one()); + assert!(stmts.is_empty()); + assert!(matches!( + expr, + ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1))) + )); + assert!(vec_contains(&constraints, |x| matches!( + x, + Constraint::FitsInNumType(_, _, 1) + ))); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty) + )); + } + + #[test] + fn one_plus_one() { + let opo = syntax::Expression::Primitive( + Location::manufactured(), + "+".to_string(), + vec![one(), one()], + ); + let (expr, stmts, constraints, ty) = infer_expression(opo); + assert!(stmts.is_empty()); + assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty)); + assert!(vec_contains(&constraints, |x| matches!( + x, + Constraint::FitsInNumType(_, _, 1) + ))); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty) + )); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty) + )); + } + + #[test] + fn one_plus_one_plus_one() { + let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse"); + let (stmts, constraints) = infer_statement(stmt); + assert_eq!(stmts.len(), 2); + let ir::Statement::Binding( + _args, + name1, + temp_ty1, + ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1), + ) = stmts.get(0).expect("item two") + else { + panic!("Failed to match first statement"); + }; + let ir::Statement::Binding( + _args, + name2, + temp_ty2, + ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2), + ) = stmts.get(1).expect("item two") + else { + panic!("Failed to match second statement"); + }; + let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] = + &primargs1[..] + else { + panic!("Failed to match first arguments"); + }; + let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] = + &primargs2[..] + else { + panic!("Failed to match first arguments"); + }; + assert_ne!(name1, name2); + assert_ne!(temp_ty1, temp_ty2); + assert_ne!(primty1, primty2); + assert_eq!(name1, left2name); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty) + )); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty) + )); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty) + )); + for (i, s) in stmts.iter().enumerate() { + println!("{}: {:?}", i, s); + } + for (i, c) in constraints.iter().enumerate() { + println!("{}: {:?}", i, c); + } + } +} diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs new file mode 100644 index 0000000..9c70944 --- /dev/null +++ b/src/type_infer/finalize.rs @@ -0,0 +1,187 @@ +use super::{ast as input, solve::TypeResolutions}; +use crate::{eval::PrimitiveType, ir as output}; + +pub fn finalize_program( + mut program: input::Program, + resolutions: &TypeResolutions, +) -> output::Program { + output::Program { + statements: program + .statements + .drain(..) + .map(|x| finalize_statement(x, resolutions)) + .collect(), + } +} + +fn finalize_statement( + statement: input::Statement, + resolutions: &TypeResolutions, +) -> output::Statement { + match statement { + input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding( + loc, + var, + finalize_type(ty, resolutions), + finalize_expression(expr, resolutions), + ), + input::Statement::Print(loc, ty, var) => { + output::Statement::Print(loc, finalize_type(ty, resolutions), var) + } + } +} + +fn finalize_expression( + expression: input::Expression, + resolutions: &TypeResolutions, +) -> output::Expression { + match expression { + input::Expression::Atomic(val_or_ref) => { + output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions)) + } + input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast( + loc, + finalize_type(target, resolutions), + finalize_val_or_ref(val_or_ref, resolutions), + ), + input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive( + loc, + finalize_type(ty, resolutions), + prim, + args.drain(..) + .map(|x| finalize_val_or_ref(x, resolutions)) + .collect(), + ), + } +} + +fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type { + match ty { + input::Type::Primitive(x) => output::Type::Primitive(x), + input::Type::Variable(_, tvar) => match resolutions.get(&tvar) { + None => panic!("Did not resolve type for type variable {}", tvar), + Some(pt) => output::Type::Primitive(*pt), + }, + } +} + +fn finalize_val_or_ref( + valref: input::ValueOrRef, + resolutions: &TypeResolutions, +) -> output::ValueOrRef { + match valref { + input::ValueOrRef::Ref(loc, ty, var) => { + output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var) + } + input::ValueOrRef::Value(loc, ty, val) => { + let new_type = finalize_type(ty, resolutions); + + match val { + input::Value::Unknown(base, value) => match new_type { + output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U8(base, value as u8), + ), + output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U16(base, value as u16), + ), + output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::U32(base, value as u32), + ), + output::Type::Primitive(PrimitiveType::U64) => { + output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) + } + output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I8(base, value as i8), + ), + output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I16(base, value as i16), + ), + output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I32(base, value as i32), + ), + output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value( + loc, + new_type, + output::Value::I64(base, value as i64), + ), + }, + + input::Value::U8(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U8) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value)) + } + + input::Value::U16(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U16) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value)) + } + + input::Value::U32(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U32) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value)) + } + + input::Value::U64(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::U64) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) + } + + input::Value::I8(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I8) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value)) + } + + input::Value::I16(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I16) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value)) + } + + input::Value::I32(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I32) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value)) + } + + input::Value::I64(base, value) => { + assert!(matches!( + new_type, + output::Type::Primitive(PrimitiveType::I64) + )); + output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value)) + } + } + } + } +} diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs new file mode 100644 index 0000000..ec33384 --- /dev/null +++ b/src/type_infer/solve.rs @@ -0,0 +1,542 @@ +use super::ast as ir; +use super::ast::Type; +use crate::{eval::PrimitiveType, syntax::Location}; +use codespan_reporting::diagnostic::Diagnostic; +use internment::ArcIntern; +use std::{collections::HashMap, fmt}; + +/// A type inference constraint that we're going to need to solve. +#[derive(Debug)] +pub enum Constraint { + /// The given type must be printable using the `print` built-in + Printable(Location, Type), + /// The provided numeric value fits in the given constant type + FitsInNumType(Location, Type, u64), + /// The given primitive has the proper arguments types associated with it + ProperPrimitiveArgs(Location, ir::Primitive, Vec, Type), + /// The given type can be casted to the target type safely + CanCastTo(Location, Type, Type), + /// The given type must be some numeric type, but this is not a constant + /// value, so don't try to default it if we can't figure it out + NumericType(Location, Type), + /// The given type is attached to a constant and must be some numeric type. + /// If we can't figure it out, we should warn the user and then just use a + /// default. + ConstantNumericType(Location, Type), + /// The two types should be equivalent + Equivalent(Location, Type, Type), +} + +impl fmt::Display for Constraint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty), + Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty), + Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => { + write!(f, "PRIM {} {} -> {}", op, args[0], ret) + } + Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => { + write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret) + } + Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret), + Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), + Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), + Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty), + Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), + } + } +} + +pub type TypeResolutions = HashMap, PrimitiveType>; + +/// The results of type inference; like [`Result`], but with a bit more information. +/// +/// This result is parameterized, because sometimes it's handy to return slightly +/// different things; there's a [`TypeInferenceResult::map`] function for performing +/// those sorts of conversions. +pub enum TypeInferenceResult { + Success { + result: Result, + warnings: Vec, + }, + Failure { + errors: Vec, + warnings: Vec, + }, +} + +impl TypeInferenceResult { + // If this was a successful type inference, run the function over the result to + // create a new result. + // + // This is the moral equivalent of [`Result::map`], but for type inference results. + pub fn map(self, f: F) -> TypeInferenceResult + where + F: FnOnce(R) -> U, + { + match self { + TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success { + result: f(result), + warnings, + }, + + TypeInferenceResult::Failure { errors, warnings } => { + TypeInferenceResult::Failure { errors, warnings } + } + } + } + + // Return the final result, or panic if it's not a success + pub fn expect(self, msg: &str) -> R { + match self { + TypeInferenceResult::Success { result, .. } => result, + TypeInferenceResult::Failure { .. } => { + panic!("tried to get value from failed type inference: {}", msg) + } + } + } +} + +/// The various kinds of errors that can occur while doing type inference. +pub enum TypeInferenceError { + /// The user provide a constant that is too large for its inferred type. + ConstantTooLarge(Location, PrimitiveType, u64), + /// The two types needed to be equivalent, but weren't. + NotEquivalent(Location, PrimitiveType, PrimitiveType), + /// We cannot safely cast the first type to the second type. + CannotSafelyCast(Location, PrimitiveType, PrimitiveType), + /// The primitive invocation provided the wrong number of arguments. + WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize), + /// We had a constraint we just couldn't solve. + CouldNotSolve(Constraint), +} + +impl From for Diagnostic { + fn from(value: TypeInferenceError) -> Self { + match value { + TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc + .labelled_error("constant too large for type") + .with_message(format!( + "Type {} has a max value of {}, which is smaller than {}", + primty, + primty.max_value(), + value + )), + TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc + .labelled_error("type inference error") + .with_message(format!("Expected type {}, received type {}", ty1, ty2)), + TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc + .labelled_error("unsafe type cast") + .with_message(format!("Cannot safely cast {} to {}", ty1, ty2)), + TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc + .labelled_error("wrong number of arguments") + .with_message(format!( + "expected {} for {}, received {}", + if lower == upper && lower > 1 { + format!("{} arguments", lower) + } else if lower == upper { + format!("{} argument", lower) + } else { + format!("{}-{} arguments", lower, upper) + }, + prim, + observed + )), + TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => { + loc.labelled_error("internal error").with_message(format!( + "could not determine if it was safe to cast from {} to {:#?}", + a, b + )) + } + TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => { + loc.labelled_error("internal error").with_message(format!( + "could not determine if {} and {:#?} were equivalent", + a, b + )) + } + TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => { + loc.labelled_error("internal error").with_message(format!( + "Could not determine if {} could fit in {}", + val, ty + )) + } + TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc + .labelled_error("internal error") + .with_message(format!("Could not determine if {} was a numeric type", ty)), + TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) => + panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty), + TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc + .labelled_error("internal error") + .with_message(format!("Could not determine if type {} was printable", ty)), + TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => { + loc.labelled_error("internal error").with_message(format!( + "Could not tell if primitive {} received the proper argument types", + prim + )) + } + } + } +} + +/// Warnings that we might want to tell the user about. +/// +/// These are fine, probably, but could indicate some behavior the user might not +/// expect, and so they might want to do something about them. +pub enum TypeInferenceWarning { + DefaultedTo(Location, Type), +} + +impl From for Diagnostic { + fn from(value: TypeInferenceWarning) -> Self { + match value { + TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning() + .with_labels(vec![loc.primary_label().with_message("unknown type")]) + .with_message(format!("Defaulted unknown type to {}", ty)), + } + } +} + +/// Solve all the constraints in the provided database. +/// +/// This process can take a bit, so you might not want to do it multiple times. Basically, +/// it's going to grind on these constraints until either it figures them out, or it stops +/// making progress. I haven't done the math on the constraints to even figure out if this +/// is guaranteed to halt, though, let alone terminate in some reasonable amount of time. +/// +/// The return value is a type inference result, which pairs some warnings with either a +/// successful set of type resolutions (mappings from type variables to their values), or +/// a series of inference errors. +pub fn solve_constraints( + mut constraint_db: Vec, +) -> TypeInferenceResult { + let mut errors = vec![]; + let mut warnings = vec![]; + let mut resolutions = HashMap::new(); + let mut changed_something = true; + + // We want to run this inference endlessly, until either we have solved all of our + // constraints. Internal to the loop, we have a check that will make sure that we + // do (eventually) stop. + while changed_something && !constraint_db.is_empty() { + // Set this to false at the top of the loop. We'll set this to true if we make + // progress in any way further down, but having this here prevents us from going + // into an infinite look when we can't figure stuff out. + changed_something = false; + // This is sort of a double-buffering thing; we're going to rename constraint_db + // and then set it to a new empty vector, which we'll add to as we find new + // constraints or find ourselves unable to solve existing ones. + let mut local_constraints = constraint_db; + constraint_db = vec![]; + + // OK. First thing we're going to do is run through all of our constraints, + // and see if we can solve any, or reduce them to theoretically more simple + // constraints. + for constraint in local_constraints.drain(..) { + match constraint { + // Currently, all of our types are printable + Constraint::Printable(_loc, _ty) => changed_something = true, + + // Case #1: We have two primitive types. If they're equal, we've discharged this + // constraint! We can just continue. If they're not equal, add an error and then + // see what else we come up with. + Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => { + if t1 != t2 { + errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2)); + } + changed_something = true; + } + + // Case #2: One of the two constraints is a primitive, and the other is a variable. + // In this case, we'll check to see if we've resolved the variable, and check for + // equivalence if we have. If we haven't, we'll set that variable to be primitive + // type. + Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name)) + | Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => { + match resolutions.get(&name) { + None => { + resolutions.insert(name, t); + } + Some(t2) if &t == t2 => {} + Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)), + } + changed_something = true; + } + + // Case #3: They're both variables. In which case, we'll have to do much the same + // check, but now on their resolutions. + Constraint::Equivalent( + ref loc, + Type::Variable(_, ref name1), + Type::Variable(_, ref name2), + ) => match (resolutions.get(name1), resolutions.get(name2)) { + (None, None) => { + constraint_db.push(constraint); + } + (Some(pt), None) => { + resolutions.insert(name2.clone(), *pt); + changed_something = true; + } + (None, Some(pt)) => { + resolutions.insert(name1.clone(), *pt); + changed_something = true; + } + (Some(pt1), Some(pt2)) if pt1 == pt2 => { + changed_something = true; + } + (Some(pt1), Some(pt2)) => { + errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2)); + changed_something = true; + } + }, + + // Make sure that the provided number fits within the provided constant type. For the + // moment, we're going to call an error here a failure, although this could be a + // warning in the future. + Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => { + if ctype.max_value() < val { + errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); + } + + changed_something = true; + } + + // If we have a non-constant type, then let's see if we can advance this to a constant + // type + Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => { + match resolutions.get(&var) { + None => constraint_db.push(Constraint::FitsInNumType( + loc, + Type::Variable(vloc, var), + val, + )), + Some(nt) => { + constraint_db.push(Constraint::FitsInNumType( + loc, + Type::Primitive(*nt), + val, + )); + changed_something = true; + } + } + } + + // If the left type in a "can cast to" check is a variable, let's see if we can advance + // it into something more tangible + Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => { + match resolutions.get(&var) { + None => constraint_db.push(Constraint::CanCastTo( + loc, + Type::Variable(vloc, var), + to_type, + )), + Some(nt) => { + constraint_db.push(Constraint::CanCastTo( + loc, + Type::Primitive(*nt), + to_type, + )); + changed_something = true; + } + } + } + + // If the right type in a "can cast to" check is a variable, same deal + Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => { + match resolutions.get(&var) { + None => constraint_db.push(Constraint::CanCastTo( + loc, + from_type, + Type::Variable(vloc, var), + )), + Some(nt) => { + constraint_db.push(Constraint::CanCastTo( + loc, + from_type, + Type::Primitive(*nt), + )); + changed_something = true; + } + } + } + + // If both of them are types, then we can actually do the test. yay! + Constraint::CanCastTo( + loc, + Type::Primitive(from_type), + Type::Primitive(to_type), + ) => { + if !from_type.can_cast_to(&to_type) { + errors.push(TypeInferenceError::CannotSafelyCast( + loc, from_type, to_type, + )); + } + changed_something = true; + } + + // As per usual, if we're trying to test if a type variable is numeric, first + // we try to advance it to a primitive + Constraint::NumericType(loc, Type::Variable(vloc, var)) => { + match resolutions.get(&var) { + None => constraint_db + .push(Constraint::NumericType(loc, Type::Variable(vloc, var))), + Some(nt) => { + constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt))); + changed_something = true; + } + } + } + + // Of course, if we get to a primitive type, then it's true, because all of our + // primitive types are numbers + Constraint::NumericType(_, Type::Primitive(_)) => { + changed_something = true; + } + + // As per usual, if we're trying to test if a type variable is numeric, first + // we try to advance it to a primitive + Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => { + match resolutions.get(&var) { + None => constraint_db.push(Constraint::ConstantNumericType( + loc, + Type::Variable(vloc, var), + )), + Some(nt) => { + constraint_db + .push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt))); + changed_something = true; + } + } + } + + // Of course, if we get to a primitive type, then it's true, because all of our + // primitive types are numbers + Constraint::ConstantNumericType(_, Type::Primitive(_)) => { + changed_something = true; + } + + // OK, this one could be a little tricky if we tried to do it all at once, but + // instead what we're going to do here is just use this constraint to generate + // a bunch more constraints, and then go have the engine solve those. The only + // real errors we're going to come up with here are "arity errors"; errors we + // find by discovering that the number of arguments provided doesn't make sense + // given the primitive being used. + Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim { + ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide + if args.len() != 2 => + { + errors.push(TypeInferenceError::WrongPrimitiveArity( + loc, + prim, + 2, + 2, + args.len(), + )); + changed_something = true; + } + + ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { + let right = args.pop().expect("2 > 0"); + let left = args.pop().expect("2 > 1"); + + // technically testing that both are numeric is redundant, but it might give + // a slightly helpful type error if we do both. + constraint_db.push(Constraint::NumericType(loc.clone(), left.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), right.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + left.clone(), + right, + )); + constraint_db.push(Constraint::Equivalent(loc, left, ret)); + changed_something = true; + } + + ir::Primitive::Minus if args.is_empty() || args.len() > 2 => { + errors.push(TypeInferenceError::WrongPrimitiveArity( + loc, + prim, + 1, + 2, + args.len(), + )); + changed_something = true; + } + + ir::Primitive::Minus if args.len() == 1 => { + let arg = args.pop().expect("1 > 0"); + constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); + constraint_db.push(Constraint::Equivalent(loc, arg, ret)); + changed_something = true; + } + + ir::Primitive::Minus => { + let right = args.pop().expect("2 > 0"); + let left = args.pop().expect("2 > 1"); + + // technically testing that both are numeric is redundant, but it might give + // a slightly helpful type error if we do both. + constraint_db.push(Constraint::NumericType(loc.clone(), left.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), right.clone())); + constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + left.clone(), + right, + )); + constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret)); + changed_something = true; + } + }, + } + } + + // If that didn't actually come up with anything, and we just recycled all the constraints + // back into the database unchanged, then let's take a look for cases in which we just + // wanted something we didn't know to be a number. Basically, those are cases where the + // user just wrote a number, but didn't tell us what type it was, and there isn't enough + // information in the context to tell us. If that happens, we'll just set that type to + // be u64, and warn the user that we did so. + if !changed_something && !constraint_db.is_empty() { + local_constraints = constraint_db; + constraint_db = vec![]; + + for constraint in local_constraints.drain(..) { + match constraint { + Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => { + let resty = Type::Primitive(PrimitiveType::U64); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + t, + Type::Primitive(PrimitiveType::U64), + )); + warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty)); + changed_something = true; + } + + _ => constraint_db.push(constraint), + } + } + } + } + + // OK, we left our loop. Which means that either we solved everything, or we didn't. + // If we didn't, turn the unsolved constraints into type inference errors, and add + // them to our error list. + let mut unsolved_constraint_errors = constraint_db + .drain(..) + .map(TypeInferenceError::CouldNotSolve) + .collect(); + errors.append(&mut unsolved_constraint_errors); + + // How'd we do? + if errors.is_empty() { + TypeInferenceResult::Success { + result: resolutions, + warnings, + } + } else { + TypeInferenceResult::Failure { errors, warnings } + } +}