diff --git a/src/backend.rs b/src/backend.rs index 4ac8e83..c6f217b 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -39,6 +39,7 @@ use cranelift_codegen::{isa, settings}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{default_libcall_names, DataDescription, DataId, FuncId, Linkage, Module}; use cranelift_object::{ObjectBuilder, ObjectModule}; +use internment::ArcIntern; use std::collections::HashMap; use target_lexicon::Triple; @@ -58,7 +59,7 @@ pub struct Backend { data_ctx: DataDescription, runtime_functions: RuntimeFunctions, defined_strings: HashMap, - defined_symbols: HashMap, + defined_symbols: HashMap, (DataId, ConstantType)>, output_buffer: Option, platform: Triple, } @@ -181,7 +182,8 @@ impl Backend { .declare_data(&name, Linkage::Export, true, false)?; self.module.define_data(id, &self.data_ctx)?; self.data_ctx.clear(); - self.defined_symbols.insert(name, (id, ctype)); + self.defined_symbols + .insert(ArcIntern::new(name), (id, ctype)); Ok(id) } diff --git a/src/backend/error.rs b/src/backend/error.rs index 7509da3..f7252e9 100644 --- a/src/backend/error.rs +++ b/src/backend/error.rs @@ -41,6 +41,8 @@ pub enum BackendError { Write(#[from] cranelift_object::object::write::Error), #[error("Invalid type cast from {from} to {to}")] InvalidTypeCast { from: PrimitiveType, to: Type }, + #[error("Unknown string constant '{0}")] + UnknownString(ArcIntern), } impl From for Diagnostic { @@ -69,6 +71,8 @@ impl From for Diagnostic { BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message( format!("Internal error trying to cast from {} to {}", from, to), ), + BackendError::UnknownString(str) => Diagnostic::error() + .with_message(format!("Unknown string found trying to compile: '{}'", str)), } } } @@ -119,6 +123,11 @@ impl PartialEq for BackendError { } => from1 == from2 && to1 == to2, _ => false, }, + + BackendError::UnknownString(a) => match other { + BackendError::UnknownString(b) => a == b, + _ => false, + }, } } } diff --git a/src/backend/eval.rs b/src/backend/eval.rs index 78ed7a5..81129eb 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -1,12 +1,14 @@ use crate::backend::Backend; use crate::eval::EvalError; -use crate::ir::Program; +use crate::ir::{Expression, Program, TopLevel, Type}; #[cfg(test)] use crate::syntax::arbitrary::GenerationEnvironment; +use crate::syntax::Location; use cranelift_jit::JITModule; use cranelift_object::ObjectModule; #[cfg(test)] use proptest::arbitrary::Arbitrary; +use std::collections::HashMap; use std::path::Path; use target_lexicon::Triple; @@ -24,9 +26,36 @@ impl Backend { /// library do. So, if you're validating equivalence between them, you'll want to weed /// out examples that overflow/underflow before checking equivalence. (This is the behavior /// of the built-in test systems.) - pub fn eval(program: Program) -> Result { + pub fn eval(program: Program) -> Result { let mut jitter = Backend::jit(Some(String::new()))?; - let function_id = jitter.compile_function("test", program)?; + let mut function_map = HashMap::new(); + let mut main_function_body = vec![]; + + for item in program.items { + match item { + TopLevel::Function(name, args, rettype, body) => { + let function_id = + jitter.compile_function(name.as_str(), args.as_slice(), rettype, body)?; + function_map.insert(name, function_id); + } + + TopLevel::Statement(stmt) => { + main_function_body.push(stmt); + } + } + } + + let main_function_body = Expression::Block( + Location::manufactured(), + Type::Primitive(crate::eval::PrimitiveType::Void), + main_function_body, + ); + let function_id = jitter.compile_function( + "___test_jit_eval___", + &[], + Type::Primitive(crate::eval::PrimitiveType::Void), + main_function_body, + )?; jitter.module.finalize_definitions()?; let compiled_bytes = jitter.bytes(function_id); let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; @@ -51,17 +80,44 @@ impl Backend { /// library do. So, if you're validating equivalence between them, you'll want to weed /// out examples that overflow/underflow before checking equivalence. (This is the behavior /// of the built-in test systems.) - pub fn eval(program: Program) -> Result { + pub fn eval(program: Program) -> Result { //use pretty::{Arena, Pretty}; //let allocator = Arena::<()>::new(); //program.pretty(&allocator).render(80, &mut std::io::stdout())?; - let mut backend = Self::object_file(Triple::host())?; + let mut function_map = HashMap::new(); + let mut main_function_body = vec![]; + + for item in program.items { + match item { + TopLevel::Function(name, args, rettype, body) => { + let function_id = + backend.compile_function(name.as_str(), args.as_slice(), rettype, body)?; + function_map.insert(name, function_id); + } + + TopLevel::Statement(stmt) => { + main_function_body.push(stmt); + } + } + } + + let main_function_body = Expression::Block( + Location::manufactured(), + Type::Primitive(crate::eval::PrimitiveType::Void), + main_function_body, + ); + let my_directory = tempfile::tempdir()?; let object_path = my_directory.path().join("object.o"); let executable_path = my_directory.path().join("test_executable"); - backend.compile_function("gogogo", program)?; + backend.compile_function( + "gogogo", + &[], + Type::Primitive(crate::eval::PrimitiveType::Void), + main_function_body, + )?; let bytes = backend.bytes()?; std::fs::write(&object_path, bytes)?; Self::link(&object_path, &executable_path)?; diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index a9989de..0f77339 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -1,15 +1,15 @@ use std::collections::HashMap; use crate::eval::PrimitiveType; -use crate::ir::{Expression, Primitive, Program, Statement, TopLevel, Type, Value, ValueOrRef}; -use crate::syntax::ConstantType; -use cranelift_codegen::entity::EntityRef; +use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; +use crate::syntax::{ConstantType, Location}; +use crate::util::scoped_map::ScopedMap; use cranelift_codegen::ir::{ - self, entities, types, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName, + self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, Signature, UserFuncName, }; use cranelift_codegen::isa::CallConv; use cranelift_codegen::Context; -use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; +use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext}; use cranelift_module::{FuncId, Linkage, Module}; use internment::ArcIntern; @@ -24,25 +24,101 @@ use crate::backend::Backend; /// This just a handy type alias to avoid a lot of confusion in the functions. type StringTable = HashMap, GlobalValue>; +/// When we're talking about variables, it's handy to just have a table that points +/// from a variable to "what to do if you want to reference this variable", which is +/// agnostic about whether the variable is local, global, an argument, etc. Since +/// the type of that function is a little bit annoying, we summarize it here. +struct ReferenceBuilder { + ir_type: ConstantType, + cranelift_type: cranelift_codegen::ir::Type, + local_data: GlobalValue, +} + +impl ReferenceBuilder { + fn refer_to(&self, builder: &mut FunctionBuilder) -> (entities::Value, ConstantType) { + let value = builder.ins().symbol_value(self.cranelift_type, self.local_data); + (value, self.ir_type) + } +} + impl Backend { - /// Compile the given `Program` into a function with the given name. + /// Translate the given IR type into an ABI parameter type for cranelift, as + /// best as possible. + fn translate_type(&self, t: &Type) -> AbiParam { + let (value_type, extension) = match t { + Type::Function(_, _) => ( + types::Type::triple_pointer_type(&self.platform), + ir::ArgumentExtension::None, + ), + Type::Primitive(PrimitiveType::Void) => (types::I8, ir::ArgumentExtension::None), // FIXME? + Type::Primitive(PrimitiveType::I8) => (types::I8, ir::ArgumentExtension::Sext), + Type::Primitive(PrimitiveType::I16) => (types::I16, ir::ArgumentExtension::Sext), + Type::Primitive(PrimitiveType::I32) => (types::I32, ir::ArgumentExtension::Sext), + Type::Primitive(PrimitiveType::I64) => (types::I64, ir::ArgumentExtension::Sext), + Type::Primitive(PrimitiveType::U8) => (types::I8, ir::ArgumentExtension::Uext), + Type::Primitive(PrimitiveType::U16) => (types::I16, ir::ArgumentExtension::Uext), + Type::Primitive(PrimitiveType::U32) => (types::I32, ir::ArgumentExtension::Uext), + Type::Primitive(PrimitiveType::U64) => (types::I64, ir::ArgumentExtension::Uext), + }; + + AbiParam { + value_type, + purpose: ir::ArgumentPurpose::Normal, + extension, + } + } + + /// Compile the given program. /// - /// At some point, the use of `Program` is going to change; however, for the - /// moment, we have no notion of a function in our language so the whole input - /// is converted into a single output function. The type of the generated - /// function is, essentially, `fn() -> ()`: it takes no arguments and returns - /// no value. - /// - /// The function provided can then be either written to a file (if using a - /// static Cranelift backend) or executed directly (if using the Cranelift JIT). + /// The returned value is a `FuncId` that represents a function that runs all the statements + /// found in the program, which will be compiled using the given function name. (If there + /// are no such statements, the function will do nothing.) + pub fn compile_program( + &mut self, + function_name: &str, + program: Program, + ) -> Result { + let mut generated_body = vec![]; + + for item in program.items { + match item { + TopLevel::Function(name, args, rettype, body) => { + self.compile_function(name.as_str(), &args, rettype, body); + } + + TopLevel::Statement(stmt) => { + generated_body.push(stmt); + } + } + } + + let void = Type::Primitive(PrimitiveType::Void); + self.compile_function( + function_name, + &[], + void.clone(), + Expression::Block(Location::manufactured(), void, generated_body), + ) + } + + /// Compile the given function. pub fn compile_function( &mut self, function_name: &str, - mut program: Program, + arguments: &[(Variable, Type)], + return_type: Type, + body: Expression, ) -> Result { let basic_signature = Signature { - params: vec![], - returns: vec![], + params: arguments + .iter() + .map(|(_, t)| self.translate_type(t)) + .collect(), + returns: if return_type == Type::Primitive(PrimitiveType::Void) { + vec![] + } else { + vec![self.translate_type(&return_type)] + }, call_conv: CallConv::triple_default(&self.platform), }; @@ -63,13 +139,6 @@ impl Backend { let user_func_name = UserFuncName::user(0, func_id.as_u32()); ctx.func = Function::with_name_signature(user_func_name, basic_signature); - // We generate a table of every string that we use in the program, here. - // Cranelift is going to require us to have this in a particular structure - // (`GlobalValue`) so that we can reference them later, and it's going to - // be tricky to generate those on the fly. So we just generate the set we - // need here, and then have ir around in the table for later. - let string_table = self.build_string_table(&mut ctx.func, &program)?; - // In the future, we might want to see what runtime functions the function // we were given uses, and then only include those functions that we care // about. Presumably, we'd use some sort of lookup table like we do for @@ -82,25 +151,32 @@ impl Backend { &mut ctx.func, )?; - // In the case of the JIT, there may be symbols we've already defined outside - // the context of this particular `Progam`, which we might want to reference. - // Just like with strings, generating the `GlobalValue`s we need can potentially - // be a little tricky to do on the fly, so we generate the complete list right - // here and then use it later. - let pre_defined_symbols: HashMap = self - .defined_symbols - .iter() - .map(|(k, (v, t))| { - let local_data = self.module.declare_data_in_func(*v, &mut ctx.func); - (k.clone(), (local_data, *t)) - }) - .collect(); + // Let's start creating the variable table we'll use when we're dereferencing + // them later. This table is a little interesting because instead of pointing + // from data to data, we're going to point from data (the variable) to an + // action to take if we encounter that variable at some later point. This + // makes it nice and easy to have many different ways to access data, such + // as globals, function arguments, etc. + let mut variables: ScopedMap, ReferenceBuilder> = ScopedMap::new(); - // The last table we're going to need is our local variable table, to store - // variables used in this `Program` but not used outside of it. For whatever - // reason, Cranelift requires us to generate unique indexes for each of our - // variables; we just use a simple incrementing counter for that. - let mut variable_table = HashMap::new(); + // At the outer-most scope of things, we'll put global variables we've defined + // elsewhere in the program. + for (name, (data_id, ty)) in self.defined_symbols.iter() { + let local_data = self.module.declare_data_in_func(*data_id, &mut ctx.func); + let cranelift_type = ir::Type::from(*ty); + variables.insert( + name.clone(), + ReferenceBuilder { cranelift_type, local_data, ir_type: *ty }, + ); + } + + // Once we have these, we're going to actually push a level of scope and + // add our arguments. We push scope because if there happen to be any with + // the same name (their shouldn't be, but just in case), we want the arguments + // to win. + variables.new_scope(); + + // FIXME: Add arguments let mut next_var_num = 1; // Finally (!), we generate the function builder that we're going to use to @@ -114,98 +190,13 @@ impl Backend { let main_block = builder.create_block(); builder.switch_to_block(main_block); - // Compiling a function is just compiling each of the statements in order. - // At the moment, we do the pattern match for statements here, and then - // directly compile the statements. If/when we add more statement forms, - // this is likely to become more cumbersome, and we'll want to separate - // these off. But for now, given the amount of tables we keep around to track - // state, it's easier to just include them. - for item in program.items.drain(..) { - match item { - TopLevel::Function(_, _, _, _) => unimplemented!(), - - // Print statements are fairly easy to compile: we just lookup the - // output buffer, the address of the string to print, and the value - // of whatever variable we're printing. Then we just call print. - TopLevel::Statement(Statement::Print(ann, t, var)) => { - // Get the output buffer (or null) from our general compilation context. - let buffer_ptr = self.output_buffer_ptr(); - let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); - - // Get a reference to the string we want to print. - let local_name_ref = string_table.get(&var).unwrap(); - let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref); - - // Look up the value for the variable. Because this might be a - // global variable (and that requires special logic), we just turn - // this into an `Expression` and re-use the logic in that implementation. - let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane( - &mut builder, - &variable_table, - &pre_defined_symbols, - )?; - - let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); - - let casted_val = match vtype { - ConstantType::U64 | ConstantType::I64 => val, - ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => { - builder.ins().sextend(types::I64, val) - } - ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => { - builder.ins().uextend(types::I64, val) - } - }; - - // Finally, we can generate the call to print. - builder.ins().call( - print_func_ref, - &[buffer_ptr, name_ptr, vtype_repr, casted_val], - ); - } - - // Variable binding is a little more con - TopLevel::Statement(Statement::Binding(_, var_name, _, value)) => { - // Kick off to the `Expression` implementation to see what value we're going - // to bind to this variable. - let (val, etype) = - value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?; - - // Now the question is: is this a local variable, or a global one? - if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) { - // It's a global variable! In this case, we assume that someone has already - // dedicated some space in memory to store this value. We look this location - // up, and then tell Cranelift to store the value there. - assert_eq!(etype, *ctype); - let val_ptr = builder - .ins() - .symbol_value(ir::Type::from(*ctype), *global_id); - builder.ins().store(MemFlags::new(), val, val_ptr, 0); - } else { - // It's a local variable! In this case, we need to allocate a new Cranelift - // `Variable` for this variable, which we do using our `next_var_num` counter. - // (While we're doing this, we also increment `next_var_num`, so that we get - // a fresh `Variable` next time. This is one of those very narrow cases in which - // I wish Rust had an increment expression.) - let var = Variable::new(next_var_num); - next_var_num += 1; - - // We can add the variable directly to our local variable map; it's `Copy`. - variable_table.insert(var_name, (var, etype)); - - // Now we tell Cranelift about our new variable! - builder.declare_var(var, ir::Type::from(etype)); - builder.def_var(var, val); - } - } - } - } + let (value, _) = self.compile_expression(body, &mut variables, &mut builder)?; // Now that we're done, inject a return function (one with no actual value; basically // the equivalent of Rust's `return;`). We then seal the block (which lets Cranelift // know that the block is done), and then finalize the function (which lets Cranelift // know we're done with the function). - builder.ins().return_(&[]); + builder.ins().return_(&[value]); builder.seal_block(main_block); builder.finalize(); @@ -219,45 +210,18 @@ impl Backend { Ok(func_id) } - // Build the string table for use in referencing strings later. - // - // This function is slightly smart, in that it only puts strings in the table that - // are used by the `Program`. (Thanks to `Progam::strings()`!) If the strings have - // been declared globally, via `Backend::define_string()`, we will re-use that data. - // Otherwise, this will define the string for you. - fn build_string_table( + /// Compile an expression, returning the Cranelift Value for the expression and + /// its type. + fn compile_expression( &mut self, - func: &mut Function, - program: &Program, - ) -> Result { - let mut string_table = HashMap::new(); - - for interned_value in program.strings().drain() { - let global_id = match self.defined_strings.get(interned_value.as_str()) { - Some(x) => *x, - None => self.define_string(interned_value.as_str())?, - }; - let local_data = self.module.declare_data_in_func(global_id, func); - string_table.insert(interned_value, local_data); - } - - Ok(string_table) - } -} - -impl Expression { - fn into_crane( - self, + expr: Expression, + variables: &mut ScopedMap, builder: &mut FunctionBuilder, - local_variables: &HashMap, (Variable, ConstantType)>, - global_variables: &HashMap, ) -> Result<(entities::Value, ConstantType), BackendError> { - match self { - Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables), - - Expression::Cast(_, target_type, expr) => { - let (val, val_type) = - expr.into_crane(builder, local_variables, global_variables)?; + match expr { + Expression::Atomic(x) => self.compile_value_or_ref(x, variables, builder), + Expression::Cast(_, target_type, valref) => { + let (val, val_type) = self.compile_value_or_ref(valref, variables, builder)?; match (val_type, &target_type) { (ConstantType::I8, Type::Primitive(PrimitiveType::I8)) => Ok((val, val_type)), @@ -325,7 +289,7 @@ impl Expression { for val in vals.drain(..) { let (compiled, compiled_type) = - val.into_crane(builder, local_variables, global_variables)?; + self.compile_value_or_ref(val, variables, builder)?; if let Some(leftmost_type) = first_type { assert_eq!(leftmost_type, compiled_type); @@ -355,22 +319,79 @@ impl Expression { Primitive::Divide => Ok((builder.ins().udiv(values[0], values[1]), first_type)), } } + + Expression::Block(_, _, mut exprs) => match exprs.pop() { + None => Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)), + Some(last) => { + for inner in exprs { + // we can ignore all of these return values and such, because we + // don't actually use them anywhere + self.compile_expression(inner, variables, builder); + } + // instead, we just return the last one + self.compile_expression(last, variables, builder) + } + }, + + Expression::Print(ann, var) => { + // Get the output buffer (or null) from our general compilation context. + let buffer_ptr = self.output_buffer_ptr(); + let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); + + // Get a reference to the string we want to print. + let string_data_id = self + .defined_strings + .get(var.as_ref()) + .ok_or_else(|| BackendError::UnknownString(var.clone()))?; + let local_name_ref = self + .module + .declare_data_in_func(*string_data_id, builder.func); + let name_ptr = builder.ins().symbol_value(types::I64, local_name_ref); + + // Look up the value for the variable. Because this might be a + // global variable (and that requires special logic), we just turn + // this into an `Expression` and re-use the logic in that implementation. + let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var); + let (val, vtype) = self.compile_value_or_ref(fake_ref, variables, builder)?; + + let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); + + let casted_val = match vtype { + ConstantType::U64 | ConstantType::I64 => val, + ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => { + builder.ins().sextend(types::I64, val) + } + ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => { + builder.ins().uextend(types::I64, val) + } + }; + + // Finally, we can generate the call to print. + let print_func_ref = self.runtime_functions.include_runtime_function( + "print", + &mut self.module, + builder.func, + )?; + builder.ins().call( + print_func_ref, + &[buffer_ptr, name_ptr, vtype_repr, casted_val], + ); + Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)) + } + + Expression::Bind(_, _, _, _) => unimplemented!(), } } -} -// Just to avoid duplication, this just leverages the `From` trait implementation -// for `ValueOrRef` to compile this via the `Expression` logic, above. -impl ValueOrRef { - fn into_crane( - self, + /// Compile a value or reference into Cranelift, returning the Cranelift Value for + /// the expression and its type. + fn compile_value_or_ref( + &self, + valref: ValueOrRef, + variables: &ScopedMap, builder: &mut FunctionBuilder, - local_variables: &HashMap, (Variable, ConstantType)>, - global_variables: &HashMap, ) -> Result<(entities::Value, ConstantType), BackendError> { - match self { - // Values are pretty straightforward to compile, mostly because we only - // have one type of variable, and it's an integer type. + match valref { ValueOrRef::Value(_, _, val) => match val { Value::I8(_, v) => { Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8)) @@ -400,31 +421,217 @@ impl ValueOrRef { ConstantType::U64, )), }, - - ValueOrRef::Ref(_, _, name) => { - // first we see if this is a local variable (which is nicer, from an - // optimization point of view.) - if let Some((local_var, etype)) = local_variables.get(&name) { - return Ok((builder.use_var(*local_var), *etype)); - } - - // then we check to see if this is a global reference, which requires us to - // first lookup where the value is stored, and then load it. - if let Some((global_var, etype)) = global_variables.get(name.as_ref()) { - let cranelift_type = ir::Type::from(*etype); - let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var); - return Ok(( - builder - .ins() - .load(cranelift_type, MemFlags::new(), val_ptr, 0), - *etype, - )); - } - - // this should never happen, because we should have made sure that there are - // no unbound variables a long time before this. but still ... - Err(BackendError::VariableLookupFailure(name)) - } + ValueOrRef::Ref(_, _, name) => match variables.get(&name) { + None => Err(BackendError::VariableLookupFailure(name)), + Some(x) => Ok(x.refer_to(builder)), + }, } } + + // Compiling a function is just compiling each of the statements in order. + // At the moment, we do the pattern match for statements here, and then + // directly compile the statements. If/when we add more statement forms, + // this is likely to become more cumbersome, and we'll want to separate + // these off. But for now, given the amount of tables we keep around to track + // state, it's easier to just include them. + // for item in program.items.drain(..) { + // match item { + // TopLevel::Function(_, _, _) => unimplemented!(), + // + // // Print statements are fairly easy to compile: we just lookup the + // // output buffer, the address of the string to print, and the value + // // of whatever variable we're printing. Then we just call print. + // TopLevel::Statement(Statement::Print(ann, t, var)) => { + // // Get the output buffer (or null) from our general compilation context. + // let buffer_ptr = self.output_buffer_ptr(); + // let buffer_ptr = builder.ins().iconst(types::I64, buffer_ptr as i64); + // + // // Get a reference to the string we want to print. + // let local_name_ref = string_table.get(&var).unwrap(); + // let name_ptr = builder.ins().symbol_value(types::I64, *local_name_ref); + // + // // Look up the value for the variable. Because this might be a + // // global variable (and that requires special logic), we just turn + // // this into an `Expression` and re-use the logic in that implementation. + // let (val, vtype) = ValueOrRef::Ref(ann, t, var).into_crane( + // &mut builder, + // &variable_table, + // &pre_defined_symbols, + // )?; + // + // let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); + // + // let casted_val = match vtype { + // ConstantType::U64 | ConstantType::I64 => val, + // ConstantType::I8 | ConstantType::I16 | ConstantType::I32 => { + // builder.ins().sextend(types::I64, val) + // } + // ConstantType::U8 | ConstantType::U16 | ConstantType::U32 => { + // builder.ins().uextend(types::I64, val) + // } + // }; + // + // // Finally, we can generate the call to print. + // builder.ins().call( + // print_func_ref, + // &[buffer_ptr, name_ptr, vtype_repr, casted_val], + // ); + // } + // + // // Variable binding is a little more con + // TopLevel::Statement(Statement::Binding(_, var_name, _, value)) => { + // // Kick off to the `Expression` implementation to see what value we're going + // // to bind to this variable. + // let (val, etype) = + // value.into_crane(&mut builder, &variable_table, &pre_defined_symbols)?; + // + // // Now the question is: is this a local variable, or a global one? + // if let Some((global_id, ctype)) = pre_defined_symbols.get(var_name.as_str()) { + // // It's a global variable! In this case, we assume that someone has already + // // dedicated some space in memory to store this value. We look this location + // // up, and then tell Cranelift to store the value there. + // assert_eq!(etype, *ctype); + // let val_ptr = builder + // .ins() + // .symbol_value(ir::Type::from(*ctype), *global_id); + // builder.ins().store(MemFlags::new(), val, val_ptr, 0); + // } else { + // // It's a local variable! In this case, we need to allocate a new Cranelift + // // `Variable` for this variable, which we do using our `next_var_num` counter. + // // (While we're doing this, we also increment `next_var_num`, so that we get + // // a fresh `Variable` next time. This is one of those very narrow cases in which + // // I wish Rust had an increment expression.) + // let var = Variable::new(next_var_num); + // next_var_num += 1; + // + // // We can add the variable directly to our local variable map; it's `Copy`. + // variable_table.insert(var_name, (var, etype)); + // + // // Now we tell Cranelift about our new variable! + // builder.declare_var(var, ir::Type::from(etype)); + // builder.def_var(var, val); + // } + // } + // } + // } + + // Build the string table for use in referencing strings later. + // + // This function is slightly smart, in that it only puts strings in the table that + // are used by the `Program`. (Thanks to `Progam::strings()`!) If the strings have + // been declared globally, via `Backend::define_string()`, we will re-use that data. + // Otherwise, this will define the string for you. + // fn build_string_table( + // &mut self, + // func: &mut Function, + // program: &Expression, + // ) -> Result { + // let mut string_table = HashMap::new(); + // + // for interned_value in program.strings().drain() { + // let global_id = match self.defined_strings.get(interned_value.as_str()) { + // Some(x) => *x, + // None => self.define_string(interned_value.as_str())?, + // }; + // let local_data = self.module.declare_data_in_func(global_id, func); + // string_table.insert(interned_value, local_data); + // } + // + // Ok(string_table) + // } } + +//impl Expression { +// fn into_crane( +// self, +// builder: &mut FunctionBuilder, +// local_variables: &HashMap, (Variable, ConstantType)>, +// global_variables: &HashMap, +// ) -> Result<(entities::Value, ConstantType), BackendError> { +// match self { +// Expression::Atomic(x) => x.into_crane(builder, local_variables, global_variables), +// +// Expression::Cast(_, target_type, expr) => { +// let (val, val_type) = +// expr.into_crane(builder, local_variables, global_variables)?; +// +// match (val_type, &target_type) { +// } +// } +// +// Expression::Primitive(_, _, prim, mut vals) => { +// } +// } +// } +//} +// +//// Just to avoid duplication, this just leverages the `From` trait implementation +//// for `ValueOrRef` to compile this via the `Expression` logic, above. +//impl ValueOrRef { +// fn into_crane( +// self, +// builder: &mut FunctionBuilder, +// local_variables: &HashMap, (Variable, ConstantType)>, +// global_variables: &HashMap, +// ) -> Result<(entities::Value, ConstantType), BackendError> { +// match self { +// // Values are pretty straightforward to compile, mostly because we only +// // have one type of variable, and it's an integer type. +// ValueOrRef::Value(_, _, val) => match val { +// Value::I8(_, v) => { +// Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::I8)) +// } +// Value::I16(_, v) => Ok(( +// builder.ins().iconst(types::I16, v as i64), +// ConstantType::I16, +// )), +// Value::I32(_, v) => Ok(( +// builder.ins().iconst(types::I32, v as i64), +// ConstantType::I32, +// )), +// Value::I64(_, v) => Ok((builder.ins().iconst(types::I64, v), ConstantType::I64)), +// Value::U8(_, v) => { +// Ok((builder.ins().iconst(types::I8, v as i64), ConstantType::U8)) +// } +// Value::U16(_, v) => Ok(( +// builder.ins().iconst(types::I16, v as i64), +// ConstantType::U16, +// )), +// Value::U32(_, v) => Ok(( +// builder.ins().iconst(types::I32, v as i64), +// ConstantType::U32, +// )), +// Value::U64(_, v) => Ok(( +// builder.ins().iconst(types::I64, v as i64), +// ConstantType::U64, +// )), +// }, +// +// ValueOrRef::Ref(_, _, name) => { +// // first we see if this is a local variable (which is nicer, from an +// // optimization point of view.) +// if let Some((local_var, etype)) = local_variables.get(&name) { +// return Ok((builder.use_var(*local_var), *etype)); +// } +// +// // then we check to see if this is a global reference, which requires us to +// // first lookup where the value is stored, and then load it. +// if let Some((global_var, etype)) = global_variables.get(name.as_ref()) { +// let cranelift_type = ir::Type::from(*etype); +// let val_ptr = builder.ins().symbol_value(cranelift_type, *global_var); +// return Ok(( +// builder +// .ins() +// .load(cranelift_type, MemFlags::new(), val_ptr, 0), +// *etype, +// )); +// } +// +// // this should never happen, because we should have made sure that there are +// // no unbound variables a long time before this. but still ... +// Err(BackendError::VariableLookupFailure(name)) +// } +// } +// } +//} +// diff --git a/src/compiler.rs b/src/compiler.rs index b94a200..22c8029 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -134,7 +134,7 @@ impl Compiler { // Finally, send all this to Cranelift for conversion into an object file. let mut backend = Backend::object_file(Triple::host())?; - backend.compile_function("gogogo", ir)?; + backend.compile_program("gogogo", ir)?; Ok(Some(backend.bytes()?)) } diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index de1e6eb..3c5818e 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -6,6 +6,7 @@ use std::{fmt::Display, str::FromStr}; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum PrimitiveType { + Void, U8, U16, U32, @@ -19,6 +20,7 @@ pub enum PrimitiveType { impl Display for PrimitiveType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + PrimitiveType::Void => write!(f, "void"), PrimitiveType::I8 => write!(f, "i8"), PrimitiveType::I16 => write!(f, "i16"), PrimitiveType::I32 => write!(f, "i32"), @@ -100,6 +102,7 @@ impl PrimitiveType { /// Return true if this type can be safely cast into the target type. pub fn can_cast_to(&self, target: &PrimitiveType) -> bool { match self { + PrimitiveType::Void => matches!(target, PrimitiveType::Void), PrimitiveType::U8 => matches!( target, PrimitiveType::U8 @@ -175,16 +178,17 @@ impl PrimitiveType { } } - pub fn max_value(&self) -> u64 { + pub fn max_value(&self) -> Option { match self { - PrimitiveType::U8 => u8::MAX as u64, - PrimitiveType::U16 => u16::MAX as u64, - PrimitiveType::U32 => u32::MAX as u64, - PrimitiveType::U64 => u64::MAX, - PrimitiveType::I8 => i8::MAX as u64, - PrimitiveType::I16 => i16::MAX as u64, - PrimitiveType::I32 => i32::MAX as u64, - PrimitiveType::I64 => i64::MAX as u64, + PrimitiveType::Void => None, + PrimitiveType::U8 => Some(u8::MAX as u64), + PrimitiveType::U16 => Some(u16::MAX as u64), + PrimitiveType::U32 => Some(u32::MAX as u64), + PrimitiveType::U64 => Some(u64::MAX), + PrimitiveType::I8 => Some(i8::MAX as u64), + PrimitiveType::I16 => Some(i16::MAX as u64), + PrimitiveType::I32 => Some(i32::MAX as u64), + PrimitiveType::I64 => Some(i64::MAX as u64), } } } diff --git a/src/ir/ast.rs b/src/ir/ast.rs index c63f5db..728a4d0 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -9,36 +9,58 @@ use proptest::{ prelude::Arbitrary, strategy::{BoxedStrategy, Strategy}, }; -use std::{fmt, str::FromStr}; +use std::{fmt, str::FromStr, sync::atomic::AtomicUsize}; /// We're going to represent variables as interned strings. /// /// These should be fast enough for comparison that it's OK, since it's going to end up /// being pretty much the pointer to the string. -type Variable = ArcIntern; +pub type Variable = ArcIntern; + +/// Generate a new symbol that is guaranteed to be different from every other symbol +/// currently known. +/// +/// This function will use the provided string as a base name for the symbol, but +/// extend it with numbers and characters to make it unique. While technically you +/// could roll-over these symbols, you probably don't need to worry about it. +pub fn gensym(base: &str) -> Variable { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + ArcIntern::new(format!( + "{}<{}>", + base, + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + )) +} /// The representation of a program within our IR. For now, this is exactly one file. /// -/// In addition, for the moment there's not really much of interest to hold here besides -/// the list of statements read from the file. Order is important. In the future, you -/// could imagine caching analysis information in this structure. +/// A program consists of a series of statements and functions. The statements should +/// be executed in order. The functions currently may not reference any variables +/// at the top level, so their order only matters in relation to each other (functions +/// may not be referenced before they are defined). /// /// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used /// to print the structure whenever possible, especially if you value your or your /// user's time. The latter is useful for testing that conversions of `Program` retain /// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be /// syntactically valid, although they may contain runtime issue like over- or underflow. +/// +/// The type variable is, somewhat confusingly, the current definition of a type within +/// the IR. Since the makeup of this structure may change over the life of the compiler, +/// it's easiest to just make it an argument. #[derive(Debug)] -pub struct Program { +pub struct Program { // For now, a program is just a vector of statements. In the future, we'll probably // extend this to include a bunch of other information, but for now: just a list. - pub(crate) items: Vec, + pub(crate) items: Vec>, } -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program +impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b Program where A: 'a, D: ?Sized + DocAllocator<'a, A>, + &'b Type: Pretty<'a, D, A>, { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { let mut result = allocator.nil(); @@ -56,17 +78,18 @@ where } } -impl Arbitrary for Program { +impl Arbitrary for Program { type Parameters = crate::syntax::arbitrary::GenerationEnvironment; type Strategy = BoxedStrategy; fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - crate::syntax::Program::arbitrary_with(args) - .prop_map(|x| { - x.type_infer() - .expect("arbitrary_with should generate type-correct programs") - }) - .boxed() + unimplemented!() + //crate::syntax::Program::arbitrary_with(args) + // .prop_map(|x| { + // x.type_infer() + // .expect("arbitrary_with should generate type-correct programs") + // }) + // .boxed() } } @@ -76,84 +99,35 @@ impl Arbitrary for Program { /// will likely be added in the future, but for now: just statements /// and functions #[derive(Debug)] -pub enum TopLevel { - Statement(Statement), - Function(Variable, Vec, Vec, Expression), +pub enum TopLevel { + Statement(Expression), + Function(Variable, Vec<(Variable, Type)>, Type, Expression), } -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b TopLevel +impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b TopLevel where A: 'a, D: ?Sized + DocAllocator<'a, A>, + &'b Type: Pretty<'a, D, A>, { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - TopLevel::Function(name, args, stmts, expr) => { - let base = allocator - .text("function") - .append(allocator.space()) - .append(allocator.text(name.as_ref().to_string())) - .append(allocator.space()) - .append( - pretty_comma_separated( - allocator, - &args.iter().map(PrettySymbol::from).collect(), - ) - .parens(), - ) - .append(allocator.space()); - - let mut body = allocator.nil(); - for stmt in stmts { - body = body - .append(stmt.pretty(allocator)) - .append(allocator.text(";")) - .append(allocator.hardline()); - } - body = body.append(expr.pretty(allocator)); - body = body.append(allocator.hardline()); - body = body.braces(); - base.append(body) - } - - TopLevel::Statement(stmt) => stmt.pretty(allocator), - } - } -} - -/// The representation of a statement in the language. -/// -/// For now, this is either a binding site (`x = 4`) or a print statement -/// (`print x`). Someday, though, more! -/// -/// As with `Program`, this type implements [`Pretty`], which should -/// be used to display the structure whenever possible. It does not -/// implement [`Arbitrary`], though, mostly because it's slightly -/// complicated to do so. -/// -#[derive(Debug)] -pub enum Statement { - Binding(Location, Variable, Type, Expression), - Print(Location, Type, Variable), -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - Statement::Binding(_, var, _, expr) => allocator - .text(var.as_ref().to_string()) + TopLevel::Function(name, args, _, expr) => allocator + .text("function") .append(allocator.space()) - .append(allocator.text("=")) + .append(allocator.text(name.as_ref().to_string())) + .append(allocator.space()) + .append( + pretty_comma_separated( + allocator, + &args.iter().map(|(x, _)| PrettySymbol::from(x)).collect(), + ) + .parens(), + ) .append(allocator.space()) .append(expr.pretty(allocator)), - Statement::Print(_, _, var) => allocator - .text("print") - .append(allocator.space()) - .append(allocator.text(var.as_ref().to_string())), + + TopLevel::Statement(stmt) => stmt.pretty(allocator), } } } @@ -171,21 +145,27 @@ where /// that the referenced data will always either be a constant or a /// variable reference. #[derive(Debug)] -pub enum Expression { - Atomic(ValueOrRef), - Cast(Location, Type, ValueOrRef), - Primitive(Location, Type, Primitive, Vec), +pub enum Expression { + Atomic(ValueOrRef), + Cast(Location, Type, ValueOrRef), + Primitive(Location, Type, Primitive, Vec>), + Block(Location, Type, Vec>), + Print(Location, Variable), + Bind(Location, Variable, Type, Box>), } -impl Expression { +impl Expression { /// Return a reference to the type of the expression, as inferred or recently /// computed. - pub fn type_of(&self) -> &Type { + pub fn type_of(&self) -> Type { match self { - Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, - Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, - Expression::Cast(_, t, _) => t, - Expression::Primitive(_, t, _, _) => t, + Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t.clone(), + Expression::Atomic(ValueOrRef::Value(_, t, _)) => t.clone(), + Expression::Cast(_, t, _) => t.clone(), + Expression::Primitive(_, t, _, _) => t.clone(), + Expression::Block(_, t, _) => t.clone(), + Expression::Print(_, _) => Type::void(), + Expression::Bind(_, _, _, _) => Type::void(), } } @@ -196,14 +176,18 @@ impl Expression { Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, Expression::Cast(l, _, _) => l, Expression::Primitive(l, _, _, _) => l, + Expression::Block(l, _, _) => l, + Expression::Print(l, _) => l, + Expression::Bind(l, _, _, _) => l, } } } -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression +impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b Expression where A: 'a, D: ?Sized + DocAllocator<'a, A>, + &'b Type: Pretty<'a, D, A>, { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { @@ -229,6 +213,35 @@ where Expression::Primitive(_, _, op, exprs) => { allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) } + Expression::Block(_, _, exprs) => match exprs.split_last() { + None => allocator.text("()"), + Some((last, &[])) => last.pretty(allocator), + Some((last, start)) => { + let mut result = allocator.text("{").append(allocator.hardline()); + + for stmt in start.iter() { + result = result + .append(stmt.pretty(allocator)) + .append(allocator.text(";")) + .append(allocator.hardline()); + } + + result + .append(last.pretty(allocator)) + .append(allocator.hardline()) + .append(allocator.text("}")) + } + }, + Expression::Print(_, var) => allocator + .text("print") + .append(allocator.space()) + .append(allocator.text(var.as_ref().to_string())), + Expression::Bind(_, var, _, expr) => allocator + .text(var.as_ref().to_string()) + .append(allocator.space()) + .append(allocator.text("=")) + .append(allocator.space()) + .append(expr.pretty(allocator)), } } } @@ -288,12 +301,12 @@ impl fmt::Display for Primitive { /// at this level. Instead, expressions that take arguments take one /// of these, which can only be a constant or a reference. #[derive(Clone, Debug)] -pub enum ValueOrRef { +pub enum ValueOrRef { Value(Location, Type, Value), Ref(Location, Type, ArcIntern), } -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef +impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b ValueOrRef where A: 'a, D: ?Sized + DocAllocator<'a, A>, @@ -306,8 +319,8 @@ where } } -impl From for Expression { - fn from(value: ValueOrRef) -> Self { +impl From> for Expression { + fn from(value: ValueOrRef) -> Self { Expression::Atomic(value) } } @@ -434,3 +447,121 @@ impl fmt::Display for Type { } } } + +impl From for Type { + fn from(value: PrimitiveType) -> Self { + Type::Primitive(value) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum TypeOrVar { + Primitive(PrimitiveType), + Variable(Location, ArcIntern), + Function(Vec, Box), +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b TypeOrVar +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + match self { + TypeOrVar::Primitive(x) => allocator.text(format!("{}", x)), + TypeOrVar::Variable(_, x) => allocator.text(x.to_string()), + TypeOrVar::Function(args, rettype) => { + pretty_comma_separated(allocator, &args.iter().collect()) + .parens() + .append(allocator.space()) + .append(allocator.text("->")) + .append(allocator.space()) + .append(rettype.pretty(allocator)) + } + } + } +} + +impl fmt::Display for TypeOrVar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TypeOrVar::Primitive(x) => x.fmt(f), + TypeOrVar::Variable(_, v) => write!(f, "{}", v), + TypeOrVar::Function(args, rettype) => { + write!(f, " write!(f, "()")?, + Some((single, &[])) => { + write!(f, "({})", single)?; + } + Some((last_one, rest)) => { + write!(f, "(")?; + for arg in rest.iter() { + write!(f, "{}, ", arg); + } + write!(f, "{})", last_one)?; + } + } + write!(f, "->")?; + rettype.fmt(f)?; + write!(f, ">") + } + } + } +} + +impl TypeOrVar { + /// Generate a fresh type variable that is different from all previous type variables. + /// + /// This type variable is guaranteed to be unique across the process lifetime. Overuse + /// of this function could potentially cause overflow problems, but you're going to have + /// to try really hard (like, 2^64 times) to make that happen. The location bound to + /// this address will be purely manufactured; if you want to specify a location, use + /// [`TypeOrVar::new_located`]. + pub fn new() -> Self { + Self::new_located(Location::manufactured()) + } + + /// Generate a fresh type variable that is different from all previous type variables. + /// + /// This type variable is guaranteed to be unique across the process lifetime. Overuse + /// of this function could potentially cause overflow problems, but you're going to have + /// to try really hard (like, 2^64 times) to make that happen. + pub fn new_located(loc: Location) -> Self { + TypeOrVar::Variable(loc, gensym("t")) + } +} + +trait TypeWithVoid { + fn void() -> Self; +} + +impl TypeWithVoid for Type { + fn void() -> Self { + Type::Primitive(PrimitiveType::Void) + } +} + +impl TypeWithVoid for TypeOrVar { + fn void() -> Self { + TypeOrVar::Primitive(PrimitiveType::Void) + } +} + +//impl From for TypeOrVar { +// fn from(value: Type) -> Self { +// TypeOrVar::Type(value) +// } +//} + +impl> From for TypeOrVar { + fn from(value: T) -> Self { + match value.into() { + Type::Primitive(p) => TypeOrVar::Primitive(p), + Type::Function(args, ret) => TypeOrVar::Function( + args.into_iter().map(Into::into).collect(), + Box::new((*ret).into()), + ), + } + } +} diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 93d56b0..af3e31e 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -1,8 +1,8 @@ use super::{Primitive, Type, ValueOrRef}; use crate::eval::{EvalEnvironment, EvalError, Value}; -use crate::ir::{Expression, Program, Statement, TopLevel}; +use crate::ir::{Expression, Program, TopLevel}; -impl Program { +impl Program { /// Evaluate the program, returning either an error or a string containing everything /// the program printed out. /// @@ -15,16 +15,7 @@ impl Program { match stmt { TopLevel::Function(_, _, _, _) => unimplemented!(), - TopLevel::Statement(Statement::Binding(_, name, _, value)) => { - let actual_value = value.eval(&env)?; - env = env.extend(name.clone(), actual_value); - } - - TopLevel::Statement(Statement::Print(_, _, name)) => { - let value = env.lookup(name.clone())?; - let line = format!("{} = {}\n", name, value); - stdout.push_str(&line); - } + TopLevel::Statement(_) => unimplemented!(), } } @@ -32,17 +23,21 @@ impl Program { } } -impl Expression { +impl Expression +where + T: Clone + Into, +{ fn eval(&self, env: &EvalEnvironment) -> Result { match self { Expression::Atomic(x) => x.eval(env), Expression::Cast(_, t, valref) => { let value = valref.eval(env)?; + let ty = t.clone().into(); - match t { + match ty { Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), - Type::Function(_, _) => Err(EvalError::CastToFunction(t.to_string())), + Type::Function(_, _) => Err(EvalError::CastToFunction(ty.to_string())), } } @@ -61,11 +56,19 @@ impl Expression { Primitive::Divide => Ok(Value::calculate("/", arg_values)?), } } + + Expression::Block(_, _, _) => { + unimplemented!() + } + + Expression::Print(_, _) => unimplemented!(), + + Expression::Bind(_, _, _, _) => unimplemented!(), } } } -impl ValueOrRef { +impl ValueOrRef { fn eval(&self, env: &EvalEnvironment) -> Result { match self { ValueOrRef::Value(_, _, v) => match v { diff --git a/src/ir/strings.rs b/src/ir/strings.rs index dc05432..5d6d07f 100644 --- a/src/ir/strings.rs +++ b/src/ir/strings.rs @@ -1,8 +1,8 @@ -use super::ast::{Expression, Program, Statement, TopLevel}; +use super::ast::{Expression, Program, TopLevel}; use internment::ArcIntern; use std::collections::HashSet; -impl Program { +impl Program { /// Get the complete list of strings used within the program. /// /// For the purposes of this function, strings are the variables used in @@ -18,37 +18,18 @@ impl Program { } } -impl TopLevel { +impl TopLevel { fn register_strings(&self, string_set: &mut HashSet>) { match self { - TopLevel::Function(_, _, stmts, body) => { - for stmt in stmts.iter() { - stmt.register_strings(string_set); - } - body.register_strings(string_set); - } + TopLevel::Function(_, _, _, body) => body.register_strings(string_set), TopLevel::Statement(stmt) => stmt.register_strings(string_set), } } } -impl Statement { - fn register_strings(&self, string_set: &mut HashSet>) { - match self { - Statement::Binding(_, name, _, expr) => { - string_set.insert(name.clone()); - expr.register_strings(string_set); - } - - Statement::Print(_, _, name) => { - string_set.insert(name.clone()); - } - } - } -} - -impl Expression { +impl Expression { fn register_strings(&self, _string_set: &mut HashSet>) { // nothing has a string in here, at the moment + unimplemented!() } } diff --git a/src/repl.rs b/src/repl.rs index 1600a0f..ea4803e 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -1,4 +1,5 @@ use crate::backend::{Backend, BackendError}; +use crate::eval::PrimitiveType; use crate::syntax::{ConstantType, Location, ParserError, Statement, TopLevel}; use crate::type_infer::TypeInferenceResult; use crate::util::scoped_map::ScopedMap; @@ -130,10 +131,6 @@ impl REPL { let syntax = TopLevel::parse(entry, source)?; let program = match syntax { - TopLevel::Function(_, _, _) => { - unimplemented!() - } - TopLevel::Statement(Statement::Binding(loc, name, expr)) => { // if this is a variable binding, and we've never defined this variable before, // we should tell cranelift about it. this is optimistic; if we fail to compile, @@ -152,9 +149,7 @@ impl REPL { } } - TopLevel::Statement(nonbinding) => crate::syntax::Program { - items: vec![TopLevel::Statement(nonbinding)], - }, + x => crate::syntax::Program { items: vec![x] }, }; let (mut errors, mut warnings) = @@ -197,8 +192,9 @@ impl REPL { for message in warnings.drain(..).map(Into::into) { self.emit_diagnostic(message)?; } + let name = format!("line{}", line_no); - let function_id = self.jitter.compile_function(&name, result)?; + let function_id = self.jitter.compile_program(&name, result)?; self.jitter.module.finalize_definitions()?; let compiled_bytes = self.jitter.bytes(function_id); let compiled_function = diff --git a/src/type_infer.rs b/src/type_infer.rs index 08cd2da..d47c7f4 100644 --- a/src/type_infer.rs +++ b/src/type_infer.rs @@ -10,7 +10,6 @@ //! all the constraints we've generated. If that's successful, in the final phase, we //! do the final conversion to the IR AST, filling in any type information we've learned //! along the way. -mod ast; mod convert; mod finalize; mod solve; @@ -32,9 +31,8 @@ impl syntax::Program { /// /// You really should have made sure that this program was validated before running /// this method, otherwise you may experience panics during operation. - pub fn type_infer(self) -> TypeInferenceResult { - let mut constraint_db = vec![]; - let program = convert_program(self, &mut constraint_db); + pub fn type_infer(self) -> TypeInferenceResult> { + let (program, constraint_db) = convert_program(self); let inference_result = solve_constraints(constraint_db); inference_result.map(|resolutions| finalize_program(program, &resolutions)) diff --git a/src/type_infer/ast.rs b/src/type_infer/ast.rs deleted file mode 100644 index 21eb2fc..0000000 --- a/src/type_infer/ast.rs +++ /dev/null @@ -1,408 +0,0 @@ -pub use crate::ir::ast::Primitive; -/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going -/// to want to use while we're doing type inference, but don't want to keep around -/// afterwards. These are: -/// -/// * A notion of a type variable -/// * An unknown numeric constant form -/// -use crate::{ - eval::PrimitiveType, - syntax::{self, ConstantType, Location}, - util::pretty::{pretty_comma_separated, PrettySymbol}, -}; -use internment::ArcIntern; -use pretty::{DocAllocator, Pretty}; -use std::fmt; -use std::sync::atomic::AtomicUsize; - -/// We're going to represent variables as interned strings. -/// -/// These should be fast enough for comparison that it's OK, since it's going to end up -/// being pretty much the pointer to the string. -type Variable = ArcIntern; - -/// The representation of a program within our IR. For now, this is exactly one file. -/// -/// In addition, for the moment there's not really much of interest to hold here besides -/// the list of statements read from the file. Order is important. In the future, you -/// could imagine caching analysis information in this structure. -/// -/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used -/// to print the structure whenever possible, especially if you value your or your -/// user's time. The latter is useful for testing that conversions of `Program` retain -/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be -/// syntactically valid, although they may contain runtime issue like over- or underflow. -#[derive(Debug)] -pub struct Program { - // For now, a program is just a vector of statements. In the future, we'll probably - // extend this to include a bunch of other information, but for now: just a list. - pub(crate) items: Vec, -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - let mut result = allocator.nil(); - - for stmt in self.items.iter() { - // there's probably a better way to do this, rather than constantly - // adding to the end, but this works. - result = result - .append(stmt.pretty(allocator)) - .append(allocator.text(";")) - .append(allocator.hardline()); - } - - result - } -} - -/// A thing that can sit at the top level of a file. -/// -/// For the moment, these are statements and functions. Other things -/// will likely be added in the future, but for now: just statements -/// and functions -#[derive(Debug)] -pub enum TopLevel { - Statement(Statement), - Function(Variable, Vec, Vec, Expression), -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b TopLevel -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - TopLevel::Function(name, args, stmts, expr) => { - let base = allocator - .text("function") - .append(allocator.space()) - .append(allocator.text(name.as_ref().to_string())) - .append(allocator.space()) - .append( - pretty_comma_separated( - allocator, - &args.iter().map(PrettySymbol::from).collect(), - ) - .parens(), - ) - .append(allocator.space()); - - let mut body = allocator.nil(); - for stmt in stmts { - body = body - .append(stmt.pretty(allocator)) - .append(allocator.text(";")) - .append(allocator.hardline()); - } - body = body.append(expr.pretty(allocator)); - body = body.append(allocator.hardline()); - body = body.braces(); - base.append(body) - } - TopLevel::Statement(stmt) => stmt.pretty(allocator), - } - } -} - -/// The representation of a statement in the language. -/// -/// For now, this is either a binding site (`x = 4`) or a print statement -/// (`print x`). Someday, though, more! -/// -/// As with `Program`, this type implements [`Pretty`], which should -/// be used to display the structure whenever possible. It does not -/// implement [`Arbitrary`], though, mostly because it's slightly -/// complicated to do so. -/// -#[derive(Debug)] -pub enum Statement { - Binding(Location, Variable, Type, Expression), - Print(Location, Type, Variable), -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - Statement::Binding(_, var, _, expr) => allocator - .text(var.as_ref().to_string()) - .append(allocator.space()) - .append(allocator.text("=")) - .append(allocator.space()) - .append(expr.pretty(allocator)), - Statement::Print(_, _, var) => allocator - .text("print") - .append(allocator.space()) - .append(allocator.text(var.as_ref().to_string())), - } - } -} - -/// The representation of an expression. -/// -/// Note that expressions, like everything else in this syntax tree, -/// supports [`Pretty`], and it's strongly encouraged that you use -/// that trait/module when printing these structures. -/// -/// Also, Expressions at this point in the compiler are explicitly -/// defined so that they are *not* recursive. By this point, if an -/// expression requires some other data (like, for example, invoking -/// a primitive), any subexpressions have been bound to variables so -/// that the referenced data will always either be a constant or a -/// variable reference. -#[derive(Debug, PartialEq)] -pub enum Expression { - Atomic(ValueOrRef), - Cast(Location, Type, ValueOrRef), - Primitive(Location, Type, Primitive, Vec), -} - -impl Expression { - /// Return a reference to the type of the expression, as inferred or recently - /// computed. - pub fn type_of(&self) -> &Type { - match self { - Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, - Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, - Expression::Cast(_, t, _) => t, - Expression::Primitive(_, t, _, _) => t, - } - } - - /// Return a reference to the location associated with the expression. - pub fn location(&self) -> &Location { - match self { - Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, - Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, - Expression::Cast(l, _, _) => l, - Expression::Primitive(l, _, _, _) => l, - } - } -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - Expression::Atomic(x) => x.pretty(allocator), - Expression::Cast(_, t, e) => allocator - .text("<") - .append(t.pretty(allocator)) - .append(allocator.text(">")) - .append(e.pretty(allocator)), - Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => { - op.pretty(allocator).append(exprs[0].pretty(allocator)) - } - Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => { - let left = exprs[0].pretty(allocator); - let right = exprs[1].pretty(allocator); - - left.append(allocator.space()) - .append(op.pretty(allocator)) - .append(allocator.space()) - .append(right) - .parens() - } - Expression::Primitive(_, _, op, exprs) => { - allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len())) - } - } - } -} - -/// An expression that is always either a value or a reference. -/// -/// This is the type used to guarantee that we don't nest expressions -/// at this level. Instead, expressions that take arguments take one -/// of these, which can only be a constant or a reference. -#[derive(Clone, Debug, PartialEq)] -pub enum ValueOrRef { - Value(Location, Type, Value), - Ref(Location, Type, ArcIntern), -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - ValueOrRef::Value(_, _, v) => v.pretty(allocator), - ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), - } - } -} - -impl From for Expression { - fn from(value: ValueOrRef) -> Self { - Expression::Atomic(value) - } -} - -/// A constant in the IR. -/// -/// The optional argument in numeric types is the base that was used by the -/// user to input the number. By retaining it, we can ensure that if we need -/// to print the number back out, we can do so in the form that the user -/// entered it. -#[derive(Clone, Debug, PartialEq)] -pub enum Value { - Unknown(Option, u64), - I8(Option, i8), - I16(Option, i16), - I32(Option, i32), - I64(Option, i64), - U8(Option, u8), - U16(Option, u16), - U32(Option, u32), - U64(Option, u64), -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - let pretty_internal = |opt_base: &Option, x, t| { - syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator) - }; - - let pretty_internal_signed = |opt_base, x: i64, t| { - let base = pretty_internal(opt_base, x.unsigned_abs(), t); - - allocator.text("-").append(base) - }; - - match self { - Value::Unknown(opt_base, value) => { - pretty_internal_signed(opt_base, *value as i64, ConstantType::U64) - } - Value::I8(opt_base, value) => { - pretty_internal_signed(opt_base, *value as i64, ConstantType::I8) - } - Value::I16(opt_base, value) => { - pretty_internal_signed(opt_base, *value as i64, ConstantType::I16) - } - Value::I32(opt_base, value) => { - pretty_internal_signed(opt_base, *value as i64, ConstantType::I32) - } - Value::I64(opt_base, value) => { - pretty_internal_signed(opt_base, *value, ConstantType::I64) - } - Value::U8(opt_base, value) => { - pretty_internal(opt_base, *value as u64, ConstantType::U8) - } - Value::U16(opt_base, value) => { - pretty_internal(opt_base, *value as u64, ConstantType::U16) - } - Value::U32(opt_base, value) => { - pretty_internal(opt_base, *value as u64, ConstantType::U32) - } - Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64), - } - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum Type { - Variable(Location, ArcIntern), - Primitive(PrimitiveType), - Function(Vec, Box), -} - -impl Type { - pub fn is_concrete(&self) -> bool { - !matches!(self, Type::Variable(_, _)) - } -} - -impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { - match self { - Type::Variable(_, x) => allocator.text(x.to_string()), - Type::Primitive(pt) => allocator.text(format!("{}", pt)), - Type::Function(args, rettype) => { - pretty_comma_separated(allocator, &args.iter().collect()) - .parens() - .append(allocator.space()) - .append(allocator.text("->")) - .append(allocator.space()) - .append(rettype.pretty(allocator)) - } - } - } -} - -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Type::Variable(_, x) => write!(f, "{}", x), - Type::Primitive(pt) => pt.fmt(f), - Type::Function(args, ret) => { - write!(f, "(")?; - let mut argiter = args.iter().peekable(); - while let Some(arg) = argiter.next() { - arg.fmt(f)?; - if argiter.peek().is_some() { - write!(f, ",")?; - } - } - write!(f, "->")?; - ret.fmt(f) - } - } - } -} - -/// Generate a fresh new name based on the given name. -/// -/// The new name is guaranteed to be unique across the entirety of the -/// execution. This is achieved by using characters in the variable name -/// that would not be valid input, and by including a counter that is -/// incremented on every invocation. -pub fn gensym(name: &str) -> ArcIntern { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let new_name = format!( - "<{}:{}>", - name, - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - ); - ArcIntern::new(new_name) -} - -/// Generate a fresh new type; this will be a unique new type variable. -/// -/// The new name is guaranteed to be unique across the entirety of the -/// execution. This is achieved by using characters in the variable name -/// that would not be valid input, and by including a counter that is -/// incremented on every invocation. -pub fn gentype() -> Type { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let name = ArcIntern::new(format!( - "t<{}>", - COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) - )); - - Type::Variable(Location::manufactured(), name) -} diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index e6dc161..d67f3d8 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -1,10 +1,9 @@ -use super::ast as ir; -use super::ast::Type; use crate::eval::PrimitiveType; +use crate::ir; use crate::syntax::{self, ConstantType}; use crate::type_infer::solve::Constraint; +use crate::util::scoped_map::ScopedMap; use internment::ArcIntern; -use std::collections::HashMap; use std::str::FromStr; /// This function takes a syntactic program and converts it into the IR version of the @@ -16,22 +15,22 @@ use std::str::FromStr; /// function can panic. pub fn convert_program( mut program: syntax::Program, - constraint_db: &mut Vec, -) -> ir::Program { +) -> (ir::Program, Vec) { + let mut constraint_db = Vec::new(); let mut items = Vec::new(); - let mut renames = HashMap::new(); - let mut bindings = HashMap::new(); + let mut renames = ScopedMap::new(); + let mut bindings = ScopedMap::new(); for item in program.items.drain(..) { - items.append(&mut convert_top_level( + items.push(convert_top_level( item, - constraint_db, + &mut constraint_db, &mut renames, &mut bindings, )); } - ir::Program { items } + (ir::Program { items }, constraint_db) } /// This function takes a top-level item and converts it into the IR version of the @@ -40,9 +39,9 @@ pub fn convert_program( pub fn convert_top_level( top_level: syntax::TopLevel, constraint_db: &mut Vec, - renames: &mut HashMap, ArcIntern>, - bindings: &mut HashMap, Type>, -) -> Vec { + renames: &mut ScopedMap, ArcIntern>, + bindings: &mut ScopedMap, ir::TypeOrVar>, +) -> ir::TopLevel { match top_level { syntax::TopLevel::Function(name, args, expr) => { // First, let us figure out what we're going to name this function. If the user @@ -59,9 +58,9 @@ pub fn convert_top_level( // Now we manufacture types for the inputs and outputs, and then a type for the // function itself. We're not going to make any claims on these types, yet; they're // all just unknown type variables we need to work out. - let argtypes: Vec = args.iter().map(|_| ir::gentype()).collect(); - let rettype = ir::gentype(); - let funtype = Type::Function(argtypes.clone(), Box::new(rettype.clone())); + let argtypes: Vec = args.iter().map(|_| ir::TypeOrVar::new()).collect(); + let rettype = ir::TypeOrVar::new(); + let funtype = ir::TypeOrVar::Function(argtypes.clone(), Box::new(rettype.clone())); // Now let's bind these types into the environment. First, we bind our function // namae to the function type we just generated. @@ -71,20 +70,20 @@ pub fn convert_top_level( let iargs: Vec> = args.iter().map(|x| ArcIntern::new(x.to_string())).collect(); assert_eq!(argtypes.len(), iargs.len()); + let mut function_args = vec![]; for (arg_name, arg_type) in iargs.iter().zip(argtypes) { bindings.insert(arg_name.clone(), arg_type.clone()); + function_args.push((arg_name.clone(), arg_type)); } - let (stmts, expr, ty) = convert_expression(expr, constraint_db, renames, bindings); - constraint_db.push(Constraint::Equivalent(expr.location().clone(), rettype, ty)); + let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); + constraint_db.push(Constraint::Equivalent(expr.location().clone(), rettype.clone(), ty)); - vec![ir::TopLevel::Function(funname, iargs, stmts, expr)] + ir::TopLevel::Function(funname, function_args, rettype, expr) } + syntax::TopLevel::Statement(stmt) => { - convert_statement(stmt, constraint_db, renames, bindings) - .drain(..) - .map(ir::TopLevel::Statement) - .collect() + ir::TopLevel::Statement(convert_statement(stmt, constraint_db, renames, bindings)) } } } @@ -103,9 +102,9 @@ pub fn convert_top_level( fn convert_statement( statement: syntax::Statement, constraint_db: &mut Vec, - renames: &mut HashMap, ArcIntern>, - bindings: &mut HashMap, Type>, -) -> Vec { + renames: &mut ScopedMap, ArcIntern>, + bindings: &mut ScopedMap, ir::TypeOrVar>, +) -> ir::Expression { match statement { syntax::Statement::Print(loc, name) => { let iname = ArcIntern::new(name.to_string()); @@ -120,17 +119,14 @@ fn convert_statement( constraint_db.push(Constraint::Printable(loc.clone(), varty.clone())); - vec![ir::Statement::Print(loc, varty, iname)] + ir::Expression::Print(loc, final_name) } syntax::Statement::Binding(loc, name, expr) => { - let (mut prereqs, expr, ty) = - convert_expression(expr, constraint_db, renames, bindings); + let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); let final_name = finalize_name(bindings, renames, name); - bindings.insert(final_name.clone(), ty.clone()); - prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); - prereqs + ir::Expression::Bind(loc, final_name, ty, Box::new(expr)) } } } @@ -149,16 +145,18 @@ fn convert_statement( fn convert_expression( expression: syntax::Expression, constraint_db: &mut Vec, - renames: &HashMap, ArcIntern>, - bindings: &mut HashMap, Type>, -) -> (Vec, ir::Expression, Type) { + renames: &ScopedMap, ArcIntern>, + bindings: &mut ScopedMap, ir::TypeOrVar>, +) -> (ir::Expression, ir::TypeOrVar) { match expression { + // converting values is mostly tedious, because there's so many cases + // involved syntax::Expression::Value(loc, val) => match val { syntax::Value::Number(base, mctype, value) => { let (newval, newtype) = match mctype { None => { - let newtype = ir::gentype(); - let newval = ir::Value::Unknown(base, value); + let newtype = ir::TypeOrVar::new(); + let newval = ir::Value::U64(base, value); constraint_db.push(Constraint::ConstantNumericType( loc.clone(), @@ -168,35 +166,35 @@ fn convert_expression( } Some(ConstantType::U8) => ( ir::Value::U8(base, value as u8), - ir::Type::Primitive(PrimitiveType::U8), + ir::TypeOrVar::Primitive(PrimitiveType::U8), ), Some(ConstantType::U16) => ( ir::Value::U16(base, value as u16), - ir::Type::Primitive(PrimitiveType::U16), + ir::TypeOrVar::Primitive(PrimitiveType::U16), ), Some(ConstantType::U32) => ( ir::Value::U32(base, value as u32), - ir::Type::Primitive(PrimitiveType::U32), + ir::TypeOrVar::Primitive(PrimitiveType::U32), ), Some(ConstantType::U64) => ( ir::Value::U64(base, value), - ir::Type::Primitive(PrimitiveType::U64), + ir::TypeOrVar::Primitive(PrimitiveType::U64), ), Some(ConstantType::I8) => ( ir::Value::I8(base, value as i8), - ir::Type::Primitive(PrimitiveType::I8), + ir::TypeOrVar::Primitive(PrimitiveType::I8), ), Some(ConstantType::I16) => ( ir::Value::I16(base, value as i16), - ir::Type::Primitive(PrimitiveType::I16), + ir::TypeOrVar::Primitive(PrimitiveType::I16), ), Some(ConstantType::I32) => ( ir::Value::I32(base, value as i32), - ir::Type::Primitive(PrimitiveType::I32), + ir::TypeOrVar::Primitive(PrimitiveType::I32), ), Some(ConstantType::I64) => ( ir::Value::I64(base, value as i64), - ir::Type::Primitive(PrimitiveType::I64), + ir::TypeOrVar::Primitive(PrimitiveType::I64), ), }; @@ -206,7 +204,6 @@ fn convert_expression( value, )); ( - vec![], ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), newtype, ) @@ -223,35 +220,37 @@ fn convert_expression( let refexp = ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name)); - (vec![], refexp, rtype) + (refexp, rtype) } syntax::Expression::Cast(loc, target, expr) => { - let (mut stmts, nexpr, etype) = - convert_expression(*expr, constraint_db, renames, bindings); - let val_or_ref = simplify_expr(nexpr, &mut stmts); - let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast"); - let target_type = Type::Primitive(target_prim_type); + let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings); + let (prereqs, val_or_ref) = simplify_expr(nexpr); + let target_type: ir::TypeOrVar = PrimitiveType::from_str(&target) + .expect("valid type for cast") + .into(); let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); - (stmts, res, target_type) + (finalize_expression(prereqs, res), target_type) } syntax::Expression::Primitive(loc, fun, mut args) => { let primop = ir::Primitive::from_str(&fun).expect("valid primitive"); - let mut stmts = vec![]; + let mut prereqs = vec![]; let mut nargs = vec![]; let mut atypes = vec![]; - let ret_type = ir::gentype(); + let ret_type = ir::TypeOrVar::new(); for arg in args.drain(..) { - let (mut astmts, aexp, atype) = - convert_expression(arg, constraint_db, renames, bindings); + let (aexp, atype) = convert_expression(arg, constraint_db, renames, bindings); + let (aprereqs, asimple) = simplify_expr(aexp); - stmts.append(&mut astmts); - nargs.push(simplify_expr(aexp, &mut stmts)); + if let Some(prereq) = aprereqs { + prereqs.push(prereq); + } + nargs.push(asimple); atypes.push(atype); } @@ -262,33 +261,56 @@ fn convert_expression( ret_type.clone(), )); - ( - stmts, - ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs), - ret_type, - ) + let last_call = ir::Expression::Primitive(loc.clone(), ret_type.clone(), primop, nargs); + + if prereqs.is_empty() { + (last_call, ret_type) + } else { + prereqs.push(last_call); + (ir::Expression::Block(loc, ret_type.clone(), prereqs), ret_type) + } } } } -fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::ValueOrRef { +fn simplify_expr( + expr: ir::Expression, +) -> ( + Option>, + ir::ValueOrRef, +) { match expr { - ir::Expression::Atomic(v_or_ref) => v_or_ref, + ir::Expression::Atomic(v_or_ref) => (None, v_or_ref), expr => { let etype = expr.type_of().clone(); let loc = expr.location().clone(); let nname = ir::gensym("g"); - let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr); + let nbinding = + ir::Expression::Bind(loc.clone(), nname.clone(), etype.clone(), Box::new(expr)); - stmts.push(nbinding); - ir::ValueOrRef::Ref(loc, etype, nname) + (Some(nbinding), ir::ValueOrRef::Ref(loc, etype, nname)) } } } +fn finalize_expression( + prereq: Option>, + actual: ir::Expression, +) -> ir::Expression { + if let Some(prereq) = prereq { + ir::Expression::Block( + prereq.location().clone(), + actual.type_of().clone(), + vec![prereq, actual], + ) + } else { + actual + } +} + fn finalize_name( - bindings: &HashMap, Type>, - renames: &mut HashMap, ArcIntern>, + bindings: &ScopedMap, ir::TypeOrVar>, + renames: &mut ScopedMap, ArcIntern>, name: syntax::Name, ) -> ArcIntern { if bindings.contains_key(&ArcIntern::new(name.name.clone())) { @@ -302,139 +324,139 @@ fn finalize_name( #[cfg(test)] mod tests { - use super::*; - use crate::syntax::Location; - - fn one() -> syntax::Expression { - syntax::Expression::Value( - Location::manufactured(), - syntax::Value::Number(None, None, 1), - ) - } - - fn vec_contains bool>(x: &[T], f: F) -> bool { - for x in x.iter() { - if f(x) { - return true; - } - } - false - } - - fn infer_expression( - x: syntax::Expression, - ) -> (ir::Expression, Vec, Vec, Type) { - let mut constraints = Vec::new(); - let renames = HashMap::new(); - let mut bindings = HashMap::new(); - let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings); - (expr, stmts, constraints, ty) - } - - fn infer_top_level(x: syntax::TopLevel) -> (Vec, Vec) { - let mut constraints = Vec::new(); - let mut renames = HashMap::new(); - let mut bindings = HashMap::new(); - let res = convert_top_level(x, &mut constraints, &mut renames, &mut bindings); - (res, constraints) - } - - #[test] - fn constant_one() { - let (expr, stmts, constraints, ty) = infer_expression(one()); - assert!(stmts.is_empty()); - assert!(matches!( - expr, - ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1))) - )); - assert!(vec_contains(&constraints, |x| matches!( - x, - Constraint::FitsInNumType(_, _, 1) - ))); - assert!(vec_contains( - &constraints, - |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty) - )); - } - - #[test] - fn one_plus_one() { - let opo = syntax::Expression::Primitive( - Location::manufactured(), - "+".to_string(), - vec![one(), one()], - ); - let (expr, stmts, constraints, ty) = infer_expression(opo); - assert!(stmts.is_empty()); - assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty)); - assert!(vec_contains(&constraints, |x| matches!( - x, - Constraint::FitsInNumType(_, _, 1) - ))); - assert!(vec_contains( - &constraints, - |x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty) - )); - assert!(vec_contains( - &constraints, - |x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty) - )); - } - - #[test] - fn one_plus_one_plus_one() { - let stmt = syntax::TopLevel::parse(1, "x = 1 + 1 + 1;").expect("basic parse"); - let (stmts, constraints) = infer_top_level(stmt); - assert_eq!(stmts.len(), 2); - let ir::TopLevel::Statement(ir::Statement::Binding( - _args, - name1, - temp_ty1, - ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1), - )) = stmts.get(0).expect("item two") - else { - panic!("Failed to match first statement"); - }; - let ir::TopLevel::Statement(ir::Statement::Binding( - _args, - name2, - temp_ty2, - ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2), - )) = stmts.get(1).expect("item two") - else { - panic!("Failed to match second statement"); - }; - let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] = - &primargs1[..] - else { - panic!("Failed to match first arguments"); - }; - let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] = - &primargs2[..] - else { - panic!("Failed to match first arguments"); - }; - assert_ne!(name1, name2); - assert_ne!(temp_ty1, temp_ty2); - assert_ne!(primty1, primty2); - assert_eq!(name1, left2name); - assert!(vec_contains( - &constraints, - |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty) - )); - assert!(vec_contains( - &constraints, - |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty) - )); - assert!(vec_contains( - &constraints, - |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty) - )); - for (i, s) in stmts.iter().enumerate() { - println!("{}: {:?}", i, s); - } - for (i, c) in constraints.iter().enumerate() { - println!("{}: {:?}", i, c); - } - } + // use super::*; + // use crate::syntax::Location; + // + // fn one() -> syntax::Expression { + // syntax::Expression::Value( + // Location::manufactured(), + // syntax::Value::Number(None, None, 1), + // ) + // } + // + // fn vec_contains bool>(x: &[T], f: F) -> bool { + // for x in x.iter() { + // if f(x) { + // return true; + // } + // } + // false + // } + // + // fn infer_expression( + // x: syntax::Expression, + // ) -> (ir::Expression, Vec, Vec, Type) { + // let mut constraints = Vec::new(); + // let renames = HashMap::new(); + // let mut bindings = HashMap::new(); + // let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings); + // (expr, stmts, constraints, ty) + // } + // + // fn infer_top_level(x: syntax::TopLevel) -> (Vec, Vec) { + // let mut constraints = Vec::new(); + // let mut renames = HashMap::new(); + // let mut bindings = HashMap::new(); + // let res = convert_top_level(x, &mut constraints, &mut renames, &mut bindings); + // (res, constraints) + // } + // + // #[test] + // fn constant_one() { + // let (expr, stmts, constraints, ty) = infer_expression(one()); + // assert!(stmts.is_empty()); + // assert!(matches!( + // expr, + // ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1))) + // )); + // assert!(vec_contains(&constraints, |x| matches!( + // x, + // Constraint::FitsInNumType(_, _, 1) + // ))); + // assert!(vec_contains( + // &constraints, + // |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty) + // )); + // } + // + // #[test] + // fn one_plus_one() { + // let opo = syntax::Expression::Primitive( + // Location::manufactured(), + // "+".to_string(), + // vec![one(), one()], + // ); + // let (expr, stmts, constraints, ty) = infer_expression(opo); + // assert!(stmts.is_empty()); + // assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty)); + // assert!(vec_contains(&constraints, |x| matches!( + // x, + // Constraint::FitsInNumType(_, _, 1) + // ))); + // assert!(vec_contains( + // &constraints, + // |x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty) + // )); + // assert!(vec_contains( + // &constraints, + // |x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty) + // )); + // } + // + // #[test] + // fn one_plus_one_plus_one() { + // let stmt = syntax::TopLevel::parse(1, "x = 1 + 1 + 1;").expect("basic parse"); + // let (stmts, constraints) = infer_top_level(stmt); + // assert_eq!(stmts.len(), 2); + // let ir::TopLevel::Statement(ir::Statement::Binding( + // _args, + // name1, + // temp_ty1, + // ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1), + // )) = stmts.get(0).expect("item two") + // else { + // panic!("Failed to match first statement"); + // }; + // let ir::TopLevel::Statement(ir::Statement::Binding( + // _args, + // name2, + // temp_ty2, + // ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2), + // )) = stmts.get(1).expect("item two") + // else { + // panic!("Failed to match second statement"); + // }; + // let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] = + // &primargs1[..] + // else { + // panic!("Failed to match first arguments"); + // }; + // let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] = + // &primargs2[..] + // else { + // panic!("Failed to match first arguments"); + // }; + // assert_ne!(name1, name2); + // assert_ne!(temp_ty1, temp_ty2); + // assert_ne!(primty1, primty2); + // assert_eq!(name1, left2name); + // assert!(vec_contains( + // &constraints, + // |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty) + // )); + // assert!(vec_contains( + // &constraints, + // |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty) + // )); + // assert!(vec_contains( + // &constraints, + // |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty) + // )); + // for (i, s) in stmts.iter().enumerate() { + // println!("{}: {:?}", i, s); + // } + // for (i, c) in constraints.iter().enumerate() { + // println!("{}: {:?}", i, c); + // } + // } } diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 763574c..577f4d8 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -1,11 +1,12 @@ -use super::{ast as input, solve::TypeResolutions}; -use crate::{eval::PrimitiveType, ir as output}; +use super::solve::TypeResolutions; +use crate::eval::PrimitiveType; +use crate::ir::{Expression, Program, TopLevel, Type, TypeOrVar, Value, ValueOrRef}; pub fn finalize_program( - mut program: input::Program, + mut program: Program, resolutions: &TypeResolutions, -) -> output::Program { - output::Program { +) -> Program { + Program { items: program .items .drain(..) @@ -14,53 +15,36 @@ pub fn finalize_program( } } -fn finalize_top_level(item: input::TopLevel, resolutions: &TypeResolutions) -> output::TopLevel { +fn finalize_top_level(item: TopLevel, resolutions: &TypeResolutions) -> TopLevel { match item { - input::TopLevel::Function(name, args, mut body, expr) => output::TopLevel::Function( - name, - args, - body.drain(..) - .map(|x| finalize_statement(x, resolutions)) - .collect(), - finalize_expression(expr, resolutions), - ), - input::TopLevel::Statement(stmt) => { - output::TopLevel::Statement(finalize_statement(stmt, resolutions)) - } - } -} - -fn finalize_statement( - statement: input::Statement, - resolutions: &TypeResolutions, -) -> output::Statement { - match statement { - input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding( - loc, - var, - finalize_type(ty, resolutions), - finalize_expression(expr, resolutions), - ), - input::Statement::Print(loc, ty, var) => { - output::Statement::Print(loc, finalize_type(ty, resolutions), var) + TopLevel::Function(name, args, rettype, expr) => { + TopLevel::Function( + name, + args.into_iter().map(|(name, t)| (name, finalize_type(t, resolutions))).collect(), + finalize_type(rettype, resolutions), + finalize_expression(expr, resolutions) + ) } + TopLevel::Statement(expr) => TopLevel::Statement(finalize_expression(expr, resolutions)), } } fn finalize_expression( - expression: input::Expression, + expression: Expression, resolutions: &TypeResolutions, -) -> output::Expression { +) -> Expression { match expression { - input::Expression::Atomic(val_or_ref) => { - output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions)) + Expression::Atomic(val_or_ref) => { + Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions)) } - input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast( + + Expression::Cast(loc, target, val_or_ref) => Expression::Cast( loc, finalize_type(target, resolutions), finalize_val_or_ref(val_or_ref, resolutions), ), - input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive( + + Expression::Primitive(loc, ty, prim, mut args) => Expression::Primitive( loc, finalize_type(ty, resolutions), prim, @@ -68,17 +52,42 @@ fn finalize_expression( .map(|x| finalize_val_or_ref(x, resolutions)) .collect(), ), + + Expression::Block(loc, ty, mut exprs) => { + let mut final_exprs = Vec::with_capacity(exprs.len()); + + for expr in exprs { + let newexpr = finalize_expression(expr, resolutions); + + if let Expression::Block(_, _, mut subexprs) = newexpr { + final_exprs.append(&mut subexprs); + } else { + final_exprs.push(newexpr); + } + } + + Expression::Block(loc, finalize_type(ty, resolutions), final_exprs) + } + + Expression::Print(loc, var) => Expression::Print(loc, var), + + Expression::Bind(loc, var, ty, subexp) => Expression::Bind( + loc, + var, + finalize_type(ty, resolutions), + Box::new(finalize_expression(*subexp, resolutions)), + ), } } -fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type { +fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type { match ty { - input::Type::Primitive(x) => output::Type::Primitive(x), - input::Type::Variable(_, tvar) => match resolutions.get(&tvar) { + TypeOrVar::Primitive(x) => Type::Primitive(x), + TypeOrVar::Variable(_, tvar) => match resolutions.get(&tvar) { None => panic!("Did not resolve type for type variable {}", tvar), - Some(pt) => output::Type::Primitive(*pt), + Some(pt) => Type::Primitive(*pt), }, - input::Type::Function(mut args, ret) => output::Type::Function( + TypeOrVar::Function(mut args, ret) => Type::Function( args.drain(..) .map(|x| finalize_type(x, resolutions)) .collect(), @@ -88,123 +97,82 @@ fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type } fn finalize_val_or_ref( - valref: input::ValueOrRef, + valref: ValueOrRef, resolutions: &TypeResolutions, -) -> output::ValueOrRef { +) -> ValueOrRef { match valref { - input::ValueOrRef::Ref(loc, ty, var) => { - output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var) - } - input::ValueOrRef::Value(loc, ty, val) => { + ValueOrRef::Ref(loc, ty, var) => ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var), + ValueOrRef::Value(loc, ty, val) => { let new_type = finalize_type(ty, resolutions); match val { - input::Value::Unknown(base, value) => match new_type { - output::Type::Function(_, _) => { + // U64 is essentially "unknown" for us, so we use the inferred type + Value::U64(base, value) => match new_type { + Type::Function(_, _) => { panic!("Somehow inferred that a constant was a function") } - output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::U8(base, value as u8), - ), - output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::U16(base, value as u16), - ), - output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::U32(base, value as u32), - ), - output::Type::Primitive(PrimitiveType::U64) => { - output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) + Type::Primitive(PrimitiveType::Void) => { + panic!("Somehow inferred that a constant was void") + } + Type::Primitive(PrimitiveType::U8) => { + ValueOrRef::Value(loc, new_type, Value::U8(base, value as u8)) + } + Type::Primitive(PrimitiveType::U16) => { + ValueOrRef::Value(loc, new_type, Value::U16(base, value as u16)) + } + Type::Primitive(PrimitiveType::U32) => { + ValueOrRef::Value(loc, new_type, Value::U32(base, value as u32)) + } + Type::Primitive(PrimitiveType::U64) => { + ValueOrRef::Value(loc, new_type, Value::U64(base, value)) + } + Type::Primitive(PrimitiveType::I8) => { + ValueOrRef::Value(loc, new_type, Value::I8(base, value as i8)) + } + Type::Primitive(PrimitiveType::I16) => { + ValueOrRef::Value(loc, new_type, Value::I16(base, value as i16)) + } + Type::Primitive(PrimitiveType::I32) => { + ValueOrRef::Value(loc, new_type, Value::I32(base, value as i32)) + } + Type::Primitive(PrimitiveType::I64) => { + ValueOrRef::Value(loc, new_type, Value::I64(base, value as i64)) } - output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::I8(base, value as i8), - ), - output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::I16(base, value as i16), - ), - output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::I32(base, value as i32), - ), - output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::I64(base, value as i64), - ), }, - input::Value::U8(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::U8) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value)) + Value::U8(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::U8))); + ValueOrRef::Value(loc, new_type, Value::U8(base, value)) } - input::Value::U16(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::U16) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value)) + Value::U16(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::U16))); + ValueOrRef::Value(loc, new_type, Value::U16(base, value)) } - input::Value::U32(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::U32) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value)) + Value::U32(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::U32))); + ValueOrRef::Value(loc, new_type, Value::U32(base, value)) } - input::Value::U64(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::U64) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) + Value::I8(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::I8))); + ValueOrRef::Value(loc, new_type, Value::I8(base, value)) } - input::Value::I8(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::I8) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value)) + Value::I16(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::I16))); + ValueOrRef::Value(loc, new_type, Value::I16(base, value)) } - input::Value::I16(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::I16) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value)) + Value::I32(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::I32))); + ValueOrRef::Value(loc, new_type, Value::I32(base, value)) } - input::Value::I32(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::I32) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value)) - } - - input::Value::I64(base, value) => { - assert!(matches!( - new_type, - output::Type::Primitive(PrimitiveType::I64) - )); - output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value)) + Value::I64(base, value) => { + assert!(matches!(new_type, Type::Primitive(PrimitiveType::I64))); + ValueOrRef::Value(loc, new_type, Value::I64(base, value)) } } } diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index 76d072b..f3ba8c8 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -1,6 +1,6 @@ -use super::ast as ir; -use super::ast::Type; -use crate::{eval::PrimitiveType, syntax::Location}; +use crate::eval::PrimitiveType; +use crate::ir::{Primitive, TypeOrVar}; +use crate::syntax::Location; use codespan_reporting::diagnostic::Diagnostic; use internment::ArcIntern; use std::{collections::HashMap, fmt}; @@ -9,22 +9,22 @@ use std::{collections::HashMap, fmt}; #[derive(Debug)] pub enum Constraint { /// The given type must be printable using the `print` built-in - Printable(Location, Type), + Printable(Location, TypeOrVar), /// The provided numeric value fits in the given constant type - FitsInNumType(Location, Type, u64), + FitsInNumType(Location, TypeOrVar, u64), /// The given primitive has the proper arguments types associated with it - ProperPrimitiveArgs(Location, ir::Primitive, Vec, Type), + ProperPrimitiveArgs(Location, Primitive, Vec, TypeOrVar), /// The given type can be casted to the target type safely - CanCastTo(Location, Type, Type), + CanCastTo(Location, TypeOrVar, TypeOrVar), /// The given type must be some numeric type, but this is not a constant /// value, so don't try to default it if we can't figure it out - NumericType(Location, Type), + NumericType(Location, TypeOrVar), /// The given type is attached to a constant and must be some numeric type. /// If we can't figure it out, we should warn the user and then just use a /// default. - ConstantNumericType(Location, Type), + ConstantNumericType(Location, TypeOrVar), /// The two types should be equivalent - Equivalent(Location, Type, Type), + Equivalent(Location, TypeOrVar, TypeOrVar), } impl fmt::Display for Constraint { @@ -101,20 +101,22 @@ impl TypeInferenceResult { pub enum TypeInferenceError { /// The user provide a constant that is too large for its inferred type. ConstantTooLarge(Location, PrimitiveType, u64), + /// Somehow we're trying to use a non-number as a number + NotANumber(Location, PrimitiveType), /// The two types needed to be equivalent, but weren't. - NotEquivalent(Location, Type, Type), + NotEquivalent(Location, TypeOrVar, TypeOrVar), /// We cannot safely cast the first type to the second type. CannotSafelyCast(Location, PrimitiveType, PrimitiveType), /// The primitive invocation provided the wrong number of arguments. - WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize), + WrongPrimitiveArity(Location, Primitive, usize, usize, usize), /// We cannot cast between function types at the moment. - CannotCastBetweenFunctinoTypes(Location, Type, Type), + CannotCastBetweenFunctinoTypes(Location, TypeOrVar, TypeOrVar), /// We cannot cast from a function type to something else. - CannotCastFromFunctionType(Location, Type), + CannotCastFromFunctionType(Location, TypeOrVar), /// We cannot cast to a function type from something else. - CannotCastToFunctionType(Location, Type), + CannotCastToFunctionType(Location, TypeOrVar), /// We cannot turn a number into a function. - CannotMakeNumberAFunction(Location, Type, Option), + CannotMakeNumberAFunction(Location, TypeOrVar, Option), /// We had a constraint we just couldn't solve. CouldNotSolve(Constraint), } @@ -127,9 +129,15 @@ impl From for Diagnostic { .with_message(format!( "Type {} has a max value of {}, which is smaller than {}", primty, - primty.max_value(), + primty.max_value().expect("constant type has max value"), value )), + TypeInferenceError::NotANumber(loc, primty) => loc + .labelled_error("not a numeric type") + .with_message(format!( + "For some reason, we're trying to use {} as a numeric type", + primty, + )), TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc .labelled_error("type inference error") .with_message(format!("Expected type {}, received type {}", ty1, ty2)), @@ -214,7 +222,7 @@ impl From for Diagnostic { /// These are fine, probably, but could indicate some behavior the user might not /// expect, and so they might want to do something about them. pub enum TypeInferenceWarning { - DefaultedTo(Location, Type), + DefaultedTo(Location, TypeOrVar), } impl From for Diagnostic { @@ -270,7 +278,11 @@ pub fn solve_constraints( // Case #1a: We have two primitive types. If they're equal, we've discharged this // constraint! We can just continue. If they're not equal, add an error and then // see what else we come up with. - Constraint::Equivalent(loc, a @ Type::Primitive(_), b @ Type::Primitive(_)) => { + Constraint::Equivalent( + loc, + a @ TypeOrVar::Primitive(_), + b @ TypeOrVar::Primitive(_), + ) => { if a != b { errors.push(TypeInferenceError::NotEquivalent(loc, a, b)); } @@ -281,8 +293,16 @@ pub fn solve_constraints( // In this case, we'll check to see if we've resolved the variable, and check for // equivalence if we have. If we haven't, we'll set that variable to be primitive // type. - Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name)) - | Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => { + Constraint::Equivalent( + loc, + TypeOrVar::Primitive(t), + TypeOrVar::Variable(_, name), + ) + | Constraint::Equivalent( + loc, + TypeOrVar::Variable(_, name), + TypeOrVar::Primitive(t), + ) => { match resolutions.get(&name) { None => { resolutions.insert(name, t); @@ -290,8 +310,8 @@ pub fn solve_constraints( Some(t2) if &t == t2 => {} Some(t2) => errors.push(TypeInferenceError::NotEquivalent( loc, - Type::Primitive(t), - Type::Primitive(*t2), + TypeOrVar::Primitive(t), + TypeOrVar::Primitive(*t2), )), } changed_something = true; @@ -301,8 +321,8 @@ pub fn solve_constraints( // check, but now on their resolutions. Constraint::Equivalent( ref loc, - Type::Variable(_, ref name1), - Type::Variable(_, ref name2), + TypeOrVar::Variable(_, ref name1), + TypeOrVar::Variable(_, ref name2), ) => match (resolutions.get(name1), resolutions.get(name2)) { (None, None) => { constraint_db.push(constraint); @@ -321,8 +341,8 @@ pub fn solve_constraints( (Some(pt1), Some(pt2)) => { errors.push(TypeInferenceError::NotEquivalent( loc.clone(), - Type::Primitive(*pt1), - Type::Primitive(*pt2), + TypeOrVar::Primitive(*pt1), + TypeOrVar::Primitive(*pt2), )); changed_something = true; } @@ -339,8 +359,8 @@ pub fn solve_constraints( // function types. Constraint::Equivalent( loc, - ref a @ Type::Function(ref args1, ref ret1), - ref b @ Type::Function(ref args2, ref ret2), + ref a @ TypeOrVar::Function(ref args1, ref ret1), + ref b @ TypeOrVar::Function(ref args2, ref ret2), ) => { if args1.len() != args2.len() { errors.push(TypeInferenceError::NotEquivalent( @@ -377,27 +397,33 @@ pub fn solve_constraints( // Make sure that the provided number fits within the provided constant type. For the // moment, we're going to call an error here a failure, although this could be a // warning in the future. - Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => { - if ctype.max_value() < val { - errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); - } + Constraint::FitsInNumType(loc, TypeOrVar::Primitive(ctype), val) => { + match ctype.max_value() { + None => { + errors.push(TypeInferenceError::NotANumber(loc, ctype)); + } + Some(max_value) if max_value < val => { + errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); + } + Some(_) => {} + }; changed_something = true; } // If we have a non-constant type, then let's see if we can advance this to a constant // type - Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => { + Constraint::FitsInNumType(loc, TypeOrVar::Variable(vloc, var), val) => { match resolutions.get(&var) { None => constraint_db.push(Constraint::FitsInNumType( loc, - Type::Variable(vloc, var), + TypeOrVar::Variable(vloc, var), val, )), Some(nt) => { constraint_db.push(Constraint::FitsInNumType( loc, - Type::Primitive(*nt), + TypeOrVar::Primitive(*nt), val, )); changed_something = true; @@ -406,7 +432,7 @@ pub fn solve_constraints( } // Function types definitely do not fit in numeric types - Constraint::FitsInNumType(loc, t @ Type::Function(_, _), val) => { + Constraint::FitsInNumType(loc, t @ TypeOrVar::Function(_, _), val) => { errors.push(TypeInferenceError::CannotMakeNumberAFunction( loc, t.clone(), @@ -416,17 +442,17 @@ pub fn solve_constraints( // If the left type in a "can cast to" check is a variable, let's see if we can advance // it into something more tangible - Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => { + Constraint::CanCastTo(loc, TypeOrVar::Variable(vloc, var), to_type) => { match resolutions.get(&var) { None => constraint_db.push(Constraint::CanCastTo( loc, - Type::Variable(vloc, var), + TypeOrVar::Variable(vloc, var), to_type, )), Some(nt) => { constraint_db.push(Constraint::CanCastTo( loc, - Type::Primitive(*nt), + TypeOrVar::Primitive(*nt), to_type, )); changed_something = true; @@ -435,18 +461,18 @@ pub fn solve_constraints( } // If the right type in a "can cast to" check is a variable, same deal - Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => { + Constraint::CanCastTo(loc, from_type, TypeOrVar::Variable(vloc, var)) => { match resolutions.get(&var) { None => constraint_db.push(Constraint::CanCastTo( loc, from_type, - Type::Variable(vloc, var), + TypeOrVar::Variable(vloc, var), )), Some(nt) => { constraint_db.push(Constraint::CanCastTo( loc, from_type, - Type::Primitive(*nt), + TypeOrVar::Primitive(*nt), )); changed_something = true; } @@ -456,8 +482,8 @@ pub fn solve_constraints( // If both of them are types, then we can actually do the test. yay! Constraint::CanCastTo( loc, - Type::Primitive(from_type), - Type::Primitive(to_type), + TypeOrVar::Primitive(from_type), + TypeOrVar::Primitive(to_type), ) => { if !from_type.can_cast_to(&to_type) { errors.push(TypeInferenceError::CannotSafelyCast( @@ -471,8 +497,8 @@ pub fn solve_constraints( // are equivalent. Constraint::CanCastTo( loc, - t1 @ Type::Function(_, _), - t2 @ Type::Function(_, _), + t1 @ TypeOrVar::Function(_, _), + t2 @ TypeOrVar::Function(_, _), ) => { if t1 != t2 { errors.push(TypeInferenceError::CannotCastBetweenFunctinoTypes( @@ -484,7 +510,11 @@ pub fn solve_constraints( changed_something = true; } - Constraint::CanCastTo(loc, t @ Type::Function(_, _), Type::Primitive(_)) => { + Constraint::CanCastTo( + loc, + t @ TypeOrVar::Function(_, _), + TypeOrVar::Primitive(_), + ) => { errors.push(TypeInferenceError::CannotCastFromFunctionType( loc, t.clone(), @@ -492,19 +522,24 @@ pub fn solve_constraints( changed_something = true; } - Constraint::CanCastTo(loc, Type::Primitive(_), t @ Type::Function(_, _)) => { + Constraint::CanCastTo( + loc, + TypeOrVar::Primitive(_), + t @ TypeOrVar::Function(_, _), + ) => { errors.push(TypeInferenceError::CannotCastToFunctionType(loc, t.clone())); changed_something = true; } // As per usual, if we're trying to test if a type variable is numeric, first // we try to advance it to a primitive - Constraint::NumericType(loc, Type::Variable(vloc, var)) => { + Constraint::NumericType(loc, TypeOrVar::Variable(vloc, var)) => { match resolutions.get(&var) { None => constraint_db - .push(Constraint::NumericType(loc, Type::Variable(vloc, var))), + .push(Constraint::NumericType(loc, TypeOrVar::Variable(vloc, var))), Some(nt) => { - constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt))); + constraint_db + .push(Constraint::NumericType(loc, TypeOrVar::Primitive(*nt))); changed_something = true; } } @@ -512,12 +547,12 @@ pub fn solve_constraints( // Of course, if we get to a primitive type, then it's true, because all of our // primitive types are numbers - Constraint::NumericType(_, Type::Primitive(_)) => { + Constraint::NumericType(_, TypeOrVar::Primitive(_)) => { changed_something = true; } // But functions are definitely not numbers - Constraint::NumericType(loc, t @ Type::Function(_, _)) => { + Constraint::NumericType(loc, t @ TypeOrVar::Function(_, _)) => { errors.push(TypeInferenceError::CannotMakeNumberAFunction( loc, t.clone(), @@ -528,15 +563,17 @@ pub fn solve_constraints( // As per usual, if we're trying to test if a type variable is numeric, first // we try to advance it to a primitive - Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => { + Constraint::ConstantNumericType(loc, TypeOrVar::Variable(vloc, var)) => { match resolutions.get(&var) { None => constraint_db.push(Constraint::ConstantNumericType( loc, - Type::Variable(vloc, var), + TypeOrVar::Variable(vloc, var), )), Some(nt) => { - constraint_db - .push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt))); + constraint_db.push(Constraint::ConstantNumericType( + loc, + TypeOrVar::Primitive(*nt), + )); changed_something = true; } } @@ -544,12 +581,12 @@ pub fn solve_constraints( // Of course, if we get to a primitive type, then it's true, because all of our // primitive types are numbers - Constraint::ConstantNumericType(_, Type::Primitive(_)) => { + Constraint::ConstantNumericType(_, TypeOrVar::Primitive(_)) => { changed_something = true; } // But functions are definitely not numbers - Constraint::ConstantNumericType(loc, t @ Type::Function(_, _)) => { + Constraint::ConstantNumericType(loc, t @ TypeOrVar::Function(_, _)) => { errors.push(TypeInferenceError::CannotMakeNumberAFunction( loc, t.clone(), @@ -565,9 +602,7 @@ pub fn solve_constraints( // find by discovering that the number of arguments provided doesn't make sense // given the primitive being used. Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim { - ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide - if args.len() != 2 => - { + Primitive::Plus | Primitive::Times | Primitive::Divide if args.len() != 2 => { errors.push(TypeInferenceError::WrongPrimitiveArity( loc, prim, @@ -578,7 +613,7 @@ pub fn solve_constraints( changed_something = true; } - ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { + Primitive::Plus | Primitive::Times | Primitive::Divide => { let right = args.pop().expect("2 > 0"); let left = args.pop().expect("2 > 1"); @@ -596,7 +631,7 @@ pub fn solve_constraints( changed_something = true; } - ir::Primitive::Minus if args.is_empty() || args.len() > 2 => { + Primitive::Minus if args.is_empty() || args.len() > 2 => { errors.push(TypeInferenceError::WrongPrimitiveArity( loc, prim, @@ -607,7 +642,7 @@ pub fn solve_constraints( changed_something = true; } - ir::Primitive::Minus if args.len() == 1 => { + Primitive::Minus if args.len() == 1 => { let arg = args.pop().expect("1 > 0"); constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone())); constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); @@ -615,7 +650,7 @@ pub fn solve_constraints( changed_something = true; } - ir::Primitive::Minus => { + Primitive::Minus => { let right = args.pop().expect("2 > 0"); let left = args.pop().expect("2 > 1"); @@ -648,12 +683,12 @@ pub fn solve_constraints( for constraint in local_constraints.drain(..) { match constraint { - Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => { - let resty = Type::Primitive(PrimitiveType::U64); + Constraint::ConstantNumericType(loc, t @ TypeOrVar::Variable(_, _)) => { + let resty = TypeOrVar::Primitive(PrimitiveType::U64); constraint_db.push(Constraint::Equivalent( loc.clone(), t, - Type::Primitive(PrimitiveType::U64), + TypeOrVar::Primitive(PrimitiveType::U64), )); warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty)); changed_something = true; diff --git a/src/util/pretty.rs b/src/util/pretty.rs index e5c478a..2cfa55e 100644 --- a/src/util/pretty.rs +++ b/src/util/pretty.rs @@ -8,7 +8,9 @@ pub struct PrettySymbol { impl<'a> From<&'a ArcIntern> for PrettySymbol { fn from(value: &'a ArcIntern) -> Self { - PrettySymbol { name: value.clone() } + PrettySymbol { + name: value.clone(), + } } }