λ Support functions! #5

Open
acw wants to merge 59 commits from awick/functions into develop
21 changed files with 477 additions and 185 deletions
Showing only changes of commit 4ba196d2a6 - Show all commits

View File

@@ -9,15 +9,15 @@ name = "ngr"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
clap = { version = "4.4.11", features = ["derive"] } clap = { version = "4.4.18", features = ["derive"] }
codespan = "0.11.1" codespan = "0.11.1"
codespan-reporting = "0.11.1" codespan-reporting = "0.11.1"
cranelift-codegen = "0.103.0" cranelift-codegen = "0.104.0"
cranelift-jit = "0.103.0" cranelift-jit = "0.104.0"
cranelift-frontend = "0.103.0" cranelift-frontend = "0.104.0"
cranelift-module = "0.103.0" cranelift-module = "0.104.0"
cranelift-native = "0.103.0" cranelift-native = "0.104.0"
cranelift-object = "0.103.0" cranelift-object = "0.104.0"
internment = { version = "0.7.4", default-features = false, features = ["arc"] } internment = { version = "0.7.4", default-features = false, features = ["arc"] }
lalrpop-util = "0.20.0" lalrpop-util = "0.20.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
@@ -26,10 +26,10 @@ pretty = { version = "0.12.3", features = ["termcolor"] }
proptest = "1.4.0" proptest = "1.4.0"
rand = "0.8.5" rand = "0.8.5"
rustyline = "13.0.0" rustyline = "13.0.0"
target-lexicon = "0.12.12" target-lexicon = "0.12.13"
tempfile = "3.8.1" tempfile = "3.9.0"
thiserror = "1.0.52" thiserror = "1.0.56"
anyhow = "1.0.77" anyhow = "1.0.79"
[build-dependencies] [build-dependencies]
lalrpop = "0.20.0" lalrpop = "0.20.0"

View File

@@ -0,0 +1,3 @@
x = 4u64;
function f(y) (x + y)
print x;

View File

@@ -0,0 +1,7 @@
b = -7662558304906888395i64;
z = 1030390794u32;
v = z;
q = <i64>z;
s = -2115098981i32;
t = <i32>s;
print t;

View File

@@ -0,0 +1,4 @@
n = (49u8 + 155u8);
q = n;
function u (b) n + b
v = n;

View File

@@ -1,4 +1,4 @@
x = 5; x = 5;
y = 4*x + 3; y = 4*x + 3;
print x; print x;
print y; print y;

View File

@@ -28,6 +28,11 @@ void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) {
case /* I64 = */ 23: case /* I64 = */ 23:
printf("%s = %" PRIi64 "i64\n", variable_name, value); printf("%s = %" PRIi64 "i64\n", variable_name, value);
break; break;
case /* void = */ 255:
printf("%s = <void>\n", variable_name);
break;
default:
printf("%s = UNKNOWN VTYPE %d\n", variable_name, vtype);
} }
} }

View File

@@ -60,6 +60,7 @@ pub struct Backend<M: Module> {
data_ctx: DataDescription, data_ctx: DataDescription,
runtime_functions: RuntimeFunctions, runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>, defined_strings: HashMap<String, DataId>,
defined_functions: HashMap<ArcIntern<String>, FuncId>,
defined_symbols: HashMap<ArcIntern<String>, (DataId, ConstantType)>, defined_symbols: HashMap<ArcIntern<String>, (DataId, ConstantType)>,
output_buffer: Option<String>, output_buffer: Option<String>,
platform: Triple, platform: Triple,
@@ -92,6 +93,7 @@ impl Backend<JITModule> {
data_ctx: DataDescription::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_functions: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer, output_buffer,
platform: Triple::host(), platform: Triple::host(),
@@ -132,6 +134,7 @@ impl Backend<ObjectModule> {
data_ctx: DataDescription::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_functions: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer: None, output_buffer: None,
platform, platform,

View File

@@ -55,6 +55,7 @@ impl From<BackendError> for Diagnostic<usize> {
match value { match value {
BackendError::Cranelift(me) => { BackendError::Cranelift(me) => {
Diagnostic::error().with_message(format!("Internal cranelift error: {}", me)) Diagnostic::error().with_message(format!("Internal cranelift error: {}", me))
.with_notes(vec![format!("{:?}", me)])
} }
BackendError::BuiltinError(me) => { BackendError::BuiltinError(me) => {
Diagnostic::error().with_message(format!("Internal runtime function error: {}", me)) Diagnostic::error().with_message(format!("Internal runtime function error: {}", me))

View File

@@ -26,34 +26,7 @@ impl Backend<JITModule> {
/// of the built-in test systems.) /// of the built-in test systems.)
pub fn eval(program: Program<Type>) -> Result<String, EvalError<Expression<Type>>> { pub fn eval(program: Program<Type>) -> Result<String, EvalError<Expression<Type>>> {
let mut jitter = Backend::jit(Some(String::new()))?; let mut jitter = Backend::jit(Some(String::new()))?;
let mut function_map = HashMap::new(); let function_id = jitter.compile_program("___test_jit_eval___", program)?;
let mut main_function_body = vec![];
for item in program.items {
match item {
TopLevel::Function(name, args, rettype, body) => {
let function_id =
jitter.compile_function(name.as_str(), args.as_slice(), rettype, body)?;
function_map.insert(name, function_id);
}
TopLevel::Statement(stmt) => {
main_function_body.push(stmt);
}
}
}
let main_function_body = Expression::Block(
Location::manufactured(),
Type::Primitive(crate::eval::PrimitiveType::Void),
main_function_body,
);
let function_id = jitter.compile_function(
"___test_jit_eval___",
&[],
Type::Primitive(crate::eval::PrimitiveType::Void),
main_function_body,
)?;
jitter.module.finalize_definitions()?; jitter.module.finalize_definitions()?;
let compiled_bytes = jitter.bytes(function_id); let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
@@ -89,8 +62,13 @@ impl Backend<ObjectModule> {
for item in program.items { for item in program.items {
match item { match item {
TopLevel::Function(name, args, rettype, body) => { TopLevel::Function(name, args, rettype, body) => {
let function_id = let function_id = backend.compile_function(
backend.compile_function(name.as_str(), args.as_slice(), rettype, body)?; &mut HashMap::new(),
name.as_str(),
args.as_slice(),
rettype,
body,
)?;
function_map.insert(name, function_id); function_map.insert(name, function_id);
} }
@@ -111,6 +89,7 @@ impl Backend<ObjectModule> {
let executable_path = my_directory.path().join("test_executable"); let executable_path = my_directory.path().join("test_executable");
backend.compile_function( backend.compile_function(
&mut HashMap::new(),
"gogogo", "gogogo",
&[], &[],
Type::Primitive(crate::eval::PrimitiveType::Void), Type::Primitive(crate::eval::PrimitiveType::Void),
@@ -154,6 +133,7 @@ impl Backend<ObjectModule> {
.join("runtime") .join("runtime")
.join("rts.c"), .join("rts.c"),
) )
.arg("-Wl,-ld_classic")
.arg(object_file) .arg(object_file)
.arg("-o") .arg("-o")
.arg(executable_path) .arg(executable_path)
@@ -206,16 +186,15 @@ proptest::proptest! {
#[test] #[test]
fn jit_backend(program in Program::arbitrary()) { fn jit_backend(program in Program::arbitrary()) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
use pretty::{DocAllocator, Pretty}; // use pretty::{DocAllocator, Pretty};
let allocator = pretty::BoxAllocator; // let allocator = pretty::BoxAllocator;
allocator // allocator
.text("---------------") // .text("---------------")
.append(allocator.hardline()) // .append(allocator.hardline())
.append(program.pretty(&allocator)) // .append(program.pretty(&allocator))
.1 // .1
.render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) // .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
.expect("rendering works"); // .expect("rendering works");
let basic_result = program.eval().map(|(_,x)| x); let basic_result = program.eval().map(|(_,x)| x);

View File

@@ -1,24 +1,23 @@
use crate::backend::error::BackendError;
use crate::backend::Backend;
use crate::eval::PrimitiveType; use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable};
use crate::syntax::{ConstantType, Location}; use crate::syntax::{ConstantType, Location};
use crate::util::scoped_map::ScopedMap;
use cranelift_codegen::ir::{ use cranelift_codegen::ir::{
self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, Signature, UserFuncName, self, entities, types, AbiParam, Function, GlobalValue, InstBuilder, MemFlags, Signature, UserFuncName
}; };
use cranelift_codegen::isa::CallConv; use cranelift_codegen::isa::CallConv;
use cranelift_codegen::Context; use cranelift_codegen::Context;
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext}; use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_module::{FuncId, Linkage, Module}; use cranelift_module::{DataDescription, FuncId, Linkage, Module};
use internment::ArcIntern; use internment::ArcIntern;
use std::collections::{hash_map, HashMap};
use crate::backend::error::BackendError;
use crate::backend::Backend;
/// When we're talking about variables, it's handy to just have a table that points /// When we're talking about variables, it's handy to just have a table that points
/// from a variable to "what to do if you want to reference this variable", which is /// from a variable to "what to do if you want to reference this variable", which is
/// agnostic about whether the variable is local, global, an argument, etc. Since /// agnostic about whether the variable is local, global, an argument, etc. Since
/// the type of that function is a little bit annoying, we summarize it here. /// the type of that function is a little bit annoying, we summarize it here.
enum ReferenceBuilder { pub enum ReferenceBuilder {
Global(ConstantType, GlobalValue), Global(ConstantType, GlobalValue),
Local(ConstantType, cranelift_frontend::Variable), Local(ConstantType, cranelift_frontend::Variable),
Argument(ConstantType, entities::Value), Argument(ConstantType, entities::Value),
@@ -29,7 +28,8 @@ impl ReferenceBuilder {
match self { match self {
ReferenceBuilder::Global(ty, gv) => { ReferenceBuilder::Global(ty, gv) => {
let cranelift_type = ir::Type::from(*ty); let cranelift_type = ir::Type::from(*ty);
let value = builder.ins().symbol_value(cranelift_type, *gv); let ptr_value = builder.ins().symbol_value(types::I64, *gv);
let value = builder.ins().load(cranelift_type, MemFlags::new(), ptr_value, 0);
(value, *ty) (value, *ty)
} }
@@ -81,11 +81,44 @@ impl<M: Module> Backend<M> {
program: Program<Type>, program: Program<Type>,
) -> Result<FuncId, BackendError> { ) -> Result<FuncId, BackendError> {
let mut generated_body = vec![]; let mut generated_body = vec![];
let mut variables = HashMap::new();
for (top_level_name, top_level_type) in program.get_top_level_variables() {
match top_level_type {
Type::Function(argument_types, return_type) => {
let func_id = self.declare_function(
top_level_name.as_str(),
Linkage::Export,
argument_types,
*return_type,
)?;
self.defined_functions.insert(top_level_name, func_id);
}
Type::Primitive(pt) => {
let data_id = self.module.declare_data(
top_level_name.as_str(),
Linkage::Export,
true,
false,
)?;
self.module.define_data(data_id, &pt.blank_data())?;
self.defined_symbols
.insert(top_level_name, (data_id, pt.into()));
}
}
}
let void = Type::Primitive(PrimitiveType::Void);
let main_func_id =
self.declare_function(function_name, Linkage::Export, vec![], void.clone())?;
self.defined_functions
.insert(ArcIntern::new(function_name.to_string()), main_func_id);
for item in program.items { for item in program.items {
match item { match item {
TopLevel::Function(name, args, rettype, body) => { TopLevel::Function(name, args, rettype, body) => {
self.compile_function(name.as_str(), &args, rettype, body)?; self.compile_function(&mut variables, name.as_str(), &args, rettype, body)?;
} }
TopLevel::Statement(stmt) => { TopLevel::Statement(stmt) => {
@@ -94,8 +127,8 @@ impl<M: Module> Backend<M> {
} }
} }
let void = Type::Primitive(PrimitiveType::Void);
self.compile_function( self.compile_function(
&mut variables,
function_name, function_name,
&[], &[],
void.clone(), void.clone(),
@@ -103,14 +136,47 @@ impl<M: Module> Backend<M> {
) )
} }
fn declare_function(
&mut self,
name: &str,
linkage: Linkage,
argument_types: Vec<Type>,
return_type: Type,
) -> Result<FuncId, cranelift_module::ModuleError> {
let basic_signature = Signature {
params: argument_types
.iter()
.map(|t| self.translate_type(t))
.collect(),
returns: if return_type == Type::Primitive(PrimitiveType::Void) {
vec![]
} else {
vec![self.translate_type(&return_type)]
},
call_conv: CallConv::triple_default(&self.platform),
};
// this generates the handle for the function that we'll eventually want to
// return to the user. For now, we declare all functions defined by this
// function as public/global/exported, although we may want to reconsider
// this decision later.
self.module
.declare_function(name, linkage, &basic_signature)
}
/// Compile the given function. /// Compile the given function.
pub fn compile_function( pub fn compile_function(
&mut self, &mut self,
variables: &mut HashMap<Variable, ReferenceBuilder>,
function_name: &str, function_name: &str,
arguments: &[(Variable, Type)], arguments: &[(Variable, Type)],
return_type: Type, return_type: Type,
body: Expression<Type>, body: Expression<Type>,
) -> Result<FuncId, BackendError> { ) -> Result<FuncId, BackendError> {
// reset the next variable counter. this value shouldn't matter; hopefully
// we won't be using close to 2^32 variables!
self.reset_local_variable_tracker();
let basic_signature = Signature { let basic_signature = Signature {
params: arguments params: arguments
.iter() .iter()
@@ -124,17 +190,23 @@ impl<M: Module> Backend<M> {
call_conv: CallConv::triple_default(&self.platform), call_conv: CallConv::triple_default(&self.platform),
}; };
// reset the next variable counter. this value shouldn't matter; hopefully
// we won't be using close to 2^32 variables!
self.reset_local_variable_tracker();
// this generates the handle for the function that we'll eventually want to // this generates the handle for the function that we'll eventually want to
// return to the user. For now, we declare all functions defined by this // return to the user. For now, we declare all functions defined by this
// function as public/global/exported, although we may want to reconsider // function as public/global/exported, although we may want to reconsider
// this decision later. // this decision later.
let func_id = let interned_name = ArcIntern::new(function_name.to_string());
self.module let func_id = match self.defined_functions.entry(interned_name) {
.declare_function(function_name, Linkage::Export, &basic_signature)?; hash_map::Entry::Occupied(entry) => *entry.get(),
hash_map::Entry::Vacant(vac) => {
let func_id = self.module.declare_function(
function_name,
Linkage::Export,
&basic_signature,
)?;
vac.insert(func_id);
func_id
}
};
// Next we have to generate the compilation context for the rest of this // Next we have to generate the compilation context for the rest of this
// function. Currently, we generate a fresh context for every function. // function. Currently, we generate a fresh context for every function.
@@ -145,14 +217,6 @@ impl<M: Module> Backend<M> {
let user_func_name = UserFuncName::user(0, func_id.as_u32()); let user_func_name = UserFuncName::user(0, func_id.as_u32());
ctx.func = Function::with_name_signature(user_func_name, basic_signature); ctx.func = Function::with_name_signature(user_func_name, basic_signature);
// Let's start creating the variable table we'll use when we're dereferencing
// them later. This table is a little interesting because instead of pointing
// from data to data, we're going to point from data (the variable) to an
// action to take if we encounter that variable at some later point. This
// makes it nice and easy to have many different ways to access data, such
// as globals, function arguments, etc.
let mut variables: ScopedMap<ArcIntern<String>, ReferenceBuilder> = ScopedMap::new();
// At the outer-most scope of things, we'll put global variables we've defined // At the outer-most scope of things, we'll put global variables we've defined
// elsewhere in the program. // elsewhere in the program.
for (name, (data_id, ty)) in self.defined_symbols.iter() { for (name, (data_id, ty)) in self.defined_symbols.iter() {
@@ -160,12 +224,6 @@ impl<M: Module> Backend<M> {
variables.insert(name.clone(), ReferenceBuilder::Global(*ty, local_data)); variables.insert(name.clone(), ReferenceBuilder::Global(*ty, local_data));
} }
// Once we have these, we're going to actually push a level of scope and
// add our arguments. We push scope because if there happen to be any with
// the same name (their shouldn't be, but just in case), we want the arguments
// to win.
variables.new_scope();
// Finally (!), we generate the function builder that we're going to use to // Finally (!), we generate the function builder that we're going to use to
// make this function! // make this function!
let mut fctx = FunctionBuilderContext::new(); let mut fctx = FunctionBuilderContext::new();
@@ -192,7 +250,7 @@ impl<M: Module> Backend<M> {
builder.switch_to_block(main_block); builder.switch_to_block(main_block);
let (value, _) = self.compile_expression(body, &mut variables, &mut builder)?; let (value, _) = self.compile_expression(body, variables, &mut builder)?;
// Now that we're done, inject a return function (one with no actual value; basically // Now that we're done, inject a return function (one with no actual value; basically
// the equivalent of Rust's `return;`). We then seal the block (which lets Cranelift // the equivalent of Rust's `return;`). We then seal the block (which lets Cranelift
@@ -221,7 +279,7 @@ impl<M: Module> Backend<M> {
fn compile_expression( fn compile_expression(
&mut self, &mut self,
expr: Expression<Type>, expr: Expression<Type>,
variables: &mut ScopedMap<Variable, ReferenceBuilder>, variables: &mut HashMap<Variable, ReferenceBuilder>,
builder: &mut FunctionBuilder, builder: &mut FunctionBuilder,
) -> Result<(entities::Value, ConstantType), BackendError> { ) -> Result<(entities::Value, ConstantType), BackendError> {
match expr { match expr {
@@ -282,6 +340,29 @@ impl<M: Module> Backend<M> {
(ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)), (ConstantType::U64, Type::Primitive(PrimitiveType::U64)) => Ok((val, val_type)),
(ConstantType::Void, Type::Primitive(PrimitiveType::Void)) => {
Ok((val, val_type))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::I16)) => {
Ok((builder.ins().uextend(types::I16, val), ConstantType::I16))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::I32))
}
(ConstantType::U8, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::I64))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::I32)) => {
Ok((builder.ins().uextend(types::I32, val), ConstantType::I32))
}
(ConstantType::U16, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::I64))
}
(ConstantType::U32, Type::Primitive(PrimitiveType::I64)) => {
Ok((builder.ins().uextend(types::I64, val), ConstantType::I64))
}
_ => Err(BackendError::InvalidTypeCast { _ => Err(BackendError::InvalidTypeCast {
from: val_type.into(), from: val_type.into(),
to: target_type, to: target_type,
@@ -327,7 +408,7 @@ impl<M: Module> Backend<M> {
} }
Expression::Block(_, _, mut exprs) => match exprs.pop() { Expression::Block(_, _, mut exprs) => match exprs.pop() {
None => Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)), None => Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)),
Some(last) => { Some(last) => {
for inner in exprs { for inner in exprs {
// we can ignore all of these return values and such, because we // we can ignore all of these return values and such, because we
@@ -354,7 +435,7 @@ impl<M: Module> Backend<M> {
// Look up the value for the variable. Because this might be a // Look up the value for the variable. Because this might be a
// global variable (and that requires special logic), we just turn // global variable (and that requires special logic), we just turn
// this into an `Expression` and re-use the logic in that implementation. // this into an `Expression` and re-use the logic in that implementation.
let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var); let fake_ref = ValueOrRef::Ref(ann, Type::Primitive(PrimitiveType::U8), var.clone());
let (val, vtype) = self.compile_value_or_ref(fake_ref, variables, builder)?; let (val, vtype) = self.compile_value_or_ref(fake_ref, variables, builder)?;
let vtype_repr = builder.ins().iconst(types::I64, vtype as i64); let vtype_repr = builder.ins().iconst(types::I64, vtype as i64);
@@ -379,7 +460,7 @@ impl<M: Module> Backend<M> {
print_func_ref, print_func_ref,
&[buffer_ptr, name_ptr, vtype_repr, casted_val], &[buffer_ptr, name_ptr, vtype_repr, casted_val],
); );
Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)) Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void))
} }
Expression::Bind(_, name, _, expr) => { Expression::Bind(_, name, _, expr) => {
@@ -390,7 +471,7 @@ impl<M: Module> Backend<M> {
builder.declare_var(variable, ir_type); builder.declare_var(variable, ir_type);
builder.def_var(variable, value); builder.def_var(variable, value);
variables.insert(name, ReferenceBuilder::Local(value_type, variable)); variables.insert(name, ReferenceBuilder::Local(value_type, variable));
Ok((builder.ins().iconst(types::I8, 0), ConstantType::I8)) Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void))
} }
} }
} }
@@ -400,7 +481,7 @@ impl<M: Module> Backend<M> {
fn compile_value_or_ref( fn compile_value_or_ref(
&self, &self,
valref: ValueOrRef<Type>, valref: ValueOrRef<Type>,
variables: &ScopedMap<Variable, ReferenceBuilder>, variables: &HashMap<Variable, ReferenceBuilder>,
builder: &mut FunctionBuilder, builder: &mut FunctionBuilder,
) -> Result<(entities::Value, ConstantType), BackendError> { ) -> Result<(entities::Value, ConstantType), BackendError> {
match valref { match valref {
@@ -453,3 +534,23 @@ impl<M: Module> Backend<M> {
} }
} }
} }
impl PrimitiveType {
fn blank_data(&self) -> DataDescription {
let (size, alignment) = match self {
PrimitiveType::Void => (8, 8),
PrimitiveType::U8 => (1, 1),
PrimitiveType::U16 => (2, 2),
PrimitiveType::U32 => (4, 4),
PrimitiveType::U64 => (4, 4),
PrimitiveType::I8 => (1, 1),
PrimitiveType::I16 => (2, 2),
PrimitiveType::I32 => (4, 4),
PrimitiveType::I64 => (4, 4),
};
let mut result = DataDescription::new();
result.define_zeroinit(size);
result.set_align(alignment);
result
}
}

View File

@@ -119,7 +119,7 @@ extern "C" fn runtime_print(
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16), Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32), Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64), Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value), Err(_) => format!("{} = {}<unknown type {}>", reconstituted, value, vtype_repr),
}; };
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {

View File

@@ -191,6 +191,13 @@ impl PrimitiveType {
(PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)), (PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)), (PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)),
(PrimitiveType::I16, Value::U8(x)) => Ok(Value::I16(*x as i16)),
(PrimitiveType::I32, Value::U8(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I64, Value::U8(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I32, Value::U16(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I64, Value::U16(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::U32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::Void, Value::Void) => Ok(Value::Void), (PrimitiveType::Void, Value::Void) => Ok(Value::Void),
_ => Err(PrimOpError::UnsafeCast { _ => Err(PrimOpError::UnsafeCast {

View File

@@ -16,5 +16,6 @@ mod arbitrary;
pub mod ast; pub mod ast;
mod eval; mod eval;
mod strings; mod strings;
mod top_level;
pub use ast::*; pub use ast::*;

View File

@@ -1,5 +1,5 @@
use crate::eval::PrimitiveType; use crate::eval::PrimitiveType;
use crate::ir::{Expression, Primitive, Program, TopLevel, Type, Value, ValueOrRef, Variable}; use crate::ir::{Expression, Primitive, Program, TopLevel, Type, TypeWithVoid, Value, ValueOrRef, Variable};
use crate::syntax::Location; use crate::syntax::Location;
use crate::util::scoped_map::ScopedMap; use crate::util::scoped_map::ScopedMap;
use proptest::strategy::{NewTree, Strategy, ValueTree}; use proptest::strategy::{NewTree, Strategy, ValueTree};
@@ -288,14 +288,25 @@ fn generate_random_expression(
ExpressionType::Block => { ExpressionType::Block => {
let num_stmts = BLOCK_LENGTH_DISTRIBUTION.sample(rng); let num_stmts = BLOCK_LENGTH_DISTRIBUTION.sample(rng);
let mut stmts = Vec::new(); let mut stmts = Vec::new();
let mut last_type = Type::Primitive(PrimitiveType::Void);
if num_stmts == 0 {
return Expression::Block(Location::manufactured(), Type::void(), stmts);
}
env.new_scope(); env.new_scope();
for _ in 0..num_stmts { for _ in 1..num_stmts {
let next = generate_random_expression(rng, env); let mut next = generate_random_expression(rng, env);
last_type = next.type_of(); let next_type = next.type_of();
if !next_type.is_void() {
let name = generate_random_name(rng);
env.insert(name.clone(), next_type.clone());
next = Expression::Bind(Location::manufactured(), name, next_type, Box::new(next));
}
stmts.push(next); stmts.push(next);
} }
let last_expr = generate_random_expression(rng, env);
let last_type = last_expr.type_of();
stmts.push(last_expr);
env.release_scope(); env.release_scope();
Expression::Block(Location::manufactured(), last_type, stmts) Expression::Block(Location::manufactured(), last_type, stmts)

View File

@@ -6,7 +6,10 @@ use crate::{
use internment::ArcIntern; use internment::ArcIntern;
use pretty::{BoxAllocator, DocAllocator, Pretty}; use pretty::{BoxAllocator, DocAllocator, Pretty};
use proptest::arbitrary::Arbitrary; use proptest::arbitrary::Arbitrary;
use std::{fmt, str::FromStr, sync::atomic::AtomicUsize}; use std::convert::TryFrom;
use std::fmt;
use std::str::FromStr;
use std::sync::atomic::AtomicUsize;
use super::arbitrary::ProgramGenerator; use super::arbitrary::ProgramGenerator;
@@ -97,6 +100,19 @@ pub enum TopLevel<Type> {
Function(Variable, Vec<(Variable, Type)>, Type, Expression<Type>), Function(Variable, Vec<(Variable, Type)>, Type, Expression<Type>),
} }
impl<T: Clone + TypeWithVoid + TypeWithFunction> TopLevel<T> {
/// Return the type of the item, as inferred or recently
/// computed.
pub fn type_of(&self) -> T {
match self {
TopLevel::Statement(expr) => expr.type_of(),
TopLevel::Function(_, args, ret, _) => {
T::build_function_type(args.iter().map(|(_, t)| t.clone()).collect(), ret.clone())
}
}
}
}
impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b TopLevel<Type> impl<'a, 'b, D, A, Type> Pretty<'a, D, A> for &'b TopLevel<Type>
where where
A: 'a, A: 'a,
@@ -148,7 +164,7 @@ pub enum Expression<Type> {
} }
impl<Type: Clone + TypeWithVoid> Expression<Type> { impl<Type: Clone + TypeWithVoid> Expression<Type> {
/// Return a reference to the type of the expression, as inferred or recently /// Return the type of the expression, as inferred or recently
/// computed. /// computed.
pub fn type_of(&self) -> Type { pub fn type_of(&self) -> Type {
match self { match self {
@@ -242,6 +258,16 @@ where
} }
} }
impl Expression<Type> {
pub fn to_pretty(&self) -> String {
let arena = pretty::Arena::<()>::new();
let doc = self.pretty(&arena);
let mut output_bytes = Vec::new();
doc.render(72, &mut output_bytes).unwrap();
String::from_utf8(output_bytes).expect("pretty generates valid utf-8")
}
}
/// A type representing the primitives allowed in the language. /// A type representing the primitives allowed in the language.
/// ///
/// Having this as an enumeration avoids a lot of "this should not happen" /// Having this as an enumeration avoids a lot of "this should not happen"
@@ -565,27 +591,56 @@ impl TypeOrVar {
} }
} }
impl PartialEq<Type> for TypeOrVar {
fn eq(&self, other: &Type) -> bool {
match other {
Type::Function(a, b) => match self {
TypeOrVar::Function(x, y) => x == a && y.as_ref() == b.as_ref(),
_ => false,
},
Type::Primitive(a) => match self {
TypeOrVar::Primitive(x) => a == x,
_ => false,
},
}
}
}
pub trait TypeWithVoid { pub trait TypeWithVoid {
fn void() -> Self; fn void() -> Self;
fn is_void(&self) -> bool;
} }
impl TypeWithVoid for Type { impl TypeWithVoid for Type {
fn void() -> Self { fn void() -> Self {
Type::Primitive(PrimitiveType::Void) Type::Primitive(PrimitiveType::Void)
} }
fn is_void(&self) -> bool {
self == &Type::Primitive(PrimitiveType::Void)
}
} }
impl TypeWithVoid for TypeOrVar { impl TypeWithVoid for TypeOrVar {
fn void() -> Self { fn void() -> Self {
TypeOrVar::Primitive(PrimitiveType::Void) TypeOrVar::Primitive(PrimitiveType::Void)
} }
fn is_void(&self) -> bool {
self == &TypeOrVar::Primitive(PrimitiveType::Void)
}
} }
//impl From<Type> for TypeOrVar { pub trait TypeWithFunction: Sized {
// fn from(value: Type) -> Self { fn build_function_type(arg_types: Vec<Self>, ret_type: Self) -> Self;
// TypeOrVar::Type(value) }
// }
//} impl TypeWithFunction for Type {
fn build_function_type(arg_types: Vec<Self>, ret_type: Self) -> Self {
Type::Function(arg_types, Box::new(ret_type))
}
}
impl<T: Into<Type>> From<T> for TypeOrVar { impl<T: Into<Type>> From<T> for TypeOrVar {
fn from(value: T) -> Self { fn from(value: T) -> Self {
@@ -598,3 +653,24 @@ impl<T: Into<Type>> From<T> for TypeOrVar {
} }
} }
} }
impl TryFrom<TypeOrVar> for Type {
type Error = TypeOrVar;
fn try_from(value: TypeOrVar) -> Result<Self, Self::Error> {
match value {
TypeOrVar::Function(args, ret) => {
let args = args
.into_iter()
.map(Type::try_from)
.collect::<Result<_, _>>()?;
let ret = Type::try_from(*ret)?;
Ok(Type::Function(args, Box::new(ret)))
}
TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)),
_ => Err(value),
}
}
}

45
src/ir/top_level.rs Normal file
View File

@@ -0,0 +1,45 @@
use crate::ir::{Expression, Program, TopLevel, TypeWithFunction, TypeWithVoid, Variable};
use std::collections::HashMap;
impl<T: Clone + TypeWithVoid + TypeWithFunction> Program<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this program.
pub fn get_top_level_variables(&self) -> HashMap<Variable, T> {
let mut result = HashMap::new();
for item in self.items.iter() {
result.extend(item.get_top_level_variables());
}
result
}
}
impl<T: Clone + TypeWithVoid + TypeWithFunction> TopLevel<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this top-level item.
///
/// For functions, this is the function name. For expressions this can be a little
/// bit more complicated, as it sort of depends on the block structuring.
pub fn get_top_level_variables(&self) -> HashMap<Variable, T> {
match self {
TopLevel::Function(name, _, _, _) => HashMap::from([(name.clone(), self.type_of())]),
TopLevel::Statement(expr) => expr.get_top_level_variables(),
}
}
}
impl<T: Clone> Expression<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this expression. Basically, returns the variable named in bind.
pub fn get_top_level_variables(&self) -> HashMap<Variable, T> {
match self {
Expression::Bind(_, name, ty, expr) => {
let mut tlvs = expr.get_top_level_variables();
tlvs.insert(name.clone(), ty.clone());
tlvs
},
_ => HashMap::new(),
}
}
}

View File

@@ -154,6 +154,19 @@ impl PartialEq for Expression {
} }
} }
impl Expression {
/// Get the location of the expression in the source file (if there is one).
pub fn location(&self) -> &Location {
match self {
Expression::Value(loc, _) => loc,
Expression::Reference(loc, _) => loc,
Expression::Cast(loc, _, _) => loc,
Expression::Primitive(loc, _, _) => loc,
Expression::Block(loc, _) => loc,
}
}
}
/// A value from the source syntax /// A value from the source syntax
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Value { pub enum Value {

View File

@@ -270,6 +270,7 @@ impl TryFrom<i64> for ConstantType {
21 => Ok(ConstantType::I16), 21 => Ok(ConstantType::I16),
22 => Ok(ConstantType::I32), 22 => Ok(ConstantType::I32),
23 => Ok(ConstantType::I64), 23 => Ok(ConstantType::I64),
255 => Ok(ConstantType::Void),
_ => Err(InvalidConstantType::Value(value)), _ => Err(InvalidConstantType::Value(value)),
} }
} }

View File

@@ -4,6 +4,7 @@ use crate::syntax::{self, ConstantType};
use crate::type_infer::solve::Constraint; use crate::type_infer::solve::Constraint;
use crate::util::scoped_map::ScopedMap; use crate::util::scoped_map::ScopedMap;
use internment::ArcIntern; use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr; use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the /// This function takes a syntactic program and converts it into the IR version of the
@@ -19,7 +20,7 @@ pub fn convert_program(
let mut constraint_db = Vec::new(); let mut constraint_db = Vec::new();
let mut items = Vec::new(); let mut items = Vec::new();
let mut renames = ScopedMap::new(); let mut renames = ScopedMap::new();
let mut bindings = ScopedMap::new(); let mut bindings = HashMap::new();
for item in program.items.drain(..) { for item in program.items.drain(..) {
items.push(convert_top_level( items.push(convert_top_level(
@@ -40,43 +41,68 @@ pub fn convert_top_level(
top_level: syntax::TopLevel, top_level: syntax::TopLevel,
constraint_db: &mut Vec<Constraint>, constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>, renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut ScopedMap<ArcIntern<String>, ir::TypeOrVar>, bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> ir::TopLevel<ir::TypeOrVar> { ) -> ir::TopLevel<ir::TypeOrVar> {
match top_level { match top_level {
syntax::TopLevel::Function(name, args, expr) => { syntax::TopLevel::Function(name, args, expr) => {
// First, let us figure out what we're going to name this function. If the user // First, at some point we're going to want to know a location for this function,
// which should either be the name if we have one, or the body if we don't.
let function_location = match name {
None => expr.location().clone(),
Some(ref name) => name.location.clone(),
};
// Next, let us figure out what we're going to name this function. If the user
// didn't provide one, we'll just call it "function:<something>" for them. (We'll // didn't provide one, we'll just call it "function:<something>" for them. (We'll
// want a name for this function, eventually, so we might as well do it now.) // want a name for this function, eventually, so we might as well do it now.)
// //
// If they did provide a name, see if we're shadowed. IF we are, then we'll have // If they did provide a name, see if we're shadowed. IF we are, then we'll have
// to specialize the name a bit. Otherwise we'll stick with their name. // to specialize the name a bit. Otherwise we'll stick with their name.
let funname = match name { let function_name = match name {
None => ir::gensym("function"), None => ir::gensym("function"),
Some(unbound) => finalize_name(bindings, renames, unbound), Some(unbound) => finalize_name(bindings, renames, unbound),
}; };
// Now we manufacture types for the inputs and outputs, and then a type for the // This function is going to have a type. We don't know what it is, but it'll have
// function itself. We're not going to make any claims on these types, yet; they're // one.
// all just unknown type variables we need to work out. let function_type = ir::TypeOrVar::new();
let argtypes: Vec<ir::TypeOrVar> = args.iter().map(|_| ir::TypeOrVar::new()).collect(); bindings.insert(function_name.clone(), function_type.clone());
// Then, let's figure out what to do with the argument names, which similarly
// may need to be renamed. We'll also generate some new type variables to associate
// with all of them.
//
// Note that we want to do all this in a new renaming scope, so that we shadow
// appropriately.
renames.new_scope();
let arginfo = args
.iter()
.map(|name| {
let new_type = ir::TypeOrVar::new();
constraint_db.push(Constraint::IsSomething(
name.location.clone(),
new_type.clone(),
));
let new_name = finalize_name(bindings, renames, name.clone());
bindings.insert(new_name.clone(), new_type.clone());
(new_name, new_type)
})
.collect::<Vec<_>>();
// Now we manufacture types for the outputs and then a type for the function itself.
// We're not going to make any claims on these types, yet; they're all just unknown
// type variables we need to work out.
let rettype = ir::TypeOrVar::new(); let rettype = ir::TypeOrVar::new();
let funtype = ir::TypeOrVar::Function(argtypes.clone(), Box::new(rettype.clone())); let actual_function_type = ir::TypeOrVar::Function(
arginfo.iter().map(|x| x.1.clone()).collect(),
// Now let's bind these types into the environment. First, we bind our function Box::new(rettype.clone()),
// namae to the function type we just generated. );
bindings.insert(funname.clone(), funtype); constraint_db.push(Constraint::Equivalent(
// And then we attach the argument names to the argument types. (We have to go function_location,
// convert all the names, first.) function_type,
let iargs: Vec<ArcIntern<String>> = actual_function_type,
args.iter().map(|x| ArcIntern::new(x.to_string())).collect(); ));
assert_eq!(argtypes.len(), iargs.len());
let mut function_args = vec![];
for ((arg_name, arg_type), orig_name) in iargs.iter().zip(argtypes).zip(args) {
bindings.insert(arg_name.clone(), arg_type.clone());
function_args.push((arg_name.clone(), arg_type.clone()));
constraint_db.push(Constraint::IsSomething(orig_name.location, arg_type));
}
// Now let's convert the body over to the new IR.
let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings);
constraint_db.push(Constraint::Equivalent( constraint_db.push(Constraint::Equivalent(
expr.location().clone(), expr.location().clone(),
@@ -84,7 +110,10 @@ pub fn convert_top_level(
ty, ty,
)); ));
ir::TopLevel::Function(funname, function_args, rettype, expr) // Remember to exit this scoping level!
renames.release_scope();
ir::TopLevel::Function(function_name, arginfo, rettype, expr)
} }
syntax::TopLevel::Statement(stmt) => { syntax::TopLevel::Statement(stmt) => {
@@ -108,7 +137,7 @@ fn convert_statement(
statement: syntax::Statement, statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>, constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>, renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut ScopedMap<ArcIntern<String>, ir::TypeOrVar>, bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> ir::Expression<ir::TypeOrVar> { ) -> ir::Expression<ir::TypeOrVar> {
match statement { match statement {
syntax::Statement::Print(loc, name) => { syntax::Statement::Print(loc, name) => {
@@ -152,7 +181,7 @@ fn convert_expression(
expression: syntax::Expression, expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>, constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>, renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut ScopedMap<ArcIntern<String>, ir::TypeOrVar>, bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> (ir::Expression<ir::TypeOrVar>, ir::TypeOrVar) { ) -> (ir::Expression<ir::TypeOrVar>, ir::TypeOrVar) {
match expression { match expression {
// converting values is mostly tedious, because there's so many cases // converting values is mostly tedious, because there's so many cases
@@ -339,7 +368,7 @@ fn finalize_expression(
} }
fn finalize_name( fn finalize_name(
bindings: &ScopedMap<ArcIntern<String>, ir::TypeOrVar>, bindings: &HashMap<ArcIntern<String>, ir::TypeOrVar>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>, renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
name: syntax::Name, name: syntax::Name,
) -> ArcIntern<String> { ) -> ArcIntern<String> {

View File

@@ -105,7 +105,7 @@ fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type {
TypeOrVar::Primitive(x) => Type::Primitive(x), TypeOrVar::Primitive(x) => Type::Primitive(x),
TypeOrVar::Variable(_, tvar) => match resolutions.get(&tvar) { TypeOrVar::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar), None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => Type::Primitive(*pt), Some(pt) => pt.clone(),
}, },
TypeOrVar::Function(mut args, ret) => Type::Function( TypeOrVar::Function(mut args, ret) => Type::Function(
args.drain(..) args.drain(..)

View File

@@ -1,5 +1,5 @@
use crate::eval::PrimitiveType; use crate::eval::PrimitiveType;
use crate::ir::{Primitive, TypeOrVar}; use crate::ir::{Primitive, Type, TypeOrVar};
use crate::syntax::Location; use crate::syntax::Location;
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use internment::ArcIntern; use internment::ArcIntern;
@@ -50,7 +50,7 @@ impl fmt::Display for Constraint {
} }
} }
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>; pub type TypeResolutions = HashMap<ArcIntern<String>, Type>;
/// The results of type inference; like [`Result`], but with a bit more information. /// The results of type inference; like [`Result`], but with a bit more information.
/// ///
@@ -257,18 +257,21 @@ pub fn solve_constraints(
) -> TypeInferenceResult<TypeResolutions> { ) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![]; let mut errors = vec![];
let mut warnings = vec![]; let mut warnings = vec![];
let mut resolutions = HashMap::new(); let mut resolutions: HashMap<ArcIntern<String>, Type> = HashMap::new();
let mut changed_something = true; let mut changed_something = true;
println!("CONSTRAINTS:");
for constraint in constraint_db.iter() {
println!("{}", constraint);
}
// We want to run this inference endlessly, until either we have solved all of our // We want to run this inference endlessly, until either we have solved all of our
// constraints. Internal to the loop, we have a check that will make sure that we // constraints. Internal to the loop, we have a check that will make sure that we
// do (eventually) stop. // do (eventually) stop.
while changed_something && !constraint_db.is_empty() { while changed_something && !constraint_db.is_empty() {
println!("CONSTRAINT:");
for constraint in constraint_db.iter() {
println!(" {}", constraint);
}
println!("RESOLUTIONS:");
for (name, ty) in resolutions.iter() {
println!(" {} = {}", name, ty);
}
// Set this to false at the top of the loop. We'll set this to true if we make // Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going // progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out. // into an infinite look when we can't figure stuff out.
@@ -292,9 +295,13 @@ pub fn solve_constraints(
Constraint::IsSomething(_, TypeOrVar::Function(_, _)) Constraint::IsSomething(_, TypeOrVar::Function(_, _))
| Constraint::IsSomething(_, TypeOrVar::Primitive(_)) => changed_something = true, | Constraint::IsSomething(_, TypeOrVar::Primitive(_)) => changed_something = true,
// Otherwise, we'll keep looking for it. // Otherwise, see if we've resolved this variable to anything. If not, add it
Constraint::IsSomething(_, TypeOrVar::Variable(_, _)) => { // back.
constraint_db.push(constraint); Constraint::IsSomething(_, TypeOrVar::Variable(_, ref name)) => {
if resolutions.get(name).is_none() {
constraint_db.push(constraint);
}
changed_something = true;
} }
// Case #1a: We have two primitive types. If they're equal, we've discharged this // Case #1a: We have two primitive types. If they're equal, we've discharged this
@@ -311,35 +318,7 @@ pub fn solve_constraints(
changed_something = true; changed_something = true;
} }
// Case #2: One of the two constraints is a primitive, and the other is a variable. // Case #2: They're both variables. In which case, we'll have to do much the same
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(
loc,
TypeOrVar::Primitive(t),
TypeOrVar::Variable(_, name),
)
| Constraint::Equivalent(
loc,
TypeOrVar::Variable(_, name),
TypeOrVar::Primitive(t),
) => {
match resolutions.get(&name) {
None => {
resolutions.insert(name, t);
}
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(
loc,
TypeOrVar::Primitive(t),
TypeOrVar::Primitive(*t2),
)),
}
changed_something = true;
}
// Case #3: They're both variables. In which case, we'll have to do much the same
// check, but now on their resolutions. // check, but now on their resolutions.
Constraint::Equivalent( Constraint::Equivalent(
ref loc, ref loc,
@@ -350,11 +329,11 @@ pub fn solve_constraints(
constraint_db.push(constraint); constraint_db.push(constraint);
} }
(Some(pt), None) => { (Some(pt), None) => {
resolutions.insert(name2.clone(), *pt); resolutions.insert(name2.clone(), pt.clone());
changed_something = true; changed_something = true;
} }
(None, Some(pt)) => { (None, Some(pt)) => {
resolutions.insert(name1.clone(), *pt); resolutions.insert(name1.clone(), pt.clone());
changed_something = true; changed_something = true;
} }
(Some(pt1), Some(pt2)) if pt1 == pt2 => { (Some(pt1), Some(pt2)) if pt1 == pt2 => {
@@ -363,13 +342,43 @@ pub fn solve_constraints(
(Some(pt1), Some(pt2)) => { (Some(pt1), Some(pt2)) => {
errors.push(TypeInferenceError::NotEquivalent( errors.push(TypeInferenceError::NotEquivalent(
loc.clone(), loc.clone(),
TypeOrVar::Primitive(*pt1), pt1.clone().into(),
TypeOrVar::Primitive(*pt2), pt2.clone().into(),
)); ));
changed_something = true; changed_something = true;
} }
}, },
// Case #3: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(loc, t, TypeOrVar::Variable(vloc, name))
| Constraint::Equivalent(loc, TypeOrVar::Variable(vloc, name), t) => {
match resolutions.get(&name) {
None => match t.try_into() {
Ok(real_type) => {
resolutions.insert(name, real_type);
}
Err(variable_type) => {
constraint_db.push(Constraint::Equivalent(
loc,
variable_type,
TypeOrVar::Variable(vloc, name),
));
continue;
}
},
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(
loc,
t,
t2.clone().into(),
)),
}
changed_something = true;
}
// Case #4: Like primitives, but for function types. This is a little complicated, because // Case #4: Like primitives, but for function types. This is a little complicated, because
// we first want to resolve all the type variables in the two types, and then see if they're // we first want to resolve all the type variables in the two types, and then see if they're
// equivalent. Fortunately, though, we can cheat a bit. What we're going to do is first see // equivalent. Fortunately, though, we can cheat a bit. What we're going to do is first see
@@ -445,7 +454,7 @@ pub fn solve_constraints(
Some(nt) => { Some(nt) => {
constraint_db.push(Constraint::FitsInNumType( constraint_db.push(Constraint::FitsInNumType(
loc, loc,
TypeOrVar::Primitive(*nt), nt.clone().into(),
val, val,
)); ));
changed_something = true; changed_something = true;
@@ -474,7 +483,7 @@ pub fn solve_constraints(
Some(nt) => { Some(nt) => {
constraint_db.push(Constraint::CanCastTo( constraint_db.push(Constraint::CanCastTo(
loc, loc,
TypeOrVar::Primitive(*nt), nt.clone().into(),
to_type, to_type,
)); ));
changed_something = true; changed_something = true;
@@ -494,7 +503,7 @@ pub fn solve_constraints(
constraint_db.push(Constraint::CanCastTo( constraint_db.push(Constraint::CanCastTo(
loc, loc,
from_type, from_type,
TypeOrVar::Primitive(*nt), nt.clone().into(),
)); ));
changed_something = true; changed_something = true;
} }
@@ -560,8 +569,7 @@ pub fn solve_constraints(
None => constraint_db None => constraint_db
.push(Constraint::NumericType(loc, TypeOrVar::Variable(vloc, var))), .push(Constraint::NumericType(loc, TypeOrVar::Variable(vloc, var))),
Some(nt) => { Some(nt) => {
constraint_db constraint_db.push(Constraint::NumericType(loc, nt.clone().into()));
.push(Constraint::NumericType(loc, TypeOrVar::Primitive(*nt)));
changed_something = true; changed_something = true;
} }
} }
@@ -592,10 +600,8 @@ pub fn solve_constraints(
TypeOrVar::Variable(vloc, var), TypeOrVar::Variable(vloc, var),
)), )),
Some(nt) => { Some(nt) => {
constraint_db.push(Constraint::ConstantNumericType( constraint_db
loc, .push(Constraint::ConstantNumericType(loc, nt.clone().into()));
TypeOrVar::Primitive(*nt),
));
changed_something = true; changed_something = true;
} }
} }