diff --git a/Cargo.toml b/Cargo.toml index 296f284..4e6b807 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,15 +9,15 @@ name = "ngr" path = "src/lib.rs" [dependencies] -clap = { version = "4.4.11", features = ["derive"] } +clap = { version = "4.4.18", features = ["derive"] } codespan = "0.11.1" codespan-reporting = "0.11.1" -cranelift-codegen = "0.103.0" -cranelift-jit = "0.103.0" -cranelift-frontend = "0.103.0" -cranelift-module = "0.103.0" -cranelift-native = "0.103.0" -cranelift-object = "0.103.0" +cranelift-codegen = "0.104.0" +cranelift-jit = "0.104.0" +cranelift-frontend = "0.104.0" +cranelift-module = "0.104.0" +cranelift-native = "0.104.0" +cranelift-object = "0.104.0" internment = { version = "0.7.4", default-features = false, features = ["arc"] } lalrpop-util = "0.20.0" lazy_static = "1.4.0" @@ -26,10 +26,10 @@ pretty = { version = "0.12.3", features = ["termcolor"] } proptest = "1.4.0" rand = "0.8.5" rustyline = "13.0.0" -target-lexicon = "0.12.12" -tempfile = "3.8.1" -thiserror = "1.0.52" -anyhow = "1.0.77" +target-lexicon = "0.12.13" +tempfile = "3.9.0" +thiserror = "1.0.56" +anyhow = "1.0.79" [build-dependencies] lalrpop = "0.20.0" diff --git a/examples/basic/generated0001.ngr b/examples/basic/generated0001.ngr new file mode 100644 index 0000000..ae46375 --- /dev/null +++ b/examples/basic/generated0001.ngr @@ -0,0 +1,3 @@ +x = 4u64; +function f(y) (x + y) +print x; \ No newline at end of file diff --git a/examples/basic/generated0002.ngr b/examples/basic/generated0002.ngr new file mode 100644 index 0000000..cc51a9e --- /dev/null +++ b/examples/basic/generated0002.ngr @@ -0,0 +1,7 @@ +b = -7662558304906888395i64; +z = 1030390794u32; +v = z; +q = z; +s = -2115098981i32; +t = s; +print t; \ No newline at end of file diff --git a/examples/basic/generated0003.ngr b/examples/basic/generated0003.ngr new file mode 100644 index 0000000..616bba6 --- /dev/null +++ b/examples/basic/generated0003.ngr @@ -0,0 +1,4 @@ +n = (49u8 + 155u8); +q = n; +function u (b) n + b +v = n; \ No newline at end of file diff --git a/examples/basic/test1.ngr b/examples/basic/test1.ngr index b8d66e1..decdcbd 100644 --- a/examples/basic/test1.ngr +++ b/examples/basic/test1.ngr @@ -1,4 +1,4 @@ x = 5; y = 4*x + 3; print x; -print y; +print y; \ No newline at end of file diff --git a/runtime/rts.c b/runtime/rts.c index a999bae..d8eba53 100644 --- a/runtime/rts.c +++ b/runtime/rts.c @@ -28,6 +28,11 @@ void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) { case /* I64 = */ 23: printf("%s = %" PRIi64 "i64\n", variable_name, value); break; + case /* void = */ 255: + printf("%s = \n", variable_name); + break; + default: + printf("%s = UNKNOWN VTYPE %d\n", variable_name, vtype); } } diff --git a/src/backend.rs b/src/backend.rs index c6cc556..b2fee29 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -60,6 +60,7 @@ pub struct Backend { data_ctx: DataDescription, runtime_functions: RuntimeFunctions, defined_strings: HashMap, + defined_functions: HashMap, FuncId>, defined_symbols: HashMap, (DataId, ConstantType)>, output_buffer: Option, platform: Triple, @@ -92,6 +93,7 @@ impl Backend { data_ctx: DataDescription::new(), runtime_functions, defined_strings: HashMap::new(), + defined_functions: HashMap::new(), defined_symbols: HashMap::new(), output_buffer, platform: Triple::host(), @@ -132,6 +134,7 @@ impl Backend { data_ctx: DataDescription::new(), runtime_functions, defined_strings: HashMap::new(), + defined_functions: HashMap::new(), defined_symbols: HashMap::new(), output_buffer: None, platform, diff --git a/src/backend/error.rs b/src/backend/error.rs index ad32cff..e5e2a1e 100644 --- a/src/backend/error.rs +++ b/src/backend/error.rs @@ -55,6 +55,7 @@ impl From for Diagnostic { match value { BackendError::Cranelift(me) => { Diagnostic::error().with_message(format!("Internal cranelift error: {}", me)) + .with_notes(vec![format!("{:?}", me)]) } BackendError::BuiltinError(me) => { Diagnostic::error().with_message(format!("Internal runtime function error: {}", me)) diff --git a/src/backend/eval.rs b/src/backend/eval.rs index f169e6f..3c639b1 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -26,34 +26,7 @@ impl Backend { /// of the built-in test systems.) pub fn eval(program: Program) -> Result>> { let mut jitter = Backend::jit(Some(String::new()))?; - let mut function_map = HashMap::new(); - let mut main_function_body = vec![]; - - for item in program.items { - match item { - TopLevel::Function(name, args, rettype, body) => { - let function_id = - jitter.compile_function(name.as_str(), args.as_slice(), rettype, body)?; - function_map.insert(name, function_id); - } - - TopLevel::Statement(stmt) => { - main_function_body.push(stmt); - } - } - } - - let main_function_body = Expression::Block( - Location::manufactured(), - Type::Primitive(crate::eval::PrimitiveType::Void), - main_function_body, - ); - let function_id = jitter.compile_function( - "___test_jit_eval___", - &[], - Type::Primitive(crate::eval::PrimitiveType::Void), - main_function_body, - )?; + let function_id = jitter.compile_program("___test_jit_eval___", program)?; jitter.module.finalize_definitions()?; let compiled_bytes = jitter.bytes(function_id); let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; @@ -89,8 +62,13 @@ impl Backend { for item in program.items { match item { TopLevel::Function(name, args, rettype, body) => { - let function_id = - backend.compile_function(name.as_str(), args.as_slice(), rettype, body)?; + let function_id = backend.compile_function( + &mut HashMap::new(), + name.as_str(), + args.as_slice(), + rettype, + body, + )?; function_map.insert(name, function_id); } @@ -111,6 +89,7 @@ impl Backend { let executable_path = my_directory.path().join("test_executable"); backend.compile_function( + &mut HashMap::new(), "gogogo", &[], Type::Primitive(crate::eval::PrimitiveType::Void), @@ -154,6 +133,7 @@ impl Backend { .join("runtime") .join("rts.c"), ) + .arg("-Wl,-ld_classic") .arg(object_file) .arg("-o") .arg(executable_path) @@ -206,16 +186,15 @@ proptest::proptest! { #[test] fn jit_backend(program in Program::arbitrary()) { use crate::eval::PrimOpError; - use pretty::{DocAllocator, Pretty}; - let allocator = pretty::BoxAllocator; - allocator - .text("---------------") - .append(allocator.hardline()) - .append(program.pretty(&allocator)) - .1 - .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) - .expect("rendering works"); - + // use pretty::{DocAllocator, Pretty}; + // let allocator = pretty::BoxAllocator; + // allocator + // .text("---------------") + // .append(allocator.hardline()) + // .append(program.pretty(&allocator)) + // .1 + // .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) + // .expect("rendering works"); let basic_result = program.eval().map(|(_,x)| x); diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 8db56bc..81a9142 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -1,24 +1,23 @@ +use crate::backend::error::BackendError; +use crate::backend::Backend; use crate::eval::PrimitiveType; use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; use crate::syntax::{ConstantType, Location}; -use crate::util::scoped_map::ScopedMap; use cranelift_codegen::ir::{ - self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, Signature, UserFuncName, + self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName }; use cranelift_codegen::isa::CallConv; use cranelift_codegen::Context; use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext}; -use cranelift_module::{FuncId, Linkage, Module}; +use cranelift_module::{DataDescription, FuncId, Linkage, Module}; use internment::ArcIntern; - -use crate::backend::error::BackendError; -use crate::backend::Backend; +use std::collections::{hash_map, HashMap}; /// When we're talking about variables, it's handy to just have a table that points /// from a variable to "what to do if you want to reference this variable", which is /// agnostic about whether the variable is local, global, an argument, etc. Since /// the type of that function is a little bit annoying, we summarize it here. -enum ReferenceBuilder { +pub enum ReferenceBuilder { Global(ConstantType, GlobalValue), Local(ConstantType, cranelift_frontend::Variable), Argument(ConstantType, entities::Value), @@ -29,7 +28,8 @@ impl ReferenceBuilder { match self { ReferenceBuilder::Global(ty, gv) => { let cranelift_type = ir::Type::from(*ty); - let value = builder.ins().symbol_value(cranelift_type, *gv); + let ptr_value = builder.ins().symbol_value(types::I64, *gv); + let value = builder.ins().load(cranelift_type, MemFlags::new(), ptr_value, 0); (value, *ty) } @@ -81,11 +81,44 @@ impl Backend { program: Program, ) -> Result { let mut generated_body = vec![]; + let mut variables = HashMap::new(); + + for (top_level_name, top_level_type) in program.get_top_level_variables() { + match top_level_type { + Type::Function(argument_types, return_type) => { + let func_id = self.declare_function( + top_level_name.as_str(), + Linkage::Export, + argument_types, + *return_type, + )?; + self.defined_functions.insert(top_level_name, func_id); + } + + Type::Primitive(pt) => { + let data_id = self.module.declare_data( + top_level_name.as_str(), + Linkage::Export, + true, + false, + )?; + self.module.define_data(data_id, &pt.blank_data())?; + self.defined_symbols + .insert(top_level_name, (data_id, pt.into())); + } + } + } + + let void = Type::Primitive(PrimitiveType::Void); + let main_func_id = + self.declare_function(function_name, Linkage::Export, vec![], void.clone())?; + self.defined_functions + .insert(ArcIntern::new(function_name.to_string()), main_func_id); for item in program.items { match item { TopLevel::Function(name, args, rettype, body) => { - self.compile_function(name.as_str(), &args, rettype, body)?; + self.compile_function(&mut variables, name.as_str(), &args, rettype, body)?; } TopLevel::Statement(stmt) => { @@ -94,8 +127,8 @@ impl Backend { } } - let void = Type::Primitive(PrimitiveType::Void); self.compile_function( + &mut variables, function_name, &[], void.clone(), @@ -103,14 +136,47 @@ impl Backend { ) } + fn declare_function( + &mut self, + name: &str, + linkage: Linkage, + argument_types: Vec, + return_type: Type, + ) -> Result { + let basic_signature = Signature { + params: argument_types + .iter() + .map(|t| self.translate_type(t)) + .collect(), + returns: if return_type == Type::Primitive(PrimitiveType::Void) { + vec![] + } else { + vec![self.translate_type(&return_type)] + }, + call_conv: CallConv::triple_default(&self.platform), + }; + + // this generates the handle for the function that we'll eventually want to + // return to the user. For now, we declare all functions defined by this + // function as public/global/exported, although we may want to reconsider + // this decision later. + self.module + .declare_function(name, linkage, &basic_signature) + } + /// Compile the given function. pub fn compile_function( &mut self, + variables: &mut HashMap, function_name: &str, arguments: &[(Variable, Type)], return_type: Type, body: Expression, ) -> Result { + // reset the next variable counter. this value shouldn't matter; hopefully + // we won't be using close to 2^32 variables! + self.reset_local_variable_tracker(); + let basic_signature = Signature { params: arguments .iter() @@ -124,17 +190,23 @@ impl Backend { call_conv: CallConv::triple_default(&self.platform), }; - // reset the next variable counter. this value shouldn't matter; hopefully - // we won't be using close to 2^32 variables! - self.reset_local_variable_tracker(); - // this generates the handle for the function that we'll eventually want to // return to the user. For now, we declare all functions defined by this // function as public/global/exported, although we may want to reconsider // this decision later. - let func_id = - self.module - .declare_function(function_name, Linkage::Export, &basic_signature)?; + let interned_name = ArcIntern::new(function_name.to_string()); + let func_id = match self.defined_functions.entry(interned_name) { + hash_map::Entry::Occupied(entry) => *entry.get(), + hash_map::Entry::Vacant(vac) => { + let func_id = self.module.declare_function( + function_name, + Linkage::Export, + &basic_signature, + )?; + vac.insert(func_id); + func_id + } + }; // Next we have to generate the compilation context for the rest of this // function. Currently, we generate a fresh context for every function. @@ -145,14 +217,6 @@ impl Backend { let user_func_name = UserFuncName::user(0, func_id.as_u32()); ctx.func = Function::with_name_signature(user_func_name, basic_signature); - // Let's start creating the variable table we'll use when we're dereferencing - // them later. This table is a little interesting because instead of pointing - // from data to data, we're going to point from data (the variable) to an - // action to take if we encounter that variable at some later point. This - // makes it nice and easy to have many different ways to access data, such - // as globals, function arguments, etc. - let mut variables: ScopedMap, ReferenceBuilder> = ScopedMap::new(); - // At the outer-most scope of things, we'll put global variables we've defined // elsewhere in the program. for (name, (data_id, ty)) in self.defined_symbols.iter() { @@ -160,12 +224,6 @@ impl Backend { variables.insert(name.clone(), ReferenceBuilder::Global(*ty, local_data)); } - // Once we have these, we're going to actually push a level of scope and - // add our arguments. We push scope because if there happen to be any with - // the same name (their shouldn't be, but just in case), we want the arguments - // to win. - variables.new_scope(); - // Finally (!), we generate the function builder that we're going to use to // make this function! let mut fctx = FunctionBuilderContext::new(); @@ -192,7 +250,7 @@ impl Backend { builder.switch_to_block(main_block); - let (value, _) = self.compile_expression(body, &mut variables, &mut builder)?; + let (value, _) = self.compile_expression(body, variables, &mut builder)?; // Now that we're done, inject a return function (one with no actual value; basically // the equivalent of Rust's `return;`). We then seal the block (which lets Cranelift @@ -221,7 +279,7 @@ impl Backend { fn compile_expression( &mut self, expr: Expression, - variables: &mut ScopedMap, + variables: &mut HashMap, builder: &mut FunctionBuilder, ) -> Result<(entities::Value, ConstantType), BackendError> { match expr { @@ -282,6 +340,29 @@ impl Backend { (ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)), + (ConstantType::Void, Type::Primitive(PrimitiveType::Void)) => { + Ok((val, val_type)) + } + + (ConstantType::U8, Type::Primitive(PrimitiveType::I16)) => { + Ok((builder.ins().uextend(types::I16, val), ConstantType::I16)) + } + (ConstantType::U8, Type::Primitive(PrimitiveType::I32)) => { + Ok((builder.ins().uextend(types::I32, val), ConstantType::I32)) + } + (ConstantType::U8, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::I64)) + } + (ConstantType::U16, Type::Primitive(PrimitiveType::I32)) => { + Ok((builder.ins().uextend(types::I32, val), ConstantType::I32)) + } + (ConstantType::U16, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::I64)) + } + (ConstantType::U32, Type::Primitive(PrimitiveType::I64)) => { + Ok((builder.ins().uextend(types::I64, val), ConstantType::I64)) + } + _ => Err(BackendError::InvalidTypeCast { from: val_type.into(), to: target_type, @@ -327,7 +408,7 @@ impl Backend { } Expression::Block(_, _, mut exprs) => match exprs.pop() { - None => Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)), + None => Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)), Some(last) => { for inner in exprs { // we can ignore all of these return values and such, because we @@ -354,7 +435,7 @@ impl Backend { // Look up the value for the variable. Because this might be a // global variable (and that requires special logic), we just turn // this into an `Expression` and re-use the logic in that implementation. - let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var); + let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var.clone()); let (val, vtype) = self.compile_value_or_ref(fake_ref, variables, builder)?; let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); @@ -379,7 +460,7 @@ impl Backend { print_func_ref, &[buffer_ptr, name_ptr, vtype_repr, casted_val], ); - Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)) + Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)) } Expression::Bind(_, name, _, expr) => { @@ -390,7 +471,7 @@ impl Backend { builder.declare_var(variable, ir_type); builder.def_var(variable, value); variables.insert(name, ReferenceBuilder::Local(value_type, variable)); - Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)) + Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)) } } } @@ -400,7 +481,7 @@ impl Backend { fn compile_value_or_ref( &self, valref: ValueOrRef, - variables: &ScopedMap, + variables: &HashMap, builder: &mut FunctionBuilder, ) -> Result<(entities::Value, ConstantType), BackendError> { match valref { @@ -453,3 +534,23 @@ impl Backend { } } } + +impl PrimitiveType { + fn blank_data(&self) -> DataDescription { + let (size, alignment) = match self { + PrimitiveType::Void => (8, 8), + PrimitiveType::U8 => (1, 1), + PrimitiveType::U16 => (2, 2), + PrimitiveType::U32 => (4, 4), + PrimitiveType::U64 => (4, 4), + PrimitiveType::I8 => (1, 1), + PrimitiveType::I16 => (2, 2), + PrimitiveType::I32 => (4, 4), + PrimitiveType::I64 => (4, 4), + }; + let mut result = DataDescription::new(); + result.define_zeroinit(size); + result.set_align(alignment); + result + } +} diff --git a/src/backend/runtime.rs b/src/backend/runtime.rs index 1df45a3..6362176 100644 --- a/src/backend/runtime.rs +++ b/src/backend/runtime.rs @@ -119,7 +119,7 @@ extern "C" fn runtime_print( Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16), Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32), Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64), - Err(_) => format!("{} = {}", reconstituted, value), + Err(_) => format!("{} = {}", reconstituted, value, vtype_repr), }; if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index fed57f3..e86ae69 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -191,6 +191,13 @@ impl PrimitiveType { (PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)), (PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)), + (PrimitiveType::I16, Value::U8(x)) => Ok(Value::I16(*x as i16)), + (PrimitiveType::I32, Value::U8(x)) => Ok(Value::I32(*x as i32)), + (PrimitiveType::I64, Value::U8(x)) => Ok(Value::I64(*x as i64)), + (PrimitiveType::I32, Value::U16(x)) => Ok(Value::I32(*x as i32)), + (PrimitiveType::I64, Value::U16(x)) => Ok(Value::I64(*x as i64)), + (PrimitiveType::I64, Value::U32(x)) => Ok(Value::I64(*x as i64)), + (PrimitiveType::Void, Value::Void) => Ok(Value::Void), _ => Err(PrimOpError::UnsafeCast { diff --git a/src/ir.rs b/src/ir.rs index bb4d0bd..e058e66 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -16,5 +16,6 @@ mod arbitrary; pub mod ast; mod eval; mod strings; +mod top_level; pub use ast::*; diff --git a/src/ir/arbitrary.rs b/src/ir/arbitrary.rs index 47aed0b..acf3cc8 100644 --- a/src/ir/arbitrary.rs +++ b/src/ir/arbitrary.rs @@ -1,5 +1,5 @@ use crate::eval::PrimitiveType; -use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; +use crate::ir::{Expression, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, Variable}; use crate::syntax::Location; use crate::util::scoped_map::ScopedMap; use proptest::strategy::{NewTree, Strategy, ValueTree}; @@ -288,14 +288,25 @@ fn generate_random_expression( ExpressionType::Block => { let num_stmts = BLOCK_LENGTH_DISTRIBUTION.sample(rng); let mut stmts = Vec::new(); - let mut last_type = Type::Primitive(PrimitiveType::Void); + + if num_stmts == 0 { + return Expression::Block(Location::manufactured(), Type::void(), stmts); + } env.new_scope(); - for _ in 0..num_stmts { - let next = generate_random_expression(rng, env); - last_type = next.type_of(); + for _ in 1..num_stmts { + let mut next = generate_random_expression(rng, env); + let next_type = next.type_of(); + if !next_type.is_void() { + let name = generate_random_name(rng); + env.insert(name.clone(), next_type.clone()); + next = Expression::Bind(Location::manufactured(), name, next_type, Box::new(next)); + } stmts.push(next); } + let last_expr = generate_random_expression(rng, env); + let last_type = last_expr.type_of(); + stmts.push(last_expr); env.release_scope(); Expression::Block(Location::manufactured(), last_type, stmts) diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 78c6998..c065414 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -6,7 +6,10 @@ use crate::{ use internment::ArcIntern; use pretty::{BoxAllocator, DocAllocator, Pretty}; use proptest::arbitrary::Arbitrary; -use std::{fmt, str::FromStr, sync::atomic::AtomicUsize}; +use std::convert::TryFrom; +use std::fmt; +use std::str::FromStr; +use std::sync::atomic::AtomicUsize; use super::arbitrary::ProgramGenerator; @@ -97,6 +100,19 @@ pub enum TopLevel { Function(Variable, Vec<(Variable, Type)>, Type, Expression), } +impl TopLevel { + /// Return the type of the item, as inferred or recently + /// computed. + pub fn type_of(&self) -> T { + match self { + TopLevel::Statement(expr) => expr.type_of(), + TopLevel::Function(_, args, ret, _) => { + T::build_function_type(args.iter().map(|(_, t)| t.clone()).collect(), ret.clone()) + } + } + } +} + impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b TopLevel where A: 'a, @@ -148,7 +164,7 @@ pub enum Expression { } impl Expression { - /// Return a reference to the type of the expression, as inferred or recently + /// Return the type of the expression, as inferred or recently /// computed. pub fn type_of(&self) -> Type { match self { @@ -242,6 +258,16 @@ where } } +impl Expression { + pub fn to_pretty(&self) -> String { + let arena = pretty::Arena::<()>::new(); + let doc = self.pretty(&arena); + let mut output_bytes = Vec::new(); + doc.render(72, &mut output_bytes).unwrap(); + String::from_utf8(output_bytes).expect("pretty generates valid utf-8") + } +} + /// A type representing the primitives allowed in the language. /// /// Having this as an enumeration avoids a lot of "this should not happen" @@ -565,27 +591,56 @@ impl TypeOrVar { } } +impl PartialEq for TypeOrVar { + fn eq(&self, other: &Type) -> bool { + match other { + Type::Function(a, b) => match self { + TypeOrVar::Function(x, y) => x == a && y.as_ref() == b.as_ref(), + _ => false, + }, + + Type::Primitive(a) => match self { + TypeOrVar::Primitive(x) => a == x, + _ => false, + }, + } + } +} + pub trait TypeWithVoid { fn void() -> Self; + fn is_void(&self) -> bool; } impl TypeWithVoid for Type { fn void() -> Self { Type::Primitive(PrimitiveType::Void) } + + fn is_void(&self) -> bool { + self == &Type::Primitive(PrimitiveType::Void) + } } impl TypeWithVoid for TypeOrVar { fn void() -> Self { TypeOrVar::Primitive(PrimitiveType::Void) } + + fn is_void(&self) -> bool { + self == &TypeOrVar::Primitive(PrimitiveType::Void) + } } -//impl From for TypeOrVar { -// fn from(value: Type) -> Self { -// TypeOrVar::Type(value) -// } -//} +pub trait TypeWithFunction: Sized { + fn build_function_type(arg_types: Vec, ret_type: Self) -> Self; +} + +impl TypeWithFunction for Type { + fn build_function_type(arg_types: Vec, ret_type: Self) -> Self { + Type::Function(arg_types, Box::new(ret_type)) + } +} impl> From for TypeOrVar { fn from(value: T) -> Self { @@ -598,3 +653,24 @@ impl> From for TypeOrVar { } } } + +impl TryFrom for Type { + type Error = TypeOrVar; + + fn try_from(value: TypeOrVar) -> Result { + match value { + TypeOrVar::Function(args, ret) => { + let args = args + .into_iter() + .map(Type::try_from) + .collect::>()?; + let ret = Type::try_from(*ret)?; + + Ok(Type::Function(args, Box::new(ret))) + } + + TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)), + _ => Err(value), + } + } +} diff --git a/src/ir/top_level.rs b/src/ir/top_level.rs new file mode 100644 index 0000000..c3459fb --- /dev/null +++ b/src/ir/top_level.rs @@ -0,0 +1,45 @@ +use crate::ir::{Expression, Program, TopLevel, TypeWithFunction, TypeWithVoid, Variable}; +use std::collections::HashMap; + +impl Program { + /// Retrieve the complete set of variables that are defined at the top level of + /// this program. + pub fn get_top_level_variables(&self) -> HashMap { + let mut result = HashMap::new(); + + for item in self.items.iter() { + result.extend(item.get_top_level_variables()); + } + + result + } +} + +impl TopLevel { + /// Retrieve the complete set of variables that are defined at the top level of + /// this top-level item. + /// + /// For functions, this is the function name. For expressions this can be a little + /// bit more complicated, as it sort of depends on the block structuring. + pub fn get_top_level_variables(&self) -> HashMap { + match self { + TopLevel::Function(name, _, _, _) => HashMap::from([(name.clone(), self.type_of())]), + TopLevel::Statement(expr) => expr.get_top_level_variables(), + } + } +} + +impl Expression { + /// Retrieve the complete set of variables that are defined at the top level of + /// this expression. Basically, returns the variable named in bind. + pub fn get_top_level_variables(&self) -> HashMap { + match self { + Expression::Bind(_, name, ty, expr) => { + let mut tlvs = expr.get_top_level_variables(); + tlvs.insert(name.clone(), ty.clone()); + tlvs + }, + _ => HashMap::new(), + } + } +} diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index 2ec0894..de005ec 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -154,6 +154,19 @@ impl PartialEq for Expression { } } +impl Expression { + /// Get the location of the expression in the source file (if there is one). + pub fn location(&self) -> &Location { + match self { + Expression::Value(loc, _) => loc, + Expression::Reference(loc, _) => loc, + Expression::Cast(loc, _, _) => loc, + Expression::Primitive(loc, _, _) => loc, + Expression::Block(loc, _) => loc, + } + } +} + /// A value from the source syntax #[derive(Clone, Debug, PartialEq, Eq)] pub enum Value { diff --git a/src/syntax/tokens.rs b/src/syntax/tokens.rs index 5a2bb99..48c1c52 100644 --- a/src/syntax/tokens.rs +++ b/src/syntax/tokens.rs @@ -270,6 +270,7 @@ impl TryFrom for ConstantType { 21 => Ok(ConstantType::I16), 22 => Ok(ConstantType::I32), 23 => Ok(ConstantType::I64), + 255 => Ok(ConstantType::Void), _ => Err(InvalidConstantType::Value(value)), } } diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 4a1800b..7e06013 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -4,6 +4,7 @@ use crate::syntax::{self, ConstantType}; use crate::type_infer::solve::Constraint; use crate::util::scoped_map::ScopedMap; use internment::ArcIntern; +use std::collections::HashMap; use std::str::FromStr; /// This function takes a syntactic program and converts it into the IR version of the @@ -19,7 +20,7 @@ pub fn convert_program( let mut constraint_db = Vec::new(); let mut items = Vec::new(); let mut renames = ScopedMap::new(); - let mut bindings = ScopedMap::new(); + let mut bindings = HashMap::new(); for item in program.items.drain(..) { items.push(convert_top_level( @@ -40,43 +41,68 @@ pub fn convert_top_level( top_level: syntax::TopLevel, constraint_db: &mut Vec, renames: &mut ScopedMap, ArcIntern>, - bindings: &mut ScopedMap, ir::TypeOrVar>, + bindings: &mut HashMap, ir::TypeOrVar>, ) -> ir::TopLevel { match top_level { syntax::TopLevel::Function(name, args, expr) => { - // First, let us figure out what we're going to name this function. If the user + // First, at some point we're going to want to know a location for this function, + // which should either be the name if we have one, or the body if we don't. + let function_location = match name { + None => expr.location().clone(), + Some(ref name) => name.location.clone(), + }; + // Next, let us figure out what we're going to name this function. If the user // didn't provide one, we'll just call it "function:" for them. (We'll // want a name for this function, eventually, so we might as well do it now.) // // If they did provide a name, see if we're shadowed. IF we are, then we'll have // to specialize the name a bit. Otherwise we'll stick with their name. - let funname = match name { + let function_name = match name { None => ir::gensym("function"), Some(unbound) => finalize_name(bindings, renames, unbound), }; - // Now we manufacture types for the inputs and outputs, and then a type for the - // function itself. We're not going to make any claims on these types, yet; they're - // all just unknown type variables we need to work out. - let argtypes: Vec = args.iter().map(|_| ir::TypeOrVar::new()).collect(); + // This function is going to have a type. We don't know what it is, but it'll have + // one. + let function_type = ir::TypeOrVar::new(); + bindings.insert(function_name.clone(), function_type.clone()); + + // Then, let's figure out what to do with the argument names, which similarly + // may need to be renamed. We'll also generate some new type variables to associate + // with all of them. + // + // Note that we want to do all this in a new renaming scope, so that we shadow + // appropriately. + renames.new_scope(); + let arginfo = args + .iter() + .map(|name| { + let new_type = ir::TypeOrVar::new(); + constraint_db.push(Constraint::IsSomething( + name.location.clone(), + new_type.clone(), + )); + let new_name = finalize_name(bindings, renames, name.clone()); + bindings.insert(new_name.clone(), new_type.clone()); + (new_name, new_type) + }) + .collect::>(); + + // Now we manufacture types for the outputs and then a type for the function itself. + // We're not going to make any claims on these types, yet; they're all just unknown + // type variables we need to work out. let rettype = ir::TypeOrVar::new(); - let funtype = ir::TypeOrVar::Function(argtypes.clone(), Box::new(rettype.clone())); - - // Now let's bind these types into the environment. First, we bind our function - // namae to the function type we just generated. - bindings.insert(funname.clone(), funtype); - // And then we attach the argument names to the argument types. (We have to go - // convert all the names, first.) - let iargs: Vec> = - args.iter().map(|x| ArcIntern::new(x.to_string())).collect(); - assert_eq!(argtypes.len(), iargs.len()); - let mut function_args = vec![]; - for ((arg_name, arg_type), orig_name) in iargs.iter().zip(argtypes).zip(args) { - bindings.insert(arg_name.clone(), arg_type.clone()); - function_args.push((arg_name.clone(), arg_type.clone())); - constraint_db.push(Constraint::IsSomething(orig_name.location, arg_type)); - } + let actual_function_type = ir::TypeOrVar::Function( + arginfo.iter().map(|x| x.1.clone()).collect(), + Box::new(rettype.clone()), + ); + constraint_db.push(Constraint::Equivalent( + function_location, + function_type, + actual_function_type, + )); + // Now let's convert the body over to the new IR. let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); constraint_db.push(Constraint::Equivalent( expr.location().clone(), @@ -84,7 +110,10 @@ pub fn convert_top_level( ty, )); - ir::TopLevel::Function(funname, function_args, rettype, expr) + // Remember to exit this scoping level! + renames.release_scope(); + + ir::TopLevel::Function(function_name, arginfo, rettype, expr) } syntax::TopLevel::Statement(stmt) => { @@ -108,7 +137,7 @@ fn convert_statement( statement: syntax::Statement, constraint_db: &mut Vec, renames: &mut ScopedMap, ArcIntern>, - bindings: &mut ScopedMap, ir::TypeOrVar>, + bindings: &mut HashMap, ir::TypeOrVar>, ) -> ir::Expression { match statement { syntax::Statement::Print(loc, name) => { @@ -152,7 +181,7 @@ fn convert_expression( expression: syntax::Expression, constraint_db: &mut Vec, renames: &mut ScopedMap, ArcIntern>, - bindings: &mut ScopedMap, ir::TypeOrVar>, + bindings: &mut HashMap, ir::TypeOrVar>, ) -> (ir::Expression, ir::TypeOrVar) { match expression { // converting values is mostly tedious, because there's so many cases @@ -339,7 +368,7 @@ fn finalize_expression( } fn finalize_name( - bindings: &ScopedMap, ir::TypeOrVar>, + bindings: &HashMap, ir::TypeOrVar>, renames: &mut ScopedMap, ArcIntern>, name: syntax::Name, ) -> ArcIntern { diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index e2af988..1b9e703 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -105,7 +105,7 @@ fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type { TypeOrVar::Primitive(x) => Type::Primitive(x), TypeOrVar::Variable(_, tvar) => match resolutions.get(&tvar) { None => panic!("Did not resolve type for type variable {}", tvar), - Some(pt) => Type::Primitive(*pt), + Some(pt) => pt.clone(), }, TypeOrVar::Function(mut args, ret) => Type::Function( args.drain(..) diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index 34127de..0b1b7f8 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -1,5 +1,5 @@ use crate::eval::PrimitiveType; -use crate::ir::{Primitive, TypeOrVar}; +use crate::ir::{Primitive, Type, TypeOrVar}; use crate::syntax::Location; use codespan_reporting::diagnostic::Diagnostic; use internment::ArcIntern; @@ -50,7 +50,7 @@ impl fmt::Display for Constraint { } } -pub type TypeResolutions = HashMap, PrimitiveType>; +pub type TypeResolutions = HashMap, Type>; /// The results of type inference; like [`Result`], but with a bit more information. /// @@ -257,18 +257,21 @@ pub fn solve_constraints( ) -> TypeInferenceResult { let mut errors = vec![]; let mut warnings = vec![]; - let mut resolutions = HashMap::new(); + let mut resolutions: HashMap, Type> = HashMap::new(); let mut changed_something = true; - println!("CONSTRAINTS:"); - for constraint in constraint_db.iter() { - println!("{}", constraint); - } - // We want to run this inference endlessly, until either we have solved all of our // constraints. Internal to the loop, we have a check that will make sure that we // do (eventually) stop. while changed_something && !constraint_db.is_empty() { + println!("CONSTRAINT:"); + for constraint in constraint_db.iter() { + println!(" {}", constraint); + } + println!("RESOLUTIONS:"); + for (name, ty) in resolutions.iter() { + println!(" {} = {}", name, ty); + } // Set this to false at the top of the loop. We'll set this to true if we make // progress in any way further down, but having this here prevents us from going // into an infinite look when we can't figure stuff out. @@ -292,9 +295,13 @@ pub fn solve_constraints( Constraint::IsSomething(_, TypeOrVar::Function(_, _)) | Constraint::IsSomething(_, TypeOrVar::Primitive(_)) => changed_something = true, - // Otherwise, we'll keep looking for it. - Constraint::IsSomething(_, TypeOrVar::Variable(_, _)) => { - constraint_db.push(constraint); + // Otherwise, see if we've resolved this variable to anything. If not, add it + // back. + Constraint::IsSomething(_, TypeOrVar::Variable(_, ref name)) => { + if resolutions.get(name).is_none() { + constraint_db.push(constraint); + } + changed_something = true; } // Case #1a: We have two primitive types. If they're equal, we've discharged this @@ -311,35 +318,7 @@ pub fn solve_constraints( changed_something = true; } - // Case #2: One of the two constraints is a primitive, and the other is a variable. - // In this case, we'll check to see if we've resolved the variable, and check for - // equivalence if we have. If we haven't, we'll set that variable to be primitive - // type. - Constraint::Equivalent( - loc, - TypeOrVar::Primitive(t), - TypeOrVar::Variable(_, name), - ) - | Constraint::Equivalent( - loc, - TypeOrVar::Variable(_, name), - TypeOrVar::Primitive(t), - ) => { - match resolutions.get(&name) { - None => { - resolutions.insert(name, t); - } - Some(t2) if &t == t2 => {} - Some(t2) => errors.push(TypeInferenceError::NotEquivalent( - loc, - TypeOrVar::Primitive(t), - TypeOrVar::Primitive(*t2), - )), - } - changed_something = true; - } - - // Case #3: They're both variables. In which case, we'll have to do much the same + // Case #2: They're both variables. In which case, we'll have to do much the same // check, but now on their resolutions. Constraint::Equivalent( ref loc, @@ -350,11 +329,11 @@ pub fn solve_constraints( constraint_db.push(constraint); } (Some(pt), None) => { - resolutions.insert(name2.clone(), *pt); + resolutions.insert(name2.clone(), pt.clone()); changed_something = true; } (None, Some(pt)) => { - resolutions.insert(name1.clone(), *pt); + resolutions.insert(name1.clone(), pt.clone()); changed_something = true; } (Some(pt1), Some(pt2)) if pt1 == pt2 => { @@ -363,13 +342,43 @@ pub fn solve_constraints( (Some(pt1), Some(pt2)) => { errors.push(TypeInferenceError::NotEquivalent( loc.clone(), - TypeOrVar::Primitive(*pt1), - TypeOrVar::Primitive(*pt2), + pt1.clone().into(), + pt2.clone().into(), )); changed_something = true; } }, + // Case #3: One of the two constraints is a primitive, and the other is a variable. + // In this case, we'll check to see if we've resolved the variable, and check for + // equivalence if we have. If we haven't, we'll set that variable to be primitive + // type. + Constraint::Equivalent(loc, t, TypeOrVar::Variable(vloc, name)) + | Constraint::Equivalent(loc, TypeOrVar::Variable(vloc, name), t) => { + match resolutions.get(&name) { + None => match t.try_into() { + Ok(real_type) => { + resolutions.insert(name, real_type); + } + Err(variable_type) => { + constraint_db.push(Constraint::Equivalent( + loc, + variable_type, + TypeOrVar::Variable(vloc, name), + )); + continue; + } + }, + Some(t2) if &t == t2 => {} + Some(t2) => errors.push(TypeInferenceError::NotEquivalent( + loc, + t, + t2.clone().into(), + )), + } + changed_something = true; + } + // Case #4: Like primitives, but for function types. This is a little complicated, because // we first want to resolve all the type variables in the two types, and then see if they're // equivalent. Fortunately, though, we can cheat a bit. What we're going to do is first see @@ -445,7 +454,7 @@ pub fn solve_constraints( Some(nt) => { constraint_db.push(Constraint::FitsInNumType( loc, - TypeOrVar::Primitive(*nt), + nt.clone().into(), val, )); changed_something = true; @@ -474,7 +483,7 @@ pub fn solve_constraints( Some(nt) => { constraint_db.push(Constraint::CanCastTo( loc, - TypeOrVar::Primitive(*nt), + nt.clone().into(), to_type, )); changed_something = true; @@ -494,7 +503,7 @@ pub fn solve_constraints( constraint_db.push(Constraint::CanCastTo( loc, from_type, - TypeOrVar::Primitive(*nt), + nt.clone().into(), )); changed_something = true; } @@ -560,8 +569,7 @@ pub fn solve_constraints( None => constraint_db .push(Constraint::NumericType(loc, TypeOrVar::Variable(vloc, var))), Some(nt) => { - constraint_db - .push(Constraint::NumericType(loc, TypeOrVar::Primitive(*nt))); + constraint_db.push(Constraint::NumericType(loc, nt.clone().into())); changed_something = true; } } @@ -592,10 +600,8 @@ pub fn solve_constraints( TypeOrVar::Variable(vloc, var), )), Some(nt) => { - constraint_db.push(Constraint::ConstantNumericType( - loc, - TypeOrVar::Primitive(*nt), - )); + constraint_db + .push(Constraint::ConstantNumericType(loc, nt.clone().into())); changed_something = true; } }