From 52d5c9252be35b021f2ecc3dedd71b8de423ceef Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Mon, 22 Apr 2024 20:49:44 -0700 Subject: [PATCH] checkpoint --- examples/basic/function0004.ngr | 10 + src/ir/ast.rs | 4 +- src/syntax/ast.rs | 20 +- src/syntax/eval.rs | 35 +- src/syntax/parser.lalrpop | 13 +- src/syntax/pretty.rs | 82 +- src/syntax/validate.rs | 27 +- src/type_infer.rs | 62 +- src/type_infer/constraint.rs | 79 ++ src/type_infer/convert.rs | 945 +++++++++++----------- src/type_infer/error.rs | 146 ++++ src/type_infer/finalize.rs | 86 +- src/type_infer/result.rs | 50 ++ src/type_infer/solve.rs | 1306 ++++++++++++------------------- src/type_infer/warning.rs | 21 + 15 files changed, 1528 insertions(+), 1358 deletions(-) create mode 100644 examples/basic/function0004.ngr create mode 100644 src/type_infer/constraint.rs create mode 100644 src/type_infer/error.rs create mode 100644 src/type_infer/result.rs create mode 100644 src/type_infer/warning.rs diff --git a/examples/basic/function0004.ngr b/examples/basic/function0004.ngr new file mode 100644 index 0000000..6803641 --- /dev/null +++ b/examples/basic/function0004.ngr @@ -0,0 +1,10 @@ +function make_adder(x) + function (y) + x + y; + +add1 = make_adder(1); +add2 = make_adder(2); +one_plus_one = add1(1); +one_plus_three = add1(3); +print one_plus_one; +print one_plus_three; \ No newline at end of file diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 24d7c41..bb88f4a 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -74,6 +74,8 @@ impl Arbitrary for Program { #[derive(Clone, Debug)] pub enum TopLevel { Statement(Expression), + // FIXME: Is the return type actually necessary, given we can infer it from + // the expression type? Function(Variable, Vec<(Variable, Type)>, Type, Expression), } @@ -500,4 +502,4 @@ fn struct_sizes_are_rational() { assert_eq!(272, std::mem::size_of::>()); assert_eq!(72, std::mem::size_of::>()); assert_eq!(72, std::mem::size_of::>()); -} \ No newline at end of file +} diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index 726ad5a..e41d443 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -27,12 +27,6 @@ pub struct Program { #[derive(Clone, Debug, PartialEq)] pub enum TopLevel { Expression(Expression), - Function( - Option, - Vec<(Name, Option)>, - Option, - Expression, - ), Structure(Location, Name, Vec<(Name, Type)>), } @@ -103,6 +97,13 @@ pub enum Expression { Call(Location, Box, Vec), Block(Location, Vec), Binding(Location, Name, Box), + Function( + Location, + Option, + Vec<(Name, Option)>, + Option, + Box, + ), } impl Expression { @@ -157,6 +158,12 @@ impl PartialEq for Expression { Expression::Binding(_, name2, expr2) => name1 == name2 && expr1 == expr2, _ => false, }, + Expression::Function(_, mname1, args1, mret1, body1) => match other { + Expression::Function(_, mname2, args2, mret2, body2) => { + mname1 == mname2 && args1 == args2 && mret1 == mret2 && body1 == body2 + } + _ => false, + }, } } } @@ -174,6 +181,7 @@ impl Expression { Expression::Call(loc, _, _) => loc, Expression::Block(loc, _) => loc, Expression::Binding(loc, _, _) => loc, + Expression::Function(loc, _, _, _, _) => loc, } } } diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index d6185ec..ae45ed8 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -24,22 +24,6 @@ impl Program { for stmt in self.items.iter() { match stmt { - TopLevel::Function(name, arg_names, _, body) => { - last_result = Value::Closure( - name.clone().map(Name::intern), - env.clone(), - arg_names - .iter() - .cloned() - .map(|(x, _)| Name::intern(x)) - .collect(), - body.clone(), - ); - if let Some(name) = name { - env.insert(name.clone().intern(), last_result.clone()); - } - } - TopLevel::Expression(expr) => last_result = expr.eval(&mut stdout, &mut env)?, TopLevel::Structure(_, _, _) => { @@ -191,6 +175,25 @@ impl Expression { env.insert(name.clone().intern(), actual_value.clone()); Ok(actual_value) } + + Expression::Function(_, name, arg_names, _, body) => { + let result = Value::Closure( + name.clone().map(Name::intern), + env.clone(), + arg_names + .iter() + .cloned() + .map(|(x, _)| Name::intern(x)) + .collect(), + *body.clone(), + ); + + if let Some(name) = name { + env.insert(name.clone().intern(), result.clone()); + } + + Ok(result) + } } } } diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index ed42ca4..4522807 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -79,16 +79,10 @@ ProgramTopLevel: Vec = { } pub TopLevel: TopLevel = { - => f, => s, ";" => TopLevel::Expression(s), } -Function: TopLevel = { - "function" "(" > ")" " Type)?> ";" => - TopLevel::Function(opt_name, args, ret.map(|x| x.1), exp), -} - Argument: (Name, Option) = { "> => (Name::new(v, Location::new(file_idx, name_start..name_end)), t.map(|v| v.1)), @@ -170,6 +164,13 @@ BindingExpression: Expression = { Box::new(e), ), + FunctionExpression, +} + +FunctionExpression: Expression = { + "function" "(" > ")" " Type)?> => + Expression::Function(Location::new(file_idx, s..e), opt_name, args, ret.map(|x| x.1), Box::new(exp)), + PrintExpression, } diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index 81eabe2..a9caf3b 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -21,47 +21,6 @@ impl TopLevel { pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> { match self { TopLevel::Expression(expr) => expr.pretty(allocator), - TopLevel::Function(name, args, rettype, body) => allocator - .text("function") - .append(allocator.space()) - .append( - name.as_ref() - .map(|x| allocator.text(x.to_string())) - .unwrap_or_else(|| allocator.nil()), - ) - .append( - allocator - .intersperse( - args.iter().map(|(x, t)| { - allocator.text(x.to_string()).append( - t.as_ref() - .map(|t| { - allocator - .text(":") - .append(allocator.space()) - .append(t.pretty(allocator)) - }) - .unwrap_or_else(|| allocator.nil()), - ) - }), - allocator.text(","), - ) - .parens(), - ) - .append( - rettype - .as_ref() - .map(|rettype| { - allocator - .space() - .append(allocator.text("->")) - .append(allocator.space()) - .append(rettype.pretty(allocator)) - }) - .unwrap_or_else(|| allocator.nil()), - ) - .append(allocator.space()) - .append(body.pretty(allocator)), TopLevel::Structure(_, name, fields) => allocator .text("struct") .append(allocator.space()) @@ -148,6 +107,47 @@ impl Expression { .append(allocator.text("=")) .append(allocator.space()) .append(expr.pretty(allocator)), + Expression::Function(_, name, args, rettype, body) => allocator + .text("function") + .append(allocator.space()) + .append( + name.as_ref() + .map(|x| allocator.text(x.to_string())) + .unwrap_or_else(|| allocator.nil()), + ) + .append( + allocator + .intersperse( + args.iter().map(|(x, t)| { + allocator.text(x.to_string()).append( + t.as_ref() + .map(|t| { + allocator + .text(":") + .append(allocator.space()) + .append(t.pretty(allocator)) + }) + .unwrap_or_else(|| allocator.nil()), + ) + }), + allocator.text(","), + ) + .parens(), + ) + .append( + rettype + .as_ref() + .map(|rettype| { + allocator + .space() + .append(allocator.text("->")) + .append(allocator.space()) + .append(rettype.pretty(allocator)) + }) + .unwrap_or_else(|| allocator.nil()), + ) + .append(allocator.space()) + .append(body.pretty(allocator)), } } } diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index c95b2bc..afbba2c 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -84,9 +84,6 @@ impl Program { let mut warnings = vec![]; for stmt in self.items.iter() { - if let TopLevel::Function(Some(name), _, _, _) = stmt { - bound_variables.insert(name.to_string(), name.location.clone()); - } let (mut new_errors, mut new_warnings) = stmt.validate_with_bindings(bound_variables); errors.append(&mut new_errors); warnings.append(&mut new_warnings); @@ -119,18 +116,6 @@ impl TopLevel { bound_variables: &mut ScopedMap, ) -> (Vec, Vec) { match self { - TopLevel::Function(name, arguments, _, body) => { - bound_variables.new_scope(); - if let Some(name) = name { - bound_variables.insert(name.name.clone(), name.location.clone()); - } - for (arg, _) in arguments.iter() { - bound_variables.insert(arg.name.clone(), arg.location.clone()); - } - let result = body.validate(bound_variables); - bound_variables.release_scope(); - result - } TopLevel::Expression(expr) => expr.validate(bound_variables), TopLevel::Structure(_, _, _) => (vec![], vec![]), } @@ -214,6 +199,18 @@ impl Expression { (errors, warnings) } + Expression::Function(_, name, arguments, _, body) => { + if let Some(name) = name { + variable_map.insert(name.name.clone(), name.location.clone()); + } + variable_map.new_scope(); + for (arg, _) in arguments.iter() { + variable_map.insert(arg.name.clone(), arg.location.clone()); + } + let result = body.validate(variable_map); + variable_map.release_scope(); + result + } } } } diff --git a/src/type_infer.rs b/src/type_infer.rs index 475bf9f..94bb3cb 100644 --- a/src/type_infer.rs +++ b/src/type_infer.rs @@ -10,20 +10,43 @@ //! all the constraints we've generated. If that's successful, in the final phase, we //! do the final conversion to the IR AST, filling in any type information we've learned //! along the way. +mod constraint; mod convert; +mod error; mod finalize; +mod result; mod solve; +mod warning; -use self::convert::convert_program; -use self::finalize::finalize_program; -use self::solve::solve_constraints; -pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning}; +use self::constraint::Constraint; +use self::error::TypeInferenceError; +pub use self::result::TypeInferenceResult; +use self::warning::TypeInferenceWarning; use crate::ir::ast as ir; use crate::syntax; #[cfg(test)] use crate::syntax::arbitrary::GenerationEnvironment; +use internment::ArcIntern; #[cfg(test)] use proptest::prelude::Arbitrary; +use std::collections::HashMap; + +#[derive(Default)] +struct InferenceEngine { + constraints: Vec, + type_definitions: HashMap, ir::TypeOrVar>, + variable_types: HashMap, ir::TypeOrVar>, + functions: HashMap< + ArcIntern, + ( + Vec<(ArcIntern, ir::TypeOrVar)>, + ir::Expression, + ), + >, + statements: Vec>, + errors: Vec, + warnings: Vec, +} impl syntax::Program { /// Infer the types for the syntactic AST, returning either a type-checked program in @@ -32,10 +55,35 @@ impl syntax::Program { /// You really should have made sure that this program was validated before running /// this method, otherwise you may experience panics during operation. pub fn type_infer(self) -> TypeInferenceResult> { - let (program, constraint_db) = convert_program(self); - let inference_result = solve_constraints(&program.type_definitions, constraint_db); + let mut engine = InferenceEngine::default(); + engine.injest_program(self); + engine.solve_constraints(); - inference_result.map(|resolutions| finalize_program(program, &resolutions)) + if engine.errors.is_empty() { + let resolutions = std::mem::take(&mut engine.constraints) + .into_iter() + .map(|constraint| match constraint { + Constraint::Equivalent(_, ir::TypeOrVar::Variable(_, name), result) => { + match result.try_into() { + Err(e) => panic!("Ended up with complex type {}", e), + Ok(v) => (name, v), + } + } + _ => panic!("Had something that wasn't an equivalence left at the end!"), + }) + .collect(); + let warnings = std::mem::take(&mut engine.warnings); + + TypeInferenceResult::Success { + result: engine.finalize_program(resolutions), + warnings, + } + } else { + TypeInferenceResult::Failure { + errors: engine.errors, + warnings: engine.warnings, + } + } } } diff --git a/src/type_infer/constraint.rs b/src/type_infer/constraint.rs new file mode 100644 index 0000000..e6520fb --- /dev/null +++ b/src/type_infer/constraint.rs @@ -0,0 +1,79 @@ +use crate::ir::TypeOrVar; +use crate::syntax::Location; +use internment::ArcIntern; +use std::fmt; + +/// A type inference constraint that we're going to need to solve. +#[derive(Debug)] +pub enum Constraint { + /// The given type must be printable using the `print` built-in + Printable(Location, TypeOrVar), + /// The provided numeric value fits in the given constant type + FitsInNumType(Location, TypeOrVar, u64), + /// The given type can be casted to the target type safely + CanCastTo(Location, TypeOrVar, TypeOrVar), + /// The given type has the given field in it, and the type of that field + /// is as given. + TypeHasField(Location, TypeOrVar, ArcIntern, TypeOrVar), + /// The given type must be some numeric type, but this is not a constant + /// value, so don't try to default it if we can't figure it out + NumericType(Location, TypeOrVar), + /// The given type is attached to a constant and must be some numeric type. + /// If we can't figure it out, we should warn the user and then just use a + /// default. + ConstantNumericType(Location, TypeOrVar), + /// The two types should be equivalent + Equivalent(Location, TypeOrVar, TypeOrVar), + /// The given type can be resolved to something + IsSomething(Location, TypeOrVar), + /// The given type can be negated + IsSigned(Location, TypeOrVar), + /// Checks to see if the given named type is equivalent to the provided one. + NamedTypeIs(Location, ArcIntern, TypeOrVar), +} + +impl fmt::Display for Constraint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty), + Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty), + Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), + Constraint::TypeHasField(_, ty1, field, ty2) => { + write!(f, "FIELD {}.{} -> {}", ty1, field, ty2) + } + Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), + Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty), + Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), + Constraint::IsSomething(_, ty) => write!(f, "SOMETHING {}", ty), + Constraint::IsSigned(_, ty) => write!(f, "SIGNED {}", ty), + Constraint::NamedTypeIs(_, name, ty) => write!(f, "TYPE_EQUIV {} == {}", name, ty), + } + } +} + +impl Constraint { + /// Replace all instances of the name (anywhere! including on the left hand side of equivalences!) + /// with the given type. + /// + /// Returns whether or not anything was changed in the constraint. + pub fn replace(&mut self, name: &ArcIntern, replace_with: &TypeOrVar) -> bool { + match self { + Constraint::Printable(_, ty) => ty.replace(name, replace_with), + Constraint::FitsInNumType(_, ty, _) => ty.replace(name, replace_with), + Constraint::CanCastTo(_, ty1, ty2) => { + ty1.replace(name, replace_with) || ty2.replace(name, replace_with) + } + Constraint::TypeHasField(_, ty1, _, ty2) => { + ty1.replace(name, replace_with) || ty2.replace(name, replace_with) + } + Constraint::ConstantNumericType(_, ty) => ty.replace(name, replace_with), + Constraint::Equivalent(_, ty1, ty2) => { + ty1.replace(name, replace_with) || ty2.replace(name, replace_with) + } + Constraint::IsSigned(_, ty) => ty.replace(name, replace_with), + Constraint::IsSomething(_, ty) => ty.replace(name, replace_with), + Constraint::NumericType(_, ty) => ty.replace(name, replace_with), + Constraint::NamedTypeIs(_, name, ty) => ty.replace(name, replace_with), + } + } +} diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index a8fe5c0..8266eee 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -1,472 +1,541 @@ +use super::constraint::Constraint; +use super::InferenceEngine; use crate::eval::PrimitiveType; use crate::ir; use crate::syntax::{self, ConstantType}; -use crate::type_infer::solve::Constraint; use crate::util::scoped_map::ScopedMap; use internment::ArcIntern; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; -enum TopLevelItem { - Type(ArcIntern, ir::TypeOrVar), - Value(ir::TopLevel), +struct ExpressionInfo { + expression: ir::Expression, + result_type: ir::TypeOrVar, + free_variables: HashSet>, + bound_variables: HashSet>, } -/// This function takes a syntactic program and converts it into the IR version of the -/// program, with appropriate type variables introduced and their constraints added to -/// the given database. -/// -/// If the input function has been validated (which it should be), then this should run -/// into no error conditions. However, if you failed to validate the input, then this -/// function can panic. -pub fn convert_program( - mut program: syntax::Program, -) -> (ir::Program, Vec) { - let mut constraint_db = Vec::new(); - let mut items = Vec::new(); - let mut renames = ScopedMap::new(); - let mut bindings = HashMap::new(); - let mut type_definitions = HashMap::new(); +impl ExpressionInfo { + fn simple(expression: ir::Expression, result_type: ir::TypeOrVar) -> Self { + ExpressionInfo { + expression, + result_type, + free_variables: HashSet::new(), + bound_variables: HashSet::new(), + } + } +} - for item in program.items.drain(..) { - let tli = convert_top_level(item, &mut constraint_db, &mut renames, &mut bindings); +impl InferenceEngine { + /// This function takes a syntactic program and converts it into the IR version of the + /// program, with appropriate type variables introduced and their constraints added to + /// the given database. + /// + /// If the input function has been validated (which it should be), then this should run + /// into no error conditions. However, if you failed to validate the input, then this + /// function can panic. + pub fn injest_program(&mut self, program: syntax::Program) { + let mut renames = ScopedMap::new(); - match tli { - TopLevelItem::Value(item) => items.push(item), - TopLevelItem::Type(name, decl) => { - let _ = type_definitions.insert(name, decl); + for item in program.items.into_iter() { + self.convert_top_level(item, &mut renames); + } + } + + /// This function takes a top-level item and converts it into the IR version of the + /// program, with all the appropriate type variables introduced and their constraints + /// added to the given database. + fn convert_top_level( + &mut self, + top_level: syntax::TopLevel, + renames: &mut ScopedMap, ArcIntern>, + ) { + match top_level { + syntax::TopLevel::Expression(expr) => { + let expr_info = self.convert_expression(expr, renames); + self.statements.push(expr_info.expression); + } + + syntax::TopLevel::Structure(_loc, name, fields) => { + let mut updated_fields = ir::Fields::default(); + + for (name, field_type) in fields.into_iter() { + updated_fields.insert(name.intern(), self.convert_type(field_type)); + } + + self.type_definitions + .insert(name.intern(), ir::TypeOrVar::Structure(updated_fields)); } } } - ( - ir::Program { - items, - type_definitions, - }, - constraint_db, - ) -} + /// This function takes a syntactic expression and converts it into a series + /// of IR statements, adding type variables and constraints as necessary. + /// + /// We generate a series of statements because we're going to flatten all + /// incoming expressions so that they are no longer recursive. This will + /// generate a bunch of new bindings for all the subexpressions, which we + /// return as a bundle. + /// + /// See the safety warning on [`convert_program`]! This function assumes that + /// you have run [`Statement::validate`], and will trigger panics in error + /// conditions if you have run that and had it come back clean. + fn convert_expression( + &mut self, + expression: syntax::Expression, + renames: &mut ScopedMap, ArcIntern>, + ) -> ExpressionInfo { + match expression { + // converting values is mostly tedious, because there's so many cases + // involved + syntax::Expression::Value(loc, val) => match val { + syntax::Value::Number(base, mctype, value) => { + let (newval, newtype) = match mctype { + None => { + let newtype = ir::TypeOrVar::new(); + let newval = ir::Value::U64(base, value); -/// This function takes a top-level item and converts it into the IR version of the -/// program, with all the appropriate type variables introduced and their constraints -/// added to the given database. -fn convert_top_level( - top_level: syntax::TopLevel, - constraint_db: &mut Vec, - renames: &mut ScopedMap, ArcIntern>, - bindings: &mut HashMap, ir::TypeOrVar>, -) -> TopLevelItem { - match top_level { - syntax::TopLevel::Function(name, args, _, expr) => { - // First, at some point we're going to want to know a location for this function, - // which should either be the name if we have one, or the body if we don't. - let function_location = match name { - None => expr.location().clone(), - Some(ref name) => name.location.clone(), - }; - // Next, let us figure out what we're going to name this function. If the user - // didn't provide one, we'll just call it "function:" for them. (We'll - // want a name for this function, eventually, so we might as well do it now.) - // - // If they did provide a name, see if we're shadowed. IF we are, then we'll have - // to specialize the name a bit. Otherwise we'll stick with their name. - let function_name = match name { - None => ir::gensym("function"), - Some(unbound) => finalize_name(bindings, renames, unbound), - }; + self.constraints.push(Constraint::ConstantNumericType( + loc.clone(), + newtype.clone(), + )); + (newval, newtype) + } + Some(ConstantType::Void) => ( + ir::Value::Void, + ir::TypeOrVar::Primitive(PrimitiveType::Void), + ), + Some(ConstantType::U8) => ( + ir::Value::U8(base, value as u8), + ir::TypeOrVar::Primitive(PrimitiveType::U8), + ), + Some(ConstantType::U16) => ( + ir::Value::U16(base, value as u16), + ir::TypeOrVar::Primitive(PrimitiveType::U16), + ), + Some(ConstantType::U32) => ( + ir::Value::U32(base, value as u32), + ir::TypeOrVar::Primitive(PrimitiveType::U32), + ), + Some(ConstantType::U64) => ( + ir::Value::U64(base, value), + ir::TypeOrVar::Primitive(PrimitiveType::U64), + ), + Some(ConstantType::I8) => ( + ir::Value::I8(base, value as i8), + ir::TypeOrVar::Primitive(PrimitiveType::I8), + ), + Some(ConstantType::I16) => ( + ir::Value::I16(base, value as i16), + ir::TypeOrVar::Primitive(PrimitiveType::I16), + ), + Some(ConstantType::I32) => ( + ir::Value::I32(base, value as i32), + ir::TypeOrVar::Primitive(PrimitiveType::I32), + ), + Some(ConstantType::I64) => ( + ir::Value::I64(base, value as i64), + ir::TypeOrVar::Primitive(PrimitiveType::I64), + ), + }; - // This function is going to have a type. We don't know what it is, but it'll have - // one. - let function_type = ir::TypeOrVar::new(); - bindings.insert(function_name.clone(), function_type.clone()); - - // Then, let's figure out what to do with the argument names, which similarly - // may need to be renamed. We'll also generate some new type variables to associate - // with all of them. - // - // Note that we want to do all this in a new renaming scope, so that we shadow - // appropriately. - renames.new_scope(); - let arginfo = args - .into_iter() - .map(|(name, mut declared_type)| { - let new_type = ir::TypeOrVar::new(); - constraint_db.push(Constraint::IsSomething( - name.location.clone(), - new_type.clone(), + self.constraints.push(Constraint::FitsInNumType( + loc.clone(), + newtype.clone(), + value, )); - let new_name = finalize_name(bindings, renames, name.clone()); - bindings.insert(new_name.clone(), new_type.clone()); - if let Some(declared_type) = declared_type.take() { - let declared_type = convert_type(declared_type, constraint_db); - constraint_db.push(Constraint::Equivalent( - name.location.clone(), - new_type.clone(), - declared_type, - )); - } + ExpressionInfo::simple( + ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), + newtype, + ) + } + }, - (new_name, new_type) - }) - .collect::>(); + syntax::Expression::Constructor(loc, name, fields) => { + let mut result_fields = HashMap::new(); + let mut type_fields = ir::Fields::default(); + let mut prereqs = vec![]; + let mut free_variables = HashSet::new(); + let mut bound_variables = HashSet::new(); - // Now we manufacture types for the outputs and then a type for the function itself. - // We're not going to make any claims on these types, yet; they're all just unknown - // type variables we need to work out. - let rettype = ir::TypeOrVar::new(); - let actual_function_type = ir::TypeOrVar::Function( - arginfo.iter().map(|x| x.1.clone()).collect(), - Box::new(rettype.clone()), - ); - constraint_db.push(Constraint::Equivalent( - function_location, - function_type, - actual_function_type, - )); + for (name, syntax_expr) in fields.into_iter() { + let field_expr_info = self.convert_expression(syntax_expr, renames); + type_fields.insert(name.clone().intern(), field_expr_info.result_type); + let (prereq, value) = simplify_expr(field_expr_info.expression); + result_fields.insert(name.clone().intern(), value); + merge_prereq(&mut prereqs, prereq); + free_variables.extend(field_expr_info.free_variables); + bound_variables.extend(field_expr_info.bound_variables); + } - // Now let's convert the body over to the new IR. - let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); - constraint_db.push(Constraint::Equivalent( - expr.location().clone(), - rettype.clone(), - ty, - )); + let result_type = ir::TypeOrVar::Structure(type_fields); - // Remember to exit this scoping level! - renames.release_scope(); + self.constraints.push(Constraint::NamedTypeIs( + loc.clone(), + name.clone().intern(), + result_type.clone(), + )); + let expression = ir::Expression::Construct( + loc, + result_type.clone(), + name.intern(), + result_fields, + ); - TopLevelItem::Value(ir::TopLevel::Function( - function_name, - arginfo, - rettype, - expr, - )) - } - - syntax::TopLevel::Expression(expr) => TopLevelItem::Value(ir::TopLevel::Statement( - convert_expression(expr, constraint_db, renames, bindings).0, - )), - - syntax::TopLevel::Structure(_loc, name, fields) => { - let mut updated_fields = ir::Fields::default(); - - for (name, field_type) in fields.into_iter() { - updated_fields.insert(name.intern(), convert_type(field_type, constraint_db)); + ExpressionInfo { + expression, + result_type, + free_variables, + bound_variables, + } } - TopLevelItem::Type(name.intern(), ir::TypeOrVar::Structure(updated_fields)) - } - } -} + syntax::Expression::Reference(loc, name) => { + let iname = ArcIntern::new(name); + let final_name = renames.get(&iname).cloned().unwrap_or(iname); + let result_type = self + .variable_types + .get(&final_name) + .cloned() + .expect("variable bound before use"); + let expression = ir::Expression::Atomic(ir::ValueOrRef::Ref( + loc, + result_type.clone(), + final_name.clone(), + )); + let free_variables = HashSet::from([final_name]); -/// This function takes a syntactic expression and converts it into a series -/// of IR statements, adding type variables and constraints as necessary. -/// -/// We generate a series of statements because we're going to flatten all -/// incoming expressions so that they are no longer recursive. This will -/// generate a bunch of new bindings for all the subexpressions, which we -/// return as a bundle. -/// -/// See the safety warning on [`convert_program`]! This function assumes that -/// you have run [`Statement::validate`], and will trigger panics in error -/// conditions if you have run that and had it come back clean. -fn convert_expression( - expression: syntax::Expression, - constraint_db: &mut Vec, - renames: &mut ScopedMap, ArcIntern>, - bindings: &mut HashMap, ir::TypeOrVar>, -) -> (ir::Expression, ir::TypeOrVar) { - match expression { - // converting values is mostly tedious, because there's so many cases - // involved - syntax::Expression::Value(loc, val) => match val { - syntax::Value::Number(base, mctype, value) => { - let (newval, newtype) = match mctype { - None => { - let newtype = ir::TypeOrVar::new(); - let newval = ir::Value::U64(base, value); + ExpressionInfo { + expression, + result_type, + free_variables, + bound_variables: HashSet::new(), + } + } - constraint_db.push(Constraint::ConstantNumericType( - loc.clone(), - newtype.clone(), - )); - (newval, newtype) + syntax::Expression::FieldRef(loc, expr, field) => { + let mut expr_info = self.convert_expression(*expr, renames); + let (prereqs, val_or_ref) = simplify_expr(expr_info.expression); + let result_type = ir::TypeOrVar::new(); + let result = ir::Expression::FieldRef( + loc.clone(), + result_type.clone(), + expr_info.result_type.clone(), + val_or_ref, + field.clone().intern(), + ); + + self.constraints.push(Constraint::TypeHasField( + loc, + expr_info.result_type.clone(), + field.intern(), + result_type.clone(), + )); + + expr_info.expression = finalize_expression(prereqs, result); + expr_info.result_type = result_type; + + expr_info + } + + syntax::Expression::Cast(loc, target, expr) => { + let mut expr_info = self.convert_expression(*expr, renames); + let (prereqs, val_or_ref) = simplify_expr(expr_info.expression); + let target_type: ir::TypeOrVar = PrimitiveType::from_str(&target) + .expect("valid type for cast") + .into(); + let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); + + self.constraints.push(Constraint::CanCastTo( + loc, + expr_info.result_type.clone(), + target_type.clone(), + )); + + expr_info.expression = finalize_expression(prereqs, res); + expr_info.result_type = target_type; + + expr_info + } + + syntax::Expression::Primitive(loc, name) => { + let primop = ir::Primitive::from_str(&name.name).expect("valid primitive"); + + match primop { + ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { + let numeric_type = ir::TypeOrVar::new_located(loc.clone()); + self.constraints + .push(Constraint::NumericType(loc.clone(), numeric_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![numeric_type.clone(), numeric_type.clone()], + Box::new(numeric_type.clone()), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) } - Some(ConstantType::Void) => ( - ir::Value::Void, - ir::TypeOrVar::Primitive(PrimitiveType::Void), - ), - Some(ConstantType::U8) => ( - ir::Value::U8(base, value as u8), - ir::TypeOrVar::Primitive(PrimitiveType::U8), - ), - Some(ConstantType::U16) => ( - ir::Value::U16(base, value as u16), - ir::TypeOrVar::Primitive(PrimitiveType::U16), - ), - Some(ConstantType::U32) => ( - ir::Value::U32(base, value as u32), - ir::TypeOrVar::Primitive(PrimitiveType::U32), - ), - Some(ConstantType::U64) => ( - ir::Value::U64(base, value), - ir::TypeOrVar::Primitive(PrimitiveType::U64), - ), - Some(ConstantType::I8) => ( - ir::Value::I8(base, value as i8), - ir::TypeOrVar::Primitive(PrimitiveType::I8), - ), - Some(ConstantType::I16) => ( - ir::Value::I16(base, value as i16), - ir::TypeOrVar::Primitive(PrimitiveType::I16), - ), - Some(ConstantType::I32) => ( - ir::Value::I32(base, value as i32), - ir::TypeOrVar::Primitive(PrimitiveType::I32), - ), - Some(ConstantType::I64) => ( - ir::Value::I64(base, value as i64), - ir::TypeOrVar::Primitive(PrimitiveType::I64), - ), + + ir::Primitive::Minus => { + let numeric_type = ir::TypeOrVar::new_located(loc.clone()); + self.constraints + .push(Constraint::NumericType(loc.clone(), numeric_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![numeric_type.clone(), numeric_type.clone()], + Box::new(numeric_type.clone()), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + } + + ir::Primitive::Print => { + let arg_type = ir::TypeOrVar::new_located(loc.clone()); + self.constraints + .push(Constraint::Printable(loc.clone(), arg_type.clone())); + let funtype = ir::TypeOrVar::Function( + vec![arg_type], + Box::new(ir::TypeOrVar::Primitive(PrimitiveType::Void)), + ); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + } + + ir::Primitive::Negate => { + let arg_type = ir::TypeOrVar::new_located(loc.clone()); + self.constraints + .push(Constraint::NumericType(loc.clone(), arg_type.clone())); + self.constraints + .push(Constraint::IsSigned(loc.clone(), arg_type.clone())); + let funtype = + ir::TypeOrVar::Function(vec![arg_type.clone()], Box::new(arg_type)); + let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); + ExpressionInfo::simple(ir::Expression::Atomic(result_value), funtype) + } + } + } + + syntax::Expression::Call(loc, fun, args) => { + let return_type = ir::TypeOrVar::new(); + let arg_types = args + .iter() + .map(|_| ir::TypeOrVar::new()) + .collect::>(); + + let mut expr_info = self.convert_expression(*fun, renames); + let target_fun_type = + ir::TypeOrVar::Function(arg_types.clone(), Box::new(return_type.clone())); + self.constraints.push(Constraint::Equivalent( + loc.clone(), + expr_info.result_type, + target_fun_type, + )); + let mut prereqs = vec![]; + + let (fun_prereqs, fun) = simplify_expr(expr_info.expression); + merge_prereq(&mut prereqs, fun_prereqs); + + let new_args = args + .into_iter() + .zip(arg_types) + .map(|(arg, target_type)| { + let arg_info = self.convert_expression(arg, renames); + let location = arg_info.expression.location().clone(); + let (arg_prereq, new_valref) = simplify_expr(arg_info.expression); + merge_prereq(&mut prereqs, arg_prereq); + self.constraints.push(Constraint::Equivalent( + location, + arg_info.result_type, + target_type, + )); + expr_info.free_variables.extend(arg_info.free_variables); + expr_info.bound_variables.extend(arg_info.bound_variables); + new_valref + }) + .collect(); + + let last_call = + ir::Expression::Call(loc.clone(), return_type.clone(), fun, new_args); + + expr_info.expression = finalize_expressions(prereqs, last_call); + expr_info.result_type = return_type; + + expr_info + } + + syntax::Expression::Block(loc, stmts) => { + let mut result_type = ir::TypeOrVar::Primitive(PrimitiveType::Void); + let mut exprs = vec![]; + let mut free_variables = HashSet::new(); + let mut bound_variables = HashSet::new(); + + for xpr in stmts.into_iter() { + let expr_info = self.convert_expression(xpr, renames); + result_type = expr_info.result_type; + exprs.push(expr_info.expression); + free_variables.extend( + expr_info + .free_variables + .difference(&bound_variables) + .cloned() + .collect::>(), + ); + bound_variables.extend(expr_info.bound_variables); + } + + ExpressionInfo { + expression: ir::Expression::Block(loc, result_type.clone(), exprs), + result_type, + free_variables, + bound_variables, + } + } + + syntax::Expression::Binding(loc, name, expr) => { + let mut expr_info = self.convert_expression(*expr, renames); + let final_name = self.finalize_name(renames, name); + self.variable_types + .insert(final_name.clone(), expr_info.result_type.clone()); + expr_info.expression = ir::Expression::Bind( + loc, + final_name.clone(), + expr_info.result_type.clone(), + Box::new(expr_info.expression), + ); + expr_info.bound_variables.insert(final_name); + expr_info + } + + syntax::Expression::Function(_, name, args, _, expr) => { + // First, at some point we're going to want to know a location for this function, + // which should either be the name if we have one, or the body if we don't. + let function_location = match name { + None => expr.location().clone(), + Some(ref name) => name.location.clone(), + }; + // Next, let us figure out what we're going to name this function. If the user + // didn't provide one, we'll just call it "function:" for them. (We'll + // want a name for this function, eventually, so we might as well do it now.) + // + // If they did provide a name, see if we're shadowed. IF we are, then we'll have + // to specialize the name a bit. Otherwise we'll stick with their name. + let function_name = match name { + None => ir::gensym("function"), + Some(unbound) => self.finalize_name(renames, unbound), }; - constraint_db.push(Constraint::FitsInNumType( - loc.clone(), - newtype.clone(), - value, + // This function is going to have a type. We don't know what it is, but it'll have + // one. + let function_type = ir::TypeOrVar::new(); + self.variable_types + .insert(function_name.clone(), function_type.clone()); + + // Then, let's figure out what to do with the argument names, which similarly + // may need to be renamed. We'll also generate some new type variables to associate + // with all of them. + // + // Note that we want to do all this in a new renaming scope, so that we shadow + // appropriately. + renames.new_scope(); + let arginfo = args + .into_iter() + .map(|(name, mut declared_type)| { + let new_type = ir::TypeOrVar::new(); + self.constraints.push(Constraint::IsSomething( + name.location.clone(), + new_type.clone(), + )); + let new_name = self.finalize_name(renames, name.clone()); + self.variable_types + .insert(new_name.clone(), new_type.clone()); + + if let Some(declared_type) = declared_type.take() { + let declared_type = self.convert_type(declared_type); + self.constraints.push(Constraint::Equivalent( + name.location.clone(), + new_type.clone(), + declared_type, + )); + } + + (new_name, new_type) + }) + .collect::>(); + + // Now we manufacture types for the outputs and then a type for the function itself. + // We're not going to make any claims on these types, yet; they're all just unknown + // type variables we need to work out. + let rettype = ir::TypeOrVar::new(); + let actual_function_type = ir::TypeOrVar::Function( + arginfo.iter().map(|x| x.1.clone()).collect(), + Box::new(rettype.clone()), + ); + self.constraints.push(Constraint::Equivalent( + function_location, + function_type, + actual_function_type, )); - ( - ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), - newtype, - ) + + // Now let's convert the body over to the new IR. + let expr_info = self.convert_expression(*expr, renames); + self.constraints.push(Constraint::Equivalent( + expr_info.expression.location().clone(), + rettype.clone(), + expr_info.result_type.clone(), + )); + + // Remember to exit this scoping level! + renames.release_scope(); + + self.functions + .insert(function_name, (arginfo, expr_info.expression.clone())); + + unimplemented!() } - }, - - syntax::Expression::Constructor(loc, name, fields) => { - let mut result_fields = HashMap::new(); - let mut type_fields = ir::Fields::default(); - let mut prereqs = vec![]; - - for (name, syntax_expr) in fields.into_iter() { - let (ir_expr, expr_type) = - convert_expression(syntax_expr, constraint_db, renames, bindings); - type_fields.insert(name.clone().intern(), expr_type); - let (prereq, value) = simplify_expr(ir_expr); - result_fields.insert(name.clone().intern(), value); - merge_prereq(&mut prereqs, prereq); - } - - let result_type = ir::TypeOrVar::Structure(type_fields); - - constraint_db.push(Constraint::NamedTypeIs( - loc.clone(), - name.clone().intern(), - result_type.clone(), - )); - let result = - ir::Expression::Construct(loc, result_type.clone(), name.intern(), result_fields); - - (finalize_expressions(prereqs, result), result_type) - } - - syntax::Expression::Reference(loc, name) => { - let iname = ArcIntern::new(name); - let final_name = renames.get(&iname).cloned().unwrap_or(iname); - let rtype = bindings - .get(&final_name) - .cloned() - .expect("variable bound before use"); - let refexp = - ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name)); - - (refexp, rtype) - } - - syntax::Expression::FieldRef(loc, expr, field) => { - let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings); - let (prereqs, val_or_ref) = simplify_expr(nexpr); - let result_type = ir::TypeOrVar::new(); - let result = ir::Expression::FieldRef( - loc.clone(), - result_type.clone(), - etype.clone(), - val_or_ref, - field.clone().intern(), - ); - - constraint_db.push(Constraint::TypeHasField( - loc, - etype, - field.intern(), - result_type.clone(), - )); - - (finalize_expression(prereqs, result), result_type) - } - - syntax::Expression::Cast(loc, target, expr) => { - let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings); - let (prereqs, val_or_ref) = simplify_expr(nexpr); - let target_type: ir::TypeOrVar = PrimitiveType::from_str(&target) - .expect("valid type for cast") - .into(); - let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); - - constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); - - (finalize_expression(prereqs, res), target_type) - } - - syntax::Expression::Primitive(loc, name) => { - let primop = ir::Primitive::from_str(&name.name).expect("valid primitive"); - - match primop { - ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => { - let numeric_type = ir::TypeOrVar::new_located(loc.clone()); - constraint_db.push(Constraint::NumericType(loc.clone(), numeric_type.clone())); - let funtype = ir::TypeOrVar::Function( - vec![numeric_type.clone(), numeric_type.clone()], - Box::new(numeric_type.clone()), - ); - let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - (ir::Expression::Atomic(result_value), funtype) - } - - ir::Primitive::Minus => { - let numeric_type = ir::TypeOrVar::new_located(loc.clone()); - constraint_db.push(Constraint::NumericType(loc.clone(), numeric_type.clone())); - let funtype = ir::TypeOrVar::Function( - vec![numeric_type.clone(), numeric_type.clone()], - Box::new(numeric_type.clone()), - ); - let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - (ir::Expression::Atomic(result_value), funtype) - } - - ir::Primitive::Print => { - let arg_type = ir::TypeOrVar::new_located(loc.clone()); - constraint_db.push(Constraint::Printable(loc.clone(), arg_type.clone())); - let funtype = ir::TypeOrVar::Function( - vec![arg_type], - Box::new(ir::TypeOrVar::Primitive(PrimitiveType::Void)), - ); - let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - (ir::Expression::Atomic(result_value), funtype) - } - - ir::Primitive::Negate => { - let arg_type = ir::TypeOrVar::new_located(loc.clone()); - constraint_db.push(Constraint::NumericType(loc.clone(), arg_type.clone())); - constraint_db.push(Constraint::IsSigned(loc.clone(), arg_type.clone())); - let funtype = - ir::TypeOrVar::Function(vec![arg_type.clone()], Box::new(arg_type)); - let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop); - (ir::Expression::Atomic(result_value), funtype) - } - } - } - - syntax::Expression::Call(loc, fun, args) => { - let return_type = ir::TypeOrVar::new(); - let arg_types = args - .iter() - .map(|_| ir::TypeOrVar::new()) - .collect::>(); - - let (new_fun, new_fun_type) = - convert_expression(*fun, constraint_db, renames, bindings); - let target_fun_type = - ir::TypeOrVar::Function(arg_types.clone(), Box::new(return_type.clone())); - constraint_db.push(Constraint::Equivalent( - loc.clone(), - new_fun_type, - target_fun_type, - )); - let mut prereqs = vec![]; - - let (fun_prereqs, fun) = simplify_expr(new_fun); - merge_prereq(&mut prereqs, fun_prereqs); - - let new_args = args - .into_iter() - .zip(arg_types) - .map(|(arg, target_type)| { - let (new_arg, inferred_type) = - convert_expression(arg, constraint_db, renames, bindings); - let location = new_arg.location().clone(); - let (arg_prereq, new_valref) = simplify_expr(new_arg); - merge_prereq(&mut prereqs, arg_prereq); - constraint_db.push(Constraint::Equivalent( - location, - inferred_type, - target_type, - )); - new_valref - }) - .collect(); - - let last_call = ir::Expression::Call(loc.clone(), return_type.clone(), fun, new_args); - - (finalize_expressions(prereqs, last_call), return_type) - } - - syntax::Expression::Block(loc, stmts) => { - let mut ret_type = ir::TypeOrVar::Primitive(PrimitiveType::Void); - let mut exprs = vec![]; - - for xpr in stmts.into_iter() { - let (expr, expr_type) = convert_expression(xpr, constraint_db, renames, bindings); - - ret_type = expr_type; - exprs.push(expr); - } - - ( - ir::Expression::Block(loc, ret_type.clone(), exprs), - ret_type, - ) - } - - syntax::Expression::Binding(loc, name, expr) => { - let (expr, ty) = convert_expression(*expr, constraint_db, renames, bindings); - let final_name = finalize_name(bindings, renames, name); - bindings.insert(final_name.clone(), ty.clone()); - - ( - ir::Expression::Bind(loc, final_name, ty.clone(), Box::new(expr)), - ty, - ) } } -} -fn convert_type(ty: syntax::Type, constraint_db: &mut Vec) -> ir::TypeOrVar { - match ty { - syntax::Type::Named(x) => match PrimitiveType::from_str(x.name.as_str()) { - Err(_) => { - let retval = ir::TypeOrVar::new_located(x.location.clone()); - constraint_db.push(Constraint::NamedTypeIs( - x.location.clone(), - x.intern(), - retval.clone(), - )); - retval + fn convert_type(&mut self, ty: syntax::Type) -> ir::TypeOrVar { + match ty { + syntax::Type::Named(x) => match PrimitiveType::from_str(x.name.as_str()) { + Err(_) => { + let retval = ir::TypeOrVar::new_located(x.location.clone()); + self.constraints.push(Constraint::NamedTypeIs( + x.location.clone(), + x.intern(), + retval.clone(), + )); + retval + } + Ok(v) => ir::TypeOrVar::Primitive(v), + }, + syntax::Type::Struct(fields) => { + let mut new_fields = ir::Fields::default(); + + for (name, field_type) in fields.into_iter() { + let new_field_type = field_type + .map(|x| self.convert_type(x)) + .unwrap_or_else(ir::TypeOrVar::new); + new_fields.insert(name.intern(), new_field_type); + } + + ir::TypeOrVar::Structure(new_fields) } - Ok(v) => ir::TypeOrVar::Primitive(v), - }, - syntax::Type::Struct(fields) => { - let mut new_fields = ir::Fields::default(); + } + } - for (name, field_type) in fields.into_iter() { - let new_field_type = field_type - .map(|x| convert_type(x, constraint_db)) - .unwrap_or_else(ir::TypeOrVar::new); - new_fields.insert(name.intern(), new_field_type); - } - - ir::TypeOrVar::Structure(new_fields) + fn finalize_name( + &mut self, + renames: &mut ScopedMap, ArcIntern>, + name: syntax::Name, + ) -> ArcIntern { + if self + .variable_types + .contains_key(&ArcIntern::new(name.name.clone())) + { + let new_name = ir::gensym(&name.name); + renames.insert(ArcIntern::new(name.name.to_string()), new_name.clone()); + new_name + } else { + ArcIntern::new(name.to_string()) } } } @@ -520,20 +589,6 @@ fn finalize_expressions( } } -fn finalize_name( - bindings: &HashMap, ir::TypeOrVar>, - renames: &mut ScopedMap, ArcIntern>, - name: syntax::Name, -) -> ArcIntern { - if bindings.contains_key(&ArcIntern::new(name.name.clone())) { - let new_name = ir::gensym(&name.name); - renames.insert(ArcIntern::new(name.name.to_string()), new_name.clone()); - new_name - } else { - ArcIntern::new(name.to_string()) - } -} - fn merge_prereq(left: &mut Vec, prereq: Option) { if let Some(item) = prereq { left.push(item) diff --git a/src/type_infer/error.rs b/src/type_infer/error.rs new file mode 100644 index 0000000..dcd0a6c --- /dev/null +++ b/src/type_infer/error.rs @@ -0,0 +1,146 @@ +use super::constraint::Constraint; +use crate::eval::PrimitiveType; +use crate::ir::{Primitive, TypeOrVar}; +use crate::syntax::Location; +use codespan_reporting::diagnostic::Diagnostic; +use internment::ArcIntern; + +/// The various kinds of errors that can occur while doing type inference. +pub enum TypeInferenceError { + /// The user provide a constant that is too large for its inferred type. + ConstantTooLarge(Location, PrimitiveType, u64), + /// Somehow we're trying to use a non-number as a number + NotANumber(Location, PrimitiveType), + /// The two types needed to be equivalent, but weren't. + NotEquivalent(Location, TypeOrVar, TypeOrVar), + /// We cannot safely cast the first type to the second type. + CannotSafelyCast(Location, PrimitiveType, PrimitiveType), + /// The primitive invocation provided the wrong number of arguments. + WrongPrimitiveArity(Location, Primitive, usize, usize, usize), + /// We cannot cast between the type types, for any number of reasons + CannotCast(Location, TypeOrVar, TypeOrVar), + /// We cannot turn a number into a function. + CannotMakeNumberAFunction(Location, TypeOrVar, Option), + /// We cannot turn a number into a Structure. + CannotMakeNumberAStructure(Location, TypeOrVar, Option), + /// We had a constraint we just couldn't solve. + CouldNotSolve(Constraint), + /// Functions are not printable. + FunctionsAreNotPrintable(Location), + /// The given type isn't signed, and can't be negated + IsNotSigned(Location, TypeOrVar), + /// The given type doesn't have the given field. + NoFieldForType(Location, ArcIntern, TypeOrVar), + /// There is no type with the given name. + UnknownTypeName(Location, ArcIntern), +} + +impl From for Diagnostic { + fn from(value: TypeInferenceError) -> Self { + match value { + TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc + .labelled_error("constant too large for type") + .with_message(format!( + "Type {} has a max value of {}, which is smaller than {}", + primty, + primty.max_value().expect("constant type has max value"), + value + )), + TypeInferenceError::NotANumber(loc, primty) => loc + .labelled_error("not a numeric type") + .with_message(format!( + "For some reason, we're trying to use {} as a numeric type", + primty, + )), + TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc + .labelled_error("type inference error") + .with_message(format!("Expected type {}, received type {}", ty1, ty2)), + TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc + .labelled_error("unsafe type cast") + .with_message(format!("Cannot safely cast {} to {}", ty1, ty2)), + TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc + .labelled_error("wrong number of arguments") + .with_message(format!( + "expected {} for {}, received {}", + if lower == upper && lower > 1 { + format!("{} arguments", lower) + } else if lower == upper { + format!("{} argument", lower) + } else { + format!("{}-{} arguments", lower, upper) + }, + prim, + observed + )), + TypeInferenceError::CannotCast(loc, t1, t2) => loc + .labelled_error("cannot cast between types") + .with_message(format!( + "tried to cast from {} to {}", + t1, t2, + )), + TypeInferenceError::CannotMakeNumberAFunction(loc, t, val) => loc + .labelled_error(if let Some(val) = val { + format!("cannot turn {} into a function", val) + } else { + "cannot use a constant as a function type".to_string() + }) + .with_message(format!("function type was {}", t)), + TypeInferenceError::CannotMakeNumberAStructure(loc, t, val) => loc + .labelled_error(if let Some(val) = val { + format!("cannot turn {} into a function", val) + } else { + "cannot use a constant as a function type".to_string() + }) + .with_message(format!("function type was {}", t)), + TypeInferenceError::FunctionsAreNotPrintable(loc) => loc + .labelled_error("cannot print function values"), + TypeInferenceError::IsNotSigned(loc, pt) => loc + .labelled_error(format!("type {} is not signed", pt)) + .with_message("and so it cannot be negated"), + TypeInferenceError::NoFieldForType(loc, field, t) => loc + .labelled_error(format!("no field {} available for type {}", field, t)), + TypeInferenceError::UnknownTypeName(loc , name) => loc + .labelled_error(format!("unknown type named {}", name)), + TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => { + loc.labelled_error("internal error").with_message(format!( + "could not determine if it was safe to cast from {} to {}", + a, b + )) + } + TypeInferenceError::CouldNotSolve(Constraint::TypeHasField(loc, a, field, _)) => { + loc.labelled_error("internal error") + .with_message(format!("fould not determine if type {} has field {}", a, field)) + } + TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => { + loc.labelled_error("internal error").with_message(format!( + "could not determine if {} and {} were equivalent", + a, b + )) + } + TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => { + loc.labelled_error("internal error").with_message(format!( + "Could not determine if {} could fit in {}", + val, ty + )) + } + TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc + .labelled_error("internal error") + .with_message(format!("Could not determine if {} was a numeric type", ty)), + TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) => + panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty), + TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc + .labelled_error("internal error") + .with_message(format!("Could not determine if type {} was printable", ty)), + TypeInferenceError::CouldNotSolve(Constraint::IsSomething(loc, _)) => { + loc.labelled_error("could not infer type") + .with_message("Could not find *any* type information; is this an unused function argument?") + } + TypeInferenceError::CouldNotSolve(Constraint::IsSigned(loc, t)) => loc + .labelled_error("internal error") + .with_message(format!("could not infer that type {} was signed", t)), + TypeInferenceError::CouldNotSolve(Constraint::NamedTypeIs(loc, name, ty)) => loc + .labelled_error("internal error") + .with_message(format!("could not infer that the name {} refers to {}", name, ty)), + } + } +} diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index 624bbe8..9265953 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -1,41 +1,61 @@ -use super::solve::TypeResolutions; use crate::eval::PrimitiveType; -use crate::ir::{Expression, Program, TopLevel, Type, TypeOrVar, Value, ValueOrRef}; +use crate::ir::{Expression, Program, TopLevel, Type, TypeOrVar, TypeWithVoid, Value, ValueOrRef}; +use crate::syntax::Location; +use internment::ArcIntern; +use std::collections::HashMap; -pub fn finalize_program( - program: Program, - resolutions: &TypeResolutions, -) -> Program { - for (name, ty) in resolutions.iter() { - tracing::debug!(name = %name, resolved_type = %ty, "resolved type variable"); - } +pub type TypeResolutions = HashMap, Type>; - Program { - items: program - .items - .into_iter() - .map(|x| finalize_top_level(x, resolutions)) - .collect(), +impl super::InferenceEngine { + pub fn finalize_program(self, resolutions: TypeResolutions) -> Program { + for (name, ty) in resolutions.iter() { + tracing::debug!(name = %name, resolved_type = %ty, "resolved type variable"); + } - type_definitions: program - .type_definitions - .into_iter() - .map(|(n, t)| (n, finalize_type(t, resolutions))) - .collect(), - } -} + let mut type_definitions = HashMap::new(); + let mut items = Vec::new(); -fn finalize_top_level(item: TopLevel, resolutions: &TypeResolutions) -> TopLevel { - match item { - TopLevel::Function(name, args, rettype, expr) => TopLevel::Function( - name, - args.into_iter() - .map(|(name, t)| (name, finalize_type(t, resolutions))) - .collect(), - finalize_type(rettype, resolutions), - finalize_expression(expr, resolutions), - ), - TopLevel::Statement(expr) => TopLevel::Statement(finalize_expression(expr, resolutions)), + for (name, def) in self.type_definitions.into_iter() { + type_definitions.insert(name, finalize_type(def, &resolutions)); + } + + for (name, (arguments, body)) in self.functions.into_iter() { + let new_body = finalize_expression(body, &resolutions); + let arguments = arguments + .into_iter() + .map(|(name, t)| (name, finalize_type(t, &resolutions))) + .collect(); + items.push(TopLevel::Function( + name, + arguments, + new_body.type_of(), + new_body, + )); + } + + let mut body = vec![]; + let mut last_type = Type::void(); + let mut location = None; + + for expr in self.statements.into_iter() { + let next = finalize_expression(expr, &resolutions); + location = location + .map(|x: Location| x.merge(next.location())) + .unwrap_or_else(|| Some(next.location().clone())); + last_type = next.type_of(); + body.push(next); + } + + items.push(TopLevel::Statement(Expression::Block( + location.unwrap_or_else(Location::manufactured), + last_type, + body, + ))); + + Program { + items, + type_definitions, + } } } diff --git a/src/type_infer/result.rs b/src/type_infer/result.rs new file mode 100644 index 0000000..49c474f --- /dev/null +++ b/src/type_infer/result.rs @@ -0,0 +1,50 @@ +use super::error::TypeInferenceError; +use super::warning::TypeInferenceWarning; + +/// The results of type inference; like [`Result`], but with a bit more information. +/// +/// This result is parameterized, because sometimes it's handy to return slightly +/// different things; there's a [`TypeInferenceResult::map`] function for performing +/// those sorts of conversions. +pub enum TypeInferenceResult { + Success { + result: Result, + warnings: Vec, + }, + Failure { + errors: Vec, + warnings: Vec, + }, +} + +impl TypeInferenceResult { + // If this was a successful type inference, run the function over the result to + // create a new result. + // + // This is the moral equivalent of [`Result::map`], but for type inference results. + pub fn map(self, f: F) -> TypeInferenceResult + where + F: FnOnce(R) -> U, + { + match self { + TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success { + result: f(result), + warnings, + }, + + TypeInferenceResult::Failure { errors, warnings } => { + TypeInferenceResult::Failure { errors, warnings } + } + } + } + + // Return the final result, or panic if it's not a success + pub fn expect(self, msg: &str) -> R { + match self { + TypeInferenceResult::Success { result, .. } => result, + TypeInferenceResult::Failure { .. } => { + panic!("tried to get value from failed type inference: {}", msg) + } + } + } +} diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index 15dc067..74f1c9f 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -1,826 +1,556 @@ +use super::constraint::Constraint; +use super::error::TypeInferenceError; +use super::warning::TypeInferenceWarning; use crate::eval::PrimitiveType; -use crate::ir::{Primitive, Type, TypeOrVar}; -use crate::syntax::Location; -use codespan_reporting::diagnostic::Diagnostic; +use crate::ir::TypeOrVar; use internment::ArcIntern; -use std::{collections::HashMap, fmt}; -/// A type inference constraint that we're going to need to solve. -#[derive(Debug)] -pub enum Constraint { - /// The given type must be printable using the `print` built-in - Printable(Location, TypeOrVar), - /// The provided numeric value fits in the given constant type - FitsInNumType(Location, TypeOrVar, u64), - /// The given type can be casted to the target type safely - CanCastTo(Location, TypeOrVar, TypeOrVar), - /// The given type has the given field in it, and the type of that field - /// is as given. - TypeHasField(Location, TypeOrVar, ArcIntern, TypeOrVar), - /// The given type must be some numeric type, but this is not a constant - /// value, so don't try to default it if we can't figure it out - NumericType(Location, TypeOrVar), - /// The given type is attached to a constant and must be some numeric type. - /// If we can't figure it out, we should warn the user and then just use a - /// default. - ConstantNumericType(Location, TypeOrVar), - /// The two types should be equivalent - Equivalent(Location, TypeOrVar, TypeOrVar), - /// The given type can be resolved to something - IsSomething(Location, TypeOrVar), - /// The given type can be negated - IsSigned(Location, TypeOrVar), - /// Checks to see if the given named type is equivalent to the provided one. - NamedTypeIs(Location, ArcIntern, TypeOrVar), -} - -impl fmt::Display for Constraint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty), - Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty), - Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), - Constraint::TypeHasField(_, ty1, field, ty2) => { - write!(f, "FIELD {}.{} -> {}", ty1, field, ty2) - } - Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), - Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty), - Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), - Constraint::IsSomething(_, ty) => write!(f, "SOMETHING {}", ty), - Constraint::IsSigned(_, ty) => write!(f, "SIGNED {}", ty), - Constraint::NamedTypeIs(_, name, ty) => write!(f, "TYPE_EQUIV {} == {}", name, ty), - } - } -} - -impl Constraint { - /// Replace all instances of the name (anywhere! including on the left hand side of equivalences!) - /// with the given type. +impl super::InferenceEngine { + /// Solve all the constraints in the provided database. /// - /// Returns whether or not anything was changed in the constraint. - fn replace(&mut self, name: &ArcIntern, replace_with: &TypeOrVar) -> bool { - match self { - Constraint::Printable(_, ty) => ty.replace(name, replace_with), - Constraint::FitsInNumType(_, ty, _) => ty.replace(name, replace_with), - Constraint::CanCastTo(_, ty1, ty2) => { - ty1.replace(name, replace_with) || ty2.replace(name, replace_with) + /// This process can take a bit, so you might not want to do it multiple times. Basically, + /// it's going to grind on these constraints until either it figures them out, or it stops + /// making progress. I haven't done the math on the constraints to even figure out if this + /// is guaranteed to halt, though, let alone terminate in some reasonable amount of time. + /// + /// The return value is a type inference result, which pairs some warnings with either a + /// successful set of type resolutions (mappings from type variables to their values), or + /// a series of inference errors. + pub fn solve_constraints(&mut self) { + let mut iteration = 0u64; + + loop { + let mut changed_something = false; + let mut all_constraints_solved = true; + let mut new_constraints = vec![]; + + tracing::debug!(iteration, "Restarting constraint solving loop"); + for constraint in self.constraints.iter() { + tracing::debug!(%constraint, "remaining constraint"); } - Constraint::TypeHasField(_, ty1, _, ty2) => { - ty1.replace(name, replace_with) || ty2.replace(name, replace_with) - } - Constraint::ConstantNumericType(_, ty) => ty.replace(name, replace_with), - Constraint::Equivalent(_, ty1, ty2) => { - ty1.replace(name, replace_with) || ty2.replace(name, replace_with) - } - Constraint::IsSigned(_, ty) => ty.replace(name, replace_with), - Constraint::IsSomething(_, ty) => ty.replace(name, replace_with), - Constraint::NumericType(_, ty) => ty.replace(name, replace_with), - Constraint::NamedTypeIs(_, name, ty) => ty.replace(name, replace_with), - } - } -} + iteration += 1; -pub type TypeResolutions = HashMap, Type>; + while let Some(constraint) = self.constraints.pop() { + match constraint { + // The basic philosophy of this match block is that, for each constraint, we're + // going to start seeing if we can just solve (or abandon) the constraint. Then, + // if we can't, we'll just chuck it back on our list for later. -/// The results of type inference; like [`Result`], but with a bit more information. -/// -/// This result is parameterized, because sometimes it's handy to return slightly -/// different things; there's a [`TypeInferenceResult::map`] function for performing -/// those sorts of conversions. -pub enum TypeInferenceResult { - Success { - result: Result, - warnings: Vec, - }, - Failure { - errors: Vec, - warnings: Vec, - }, -} - -impl TypeInferenceResult { - // If this was a successful type inference, run the function over the result to - // create a new result. - // - // This is the moral equivalent of [`Result::map`], but for type inference results. - pub fn map(self, f: F) -> TypeInferenceResult - where - F: FnOnce(R) -> U, - { - match self { - TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success { - result: f(result), - warnings, - }, - - TypeInferenceResult::Failure { errors, warnings } => { - TypeInferenceResult::Failure { errors, warnings } - } - } - } - - // Return the final result, or panic if it's not a success - pub fn expect(self, msg: &str) -> R { - match self { - TypeInferenceResult::Success { result, .. } => result, - TypeInferenceResult::Failure { .. } => { - panic!("tried to get value from failed type inference: {}", msg) - } - } - } -} - -/// The various kinds of errors that can occur while doing type inference. -pub enum TypeInferenceError { - /// The user provide a constant that is too large for its inferred type. - ConstantTooLarge(Location, PrimitiveType, u64), - /// Somehow we're trying to use a non-number as a number - NotANumber(Location, PrimitiveType), - /// The two types needed to be equivalent, but weren't. - NotEquivalent(Location, TypeOrVar, TypeOrVar), - /// We cannot safely cast the first type to the second type. - CannotSafelyCast(Location, PrimitiveType, PrimitiveType), - /// The primitive invocation provided the wrong number of arguments. - WrongPrimitiveArity(Location, Primitive, usize, usize, usize), - /// We cannot cast between the type types, for any number of reasons - CannotCast(Location, TypeOrVar, TypeOrVar), - /// We cannot turn a number into a function. - CannotMakeNumberAFunction(Location, TypeOrVar, Option), - /// We cannot turn a number into a Structure. - CannotMakeNumberAStructure(Location, TypeOrVar, Option), - /// We had a constraint we just couldn't solve. - CouldNotSolve(Constraint), - /// Functions are not printable. - FunctionsAreNotPrintable(Location), - /// The given type isn't signed, and can't be negated - IsNotSigned(Location, TypeOrVar), - /// The given type doesn't have the given field. - NoFieldForType(Location, ArcIntern, TypeOrVar), - /// There is no type with the given name. - UnknownTypeName(Location, ArcIntern), -} - -impl From for Diagnostic { - fn from(value: TypeInferenceError) -> Self { - match value { - TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc - .labelled_error("constant too large for type") - .with_message(format!( - "Type {} has a max value of {}, which is smaller than {}", - primty, - primty.max_value().expect("constant type has max value"), - value - )), - TypeInferenceError::NotANumber(loc, primty) => loc - .labelled_error("not a numeric type") - .with_message(format!( - "For some reason, we're trying to use {} as a numeric type", - primty, - )), - TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc - .labelled_error("type inference error") - .with_message(format!("Expected type {}, received type {}", ty1, ty2)), - TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc - .labelled_error("unsafe type cast") - .with_message(format!("Cannot safely cast {} to {}", ty1, ty2)), - TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc - .labelled_error("wrong number of arguments") - .with_message(format!( - "expected {} for {}, received {}", - if lower == upper && lower > 1 { - format!("{} arguments", lower) - } else if lower == upper { - format!("{} argument", lower) - } else { - format!("{}-{} arguments", lower, upper) - }, - prim, - observed - )), - TypeInferenceError::CannotCast(loc, t1, t2) => loc - .labelled_error("cannot cast between types") - .with_message(format!( - "tried to cast from {} to {}", - t1, t2, - )), - TypeInferenceError::CannotMakeNumberAFunction(loc, t, val) => loc - .labelled_error(if let Some(val) = val { - format!("cannot turn {} into a function", val) - } else { - "cannot use a constant as a function type".to_string() - }) - .with_message(format!("function type was {}", t)), - TypeInferenceError::CannotMakeNumberAStructure(loc, t, val) => loc - .labelled_error(if let Some(val) = val { - format!("cannot turn {} into a function", val) - } else { - "cannot use a constant as a function type".to_string() - }) - .with_message(format!("function type was {}", t)), - TypeInferenceError::FunctionsAreNotPrintable(loc) => loc - .labelled_error("cannot print function values"), - TypeInferenceError::IsNotSigned(loc, pt) => loc - .labelled_error(format!("type {} is not signed", pt)) - .with_message("and so it cannot be negated"), - TypeInferenceError::NoFieldForType(loc, field, t) => loc - .labelled_error(format!("no field {} available for type {}", field, t)), - TypeInferenceError::UnknownTypeName(loc , name) => loc - .labelled_error(format!("unknown type named {}", name)), - TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => { - loc.labelled_error("internal error").with_message(format!( - "could not determine if it was safe to cast from {} to {}", - a, b - )) - } - TypeInferenceError::CouldNotSolve(Constraint::TypeHasField(loc, a, field, _)) => { - loc.labelled_error("internal error") - .with_message(format!("fould not determine if type {} has field {}", a, field)) - } - TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => { - loc.labelled_error("internal error").with_message(format!( - "could not determine if {} and {} were equivalent", - a, b - )) - } - TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => { - loc.labelled_error("internal error").with_message(format!( - "Could not determine if {} could fit in {}", - val, ty - )) - } - TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc - .labelled_error("internal error") - .with_message(format!("Could not determine if {} was a numeric type", ty)), - TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) => - panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty), - TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc - .labelled_error("internal error") - .with_message(format!("Could not determine if type {} was printable", ty)), - TypeInferenceError::CouldNotSolve(Constraint::IsSomething(loc, _)) => { - loc.labelled_error("could not infer type") - .with_message("Could not find *any* type information; is this an unused function argument?") - } - TypeInferenceError::CouldNotSolve(Constraint::IsSigned(loc, t)) => loc - .labelled_error("internal error") - .with_message(format!("could not infer that type {} was signed", t)), - TypeInferenceError::CouldNotSolve(Constraint::NamedTypeIs(loc, name, ty)) => loc - .labelled_error("internal error") - .with_message(format!("could not infer that the name {} refers to {}", name, ty)), - } - } -} - -/// Warnings that we might want to tell the user about. -/// -/// These are fine, probably, but could indicate some behavior the user might not -/// expect, and so they might want to do something about them. -pub enum TypeInferenceWarning { - DefaultedTo(Location, TypeOrVar), -} - -impl From for Diagnostic { - fn from(value: TypeInferenceWarning) -> Self { - match value { - TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning() - .with_labels(vec![loc.primary_label().with_message("unknown type")]) - .with_message(format!("Defaulted unknown type to {}", ty)), - } - } -} - -/// Solve all the constraints in the provided database. -/// -/// This process can take a bit, so you might not want to do it multiple times. Basically, -/// it's going to grind on these constraints until either it figures them out, or it stops -/// making progress. I haven't done the math on the constraints to even figure out if this -/// is guaranteed to halt, though, let alone terminate in some reasonable amount of time. -/// -/// The return value is a type inference result, which pairs some warnings with either a -/// successful set of type resolutions (mappings from type variables to their values), or -/// a series of inference errors. -pub fn solve_constraints( - known_types: &HashMap, TypeOrVar>, - mut constraint_db: Vec, -) -> TypeInferenceResult { - let mut errors = vec![]; - let mut warnings = vec![]; - let mut iteration = 0u64; - - loop { - let mut changed_something = false; - let mut all_constraints_solved = true; - let mut new_constraints = vec![]; - - tracing::debug!(iteration, "Restarting constraint solving loop"); - for constraint in constraint_db.iter() { - tracing::debug!(%constraint, "remaining constraint"); - } - iteration += 1; - - while let Some(constraint) = constraint_db.pop() { - match constraint { - // The basic philosophy of this match block is that, for each constraint, we're - // going to start seeing if we can just solve (or abandon) the constraint. Then, - // if we can't, we'll just chuck it back on our list for later. - - // Checks on whether we can cast from one thing to another! - Constraint::CanCastTo( - loc, - TypeOrVar::Primitive(from_type), - TypeOrVar::Primitive(to_type), - ) => { - if !from_type.can_cast_to(&to_type) { - errors.push(TypeInferenceError::CannotSafelyCast( - loc, from_type, to_type, - )); - } - tracing::trace!(form = %from_type, to = %to_type, "we can determine if we can do the cast"); - changed_something = true; - } - - Constraint::CanCastTo( - loc, - TypeOrVar::Function(args1, ret1), - TypeOrVar::Function(args2, ret2), - ) => { - if args1.len() == args2.len() { - new_constraints.push(Constraint::Equivalent(loc.clone(), *ret1, *ret2)); - for (arg1, arg2) in args1.into_iter().zip(args2) { - new_constraints.push(Constraint::Equivalent(loc.clone(), arg1, arg2)) - } - all_constraints_solved = false; - } else { - errors.push(TypeInferenceError::CannotCast( - loc, - TypeOrVar::Function(args1, ret1), - TypeOrVar::Function(args2, ret2), - )); - } - tracing::trace!( - "we transferred CanCastTo to equivalence checks for function types" - ); - changed_something = true; - } - - Constraint::CanCastTo( - loc, - st1 @ TypeOrVar::Structure(_), - st2 @ TypeOrVar::Structure(_), - ) => { - tracing::trace!( - "structures can be equivalent, if their fields and types are exactly the same" - ); - new_constraints.push(Constraint::Equivalent(loc, st1, st2)); - changed_something = true; - } - - Constraint::CanCastTo( - loc, - ft @ TypeOrVar::Function(_, _), - ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Structure(_), - ) => { - tracing::trace!(function_type = %ft, other_type = %ot, "we can't cast a function type to a primitive or structure type"); - errors.push(TypeInferenceError::CannotCast(loc, ft, ot)); - changed_something = true; - } - - Constraint::CanCastTo( - loc, - pt @ TypeOrVar::Primitive(_), - ot @ TypeOrVar::Function(_, _) | ot @ TypeOrVar::Structure(_), - ) => { - tracing::trace!(other_type = %ot, primitive_type = %pt, "we can't cast a primitive type to a function or structure type"); - errors.push(TypeInferenceError::CannotCast(loc, pt, ot)); - changed_something = true; - } - - Constraint::CanCastTo( - loc, - st @ TypeOrVar::Structure(_), - ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _), - ) => { - tracing::trace!(structure_type = %st, other_type = %ot, "we can't cast a structure type to a function or primitive type"); - errors.push(TypeInferenceError::CannotCast(loc, st, ot)); - changed_something = true; - } - - Constraint::NamedTypeIs(loc, name, ty) => match known_types.get(&name) { - None => { - tracing::trace!(type_name = %name, "we don't know a type named name"); - errors.push(TypeInferenceError::UnknownTypeName(loc, name)); - changed_something = true; - } - - Some(declared_type) => { - tracing::trace!(type_name = %name, declared = %declared_type, provided = %ty, "validating that named type is equivalent to provided"); - new_constraints.push(Constraint::Equivalent( - loc, - declared_type.clone(), - ty, - )); - changed_something = true; - } - }, - - Constraint::TypeHasField( - loc, - TypeOrVar::Structure(mut fields), - field, - result_type, - ) => match fields.remove_field(&field) { - None => { - let reconstituted = TypeOrVar::Structure(fields); - tracing::trace!(structure_type = %reconstituted, %field, "no field found in type"); - errors.push(TypeInferenceError::NoFieldForType( - loc, - field, - reconstituted, - )); - changed_something = true; - } - - Some(field_subtype) => { - tracing::trace!(%field_subtype, %result_type, %field, "validating that field's subtype matches target result type"); - new_constraints.push(Constraint::Equivalent( - loc, - result_type, - field_subtype, - )); - changed_something = true; - } - }, - - Constraint::TypeHasField( - loc, - ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _), - field, - _, - ) => { - tracing::trace!(other_type = %ot, %field, "can't get field from primitive or function type"); - errors.push(TypeInferenceError::NoFieldForType(loc, field, ot)); - changed_something = true; - } - - // if we're testing if an actual primitive type is numeric, that's pretty easy - Constraint::ConstantNumericType(loc, TypeOrVar::Primitive(pt)) => { - tracing::trace!(primitive_type = %pt, "its easy to tell if a constant number can be a primitive type"); - if pt.max_value().is_none() { - errors.push(TypeInferenceError::NotANumber(loc, pt)) - } - changed_something = true; - } - - // if we're testing if a function type is numeric, then throw a useful warning - Constraint::ConstantNumericType(loc, t @ TypeOrVar::Function(_, _)) => { - tracing::trace!(function_type = %t, "functions can't be constant numbers"); - errors.push(TypeInferenceError::CannotMakeNumberAFunction(loc, t, None)); - changed_something = true; - } - - // if we're testing if a function type is numeric, then throw a useful warning - Constraint::ConstantNumericType(loc, t @ TypeOrVar::Structure(_)) => { - tracing::trace!(structure_type = %t, "structures can't be constant numbers"); - errors.push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None)); - changed_something = true; - } - - // if we're testing if a number can fit into a numeric type, we can just do that! - Constraint::FitsInNumType(loc, TypeOrVar::Primitive(ctype), val) => { - match ctype.max_value() { - None => errors.push(TypeInferenceError::NotANumber(loc, ctype)), - - Some(max_value) if max_value < val => { - errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); - } - - Some(_) => {} - } - changed_something = true; - tracing::trace!(primitive_type = %ctype, value = val, "we can test for a value fitting in a primitive type"); - } - - // if we're testing if a function type can fit into a numeric type, that's a problem - Constraint::FitsInNumType(loc, t @ TypeOrVar::Function(_, _), val) => { - tracing::trace!(function_type = %t, "values don't fit in function types"); - errors.push(TypeInferenceError::CannotMakeNumberAFunction( + // Checks on whether we can cast from one thing to another! + Constraint::CanCastTo( loc, - t, - Some(val), - )); - changed_something = true; - } - - // if we're testing if a function type can fit into a numeric type, that's a problem - Constraint::FitsInNumType(loc, t @ TypeOrVar::Structure(_), val) => { - tracing::trace!(function_type = %t, "values don't fit in structure types"); - errors.push(TypeInferenceError::CannotMakeNumberAStructure( - loc, - t, - Some(val), - )); - changed_something = true; - } - - // if we want to know if a type is something, and it is something, then we're done - Constraint::IsSomething(_, t @ TypeOrVar::Function(_, _)) - | Constraint::IsSomething(_, t @ TypeOrVar::Primitive(_)) - | Constraint::IsSomething(_, t @ TypeOrVar::Structure(_)) => { - tracing::trace!(tested_type = %t, "type is definitely something"); - changed_something = true; - } - - // if we want to know if something is signed, we can check its primitive type - Constraint::IsSigned(loc, TypeOrVar::Primitive(pt)) => { - tracing::trace!(primitive_type = %pt, "we can check if a primitive is signed"); - if !pt.valid_operators().contains(&("-", 1)) { - errors.push(TypeInferenceError::IsNotSigned( - loc, - TypeOrVar::Primitive(pt), - )); - } - changed_something = true; - } - - // again with the functions and the numbers - Constraint::IsSigned(loc, t @ TypeOrVar::Function(_, _)) => { - tracing::trace!(function_type = %t, "functions are not signed"); - errors.push(TypeInferenceError::IsNotSigned(loc, t)); - changed_something = true; - } - - // again with the functions and the numbers - Constraint::IsSigned(loc, t @ TypeOrVar::Structure(_)) => { - tracing::trace!(structure_type = %t, "structures are not signed"); - errors.push(TypeInferenceError::IsNotSigned(loc, t)); - changed_something = true; - } - - // if we're testing if an actual primitive type is numeric, that's pretty easy - Constraint::NumericType(loc, TypeOrVar::Primitive(pt)) => { - tracing::trace!(primitive_type = %pt, "its easy to tell if a primitive type is numeric"); - if pt.max_value().is_none() { - errors.push(TypeInferenceError::NotANumber(loc, pt)) - } - changed_something = true; - } - - // if we're testing if a function type is numeric, then throw a useful warning - Constraint::NumericType(loc, t @ TypeOrVar::Function(_, _)) => { - tracing::trace!(function_type = %t, "function types aren't numeric"); - errors.push(TypeInferenceError::CannotMakeNumberAFunction(loc, t, None)); - changed_something = true; - } - - // if we're testing if a structure type is numeric, then throw a useful warning - Constraint::NumericType(loc, t @ TypeOrVar::Structure(_)) => { - tracing::trace!(structure_type = %t, "structure types aren't numeric"); - errors.push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None)); - changed_something = true; - } - - // all of our primitive types are printable - Constraint::Printable(_, TypeOrVar::Primitive(pt)) => { - tracing::trace!(primitive_type = %pt, "primitive types are printable"); - changed_something = true; - } - - // function types are definitely not printable - Constraint::Printable(loc, ft @ TypeOrVar::Function(_, _)) => { - tracing::trace!(function_type = %ft, "function types are not printable"); - errors.push(TypeInferenceError::FunctionsAreNotPrintable(loc)); - changed_something = true; - } - - // structure types are printable if all the types inside them are printable - Constraint::Printable(loc, TypeOrVar::Structure(fields)) => { - tracing::trace!( - "structure types are printable if all their subtypes are printable" - ); - for (_, subtype) in fields.into_iter() { - new_constraints.push(Constraint::Printable(loc.clone(), subtype)); - } - changed_something = true; - } - - // Some equivalences we can/should solve directly - Constraint::Equivalent( - loc, - TypeOrVar::Primitive(pt1), - TypeOrVar::Primitive(pt2), - ) => { - if pt1 != pt2 { - errors.push(TypeInferenceError::NotEquivalent( - loc, - TypeOrVar::Primitive(pt1), - TypeOrVar::Primitive(pt2), - )); - } - changed_something = true; - tracing::trace!(primitive_type1 = %pt1, primitive_type2 = %pt2, "we checked for primitive type equivalence"); - } - - Constraint::Equivalent( - loc, - pt @ TypeOrVar::Primitive(_), - ft @ TypeOrVar::Function(_, _), - ) - | Constraint::Equivalent( - loc, - ft @ TypeOrVar::Function(_, _), - pt @ TypeOrVar::Primitive(_), - ) => { - tracing::trace!(primitive_type = %pt, function_type = %ft, "function and primitive types cannot be equivalent"); - errors.push(TypeInferenceError::NotEquivalent(loc, pt, ft)); - changed_something = true; - } - - Constraint::Equivalent( - loc, - pt @ TypeOrVar::Primitive(_), - st @ TypeOrVar::Structure(_), - ) - | Constraint::Equivalent( - loc, - st @ TypeOrVar::Structure(_), - pt @ TypeOrVar::Primitive(_), - ) => { - tracing::trace!(primitive_type = %pt, structure_type = %st, "structure and primitive types cannot be equivalent"); - errors.push(TypeInferenceError::NotEquivalent(loc, pt, st)); - changed_something = true; - } - - Constraint::Equivalent( - loc, - st @ TypeOrVar::Structure(_), - ft @ TypeOrVar::Function(_, _), - ) - | Constraint::Equivalent( - loc, - ft @ TypeOrVar::Function(_, _), - st @ TypeOrVar::Structure(_), - ) => { - tracing::trace!(structure_type = %st, function_type = %ft, "structure and primitive types cannot be equivalent"); - errors.push(TypeInferenceError::NotEquivalent(loc, st, ft)); - changed_something = true; - } - - Constraint::Equivalent( - _, - TypeOrVar::Variable(_, name1), - TypeOrVar::Variable(_, name2), - ) if name1 == name2 => { - tracing::trace!(name = %name1, "variable is equivalent to itself"); - changed_something = true; - } - - Constraint::Equivalent( - loc, - TypeOrVar::Function(args1, ret1), - TypeOrVar::Function(args2, ret2), - ) => { - if args1.len() != args2.len() { - let t1 = TypeOrVar::Function(args1, ret1); - let t2 = TypeOrVar::Function(args2, ret2); - errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2)); - } else { - for (left, right) in args1.into_iter().zip(args2) { - new_constraints.push(Constraint::Equivalent(loc.clone(), left, right)); - } - new_constraints.push(Constraint::Equivalent(loc, *ret1, *ret2)); - all_constraints_solved = false; - } - - changed_something = true; - tracing::trace!("we checked/rewrote if function types are equivalent"); - } - - Constraint::Equivalent( - loc, - TypeOrVar::Structure(fields1), - TypeOrVar::Structure(mut fields2), - ) => { - if fields1.count() == fields2.count() - && fields1.field_names().all(|x| fields2.has_field(x)) - { - for (name, subtype1) in fields1.into_iter() { - let subtype2 = fields2 - .remove_field(&name) - .expect("can find matching field after equivalence check"); - new_constraints.push(Constraint::Equivalent( - loc.clone(), - subtype1, - subtype2, + TypeOrVar::Primitive(from_type), + TypeOrVar::Primitive(to_type), + ) => { + if !from_type.can_cast_to(&to_type) { + self.errors.push(TypeInferenceError::CannotSafelyCast( + loc, from_type, to_type, )); } - } else { - errors.push(TypeInferenceError::NotEquivalent( - loc, - TypeOrVar::Structure(fields1), - TypeOrVar::Structure(fields2), - )) + tracing::trace!(form = %from_type, to = %to_type, "we can determine if we can do the cast"); + changed_something = true; } - changed_something = true; - tracing::trace!("we checked/rewrote if structures are equivalent"); - } - Constraint::Equivalent(_, TypeOrVar::Variable(_, ref name), ref rhs) => { - changed_something |= replace_variable(&mut constraint_db, name, rhs); - changed_something |= replace_variable(&mut new_constraints, name, rhs); - all_constraints_solved &= rhs.is_resolved(); - if changed_something { - tracing::trace!(%name, new_type = %rhs, "we were able to rewrite name somewhere"); + Constraint::CanCastTo( + loc, + TypeOrVar::Function(args1, ret1), + TypeOrVar::Function(args2, ret2), + ) => { + if args1.len() == args2.len() { + new_constraints.push(Constraint::Equivalent(loc.clone(), *ret1, *ret2)); + for (arg1, arg2) in args1.into_iter().zip(args2) { + new_constraints.push(Constraint::Equivalent( + loc.clone(), + arg1, + arg2, + )) + } + all_constraints_solved = false; + } else { + self.errors.push(TypeInferenceError::CannotCast( + loc, + TypeOrVar::Function(args1, ret1), + TypeOrVar::Function(args2, ret2), + )); + } + tracing::trace!( + "we transferred CanCastTo to equivalence checks for function types" + ); + changed_something = true; } - new_constraints.push(constraint); - } - Constraint::Equivalent(loc, lhs, rhs @ TypeOrVar::Variable(_, _)) => { - tracing::trace!(new_left = %rhs, new_right = %lhs, "we flipped the order on an equivalence"); - new_constraints.push(Constraint::Equivalent(loc, rhs, lhs)); - changed_something = true; - all_constraints_solved = false; - } + Constraint::CanCastTo( + loc, + st1 @ TypeOrVar::Structure(_), + st2 @ TypeOrVar::Structure(_), + ) => { + tracing::trace!( + "structures can be equivalent, if their fields and types are exactly the same" + ); + new_constraints.push(Constraint::Equivalent(loc, st1, st2)); + changed_something = true; + } - Constraint::CanCastTo(_, TypeOrVar::Variable(_, _), _) - | Constraint::CanCastTo(_, _, TypeOrVar::Variable(_, _)) - | Constraint::TypeHasField(_, TypeOrVar::Variable(_, _), _, _) - | Constraint::ConstantNumericType(_, TypeOrVar::Variable(_, _)) - | Constraint::FitsInNumType(_, TypeOrVar::Variable(_, _), _) - | Constraint::IsSomething(_, TypeOrVar::Variable(_, _)) - | Constraint::IsSigned(_, TypeOrVar::Variable(_, _)) - | Constraint::NumericType(_, TypeOrVar::Variable(_, _)) - | Constraint::Printable(_, TypeOrVar::Variable(_, _)) => { - all_constraints_solved = false; - new_constraints.push(constraint); - } - } - } + Constraint::CanCastTo( + loc, + ft @ TypeOrVar::Function(_, _), + ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Structure(_), + ) => { + tracing::trace!(function_type = %ft, other_type = %ot, "we can't cast a function type to a primitive or structure type"); + self.errors + .push(TypeInferenceError::CannotCast(loc, ft, ot)); + changed_something = true; + } - if all_constraints_solved { - let result = new_constraints - .into_iter() - .map(|constraint| match constraint { - Constraint::Equivalent(_, TypeOrVar::Variable(_, name), result) => { - match result.try_into() { - Err(e) => panic!("Ended up with complex type {}", e), - Ok(v) => (name, v), + Constraint::CanCastTo( + loc, + pt @ TypeOrVar::Primitive(_), + ot @ TypeOrVar::Function(_, _) | ot @ TypeOrVar::Structure(_), + ) => { + tracing::trace!(other_type = %ot, primitive_type = %pt, "we can't cast a primitive type to a function or structure type"); + self.errors + .push(TypeInferenceError::CannotCast(loc, pt, ot)); + changed_something = true; + } + + Constraint::CanCastTo( + loc, + st @ TypeOrVar::Structure(_), + ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _), + ) => { + tracing::trace!(structure_type = %st, other_type = %ot, "we can't cast a structure type to a function or primitive type"); + self.errors + .push(TypeInferenceError::CannotCast(loc, st, ot)); + changed_something = true; + } + + Constraint::NamedTypeIs(loc, name, ty) => { + match self.type_definitions.get(&name) { + None => { + tracing::trace!(type_name = %name, "we don't know a type named name"); + self.errors + .push(TypeInferenceError::UnknownTypeName(loc, name)); + changed_something = true; + } + + Some(declared_type) => { + tracing::trace!(type_name = %name, declared = %declared_type, provided = %ty, "validating that named type is equivalent to provided"); + new_constraints.push(Constraint::Equivalent( + loc, + declared_type.clone(), + ty, + )); + changed_something = true; + } } } - _ => panic!("Had something that wasn't an equivalence left at the end!"), - }) - .collect(); - return TypeInferenceResult::Success { result, warnings }; - } - if !changed_something { - let mut addendums = vec![]; + Constraint::TypeHasField( + loc, + TypeOrVar::Structure(mut fields), + field, + result_type, + ) => match fields.remove_field(&field) { + None => { + let reconstituted = TypeOrVar::Structure(fields); + tracing::trace!(structure_type = %reconstituted, %field, "no field found in type"); + self.errors.push(TypeInferenceError::NoFieldForType( + loc, + field, + reconstituted, + )); + changed_something = true; + } - macro_rules! default_type { - ($addendums: ident, $loc: ident, $t: ident) => { - let resty = TypeOrVar::Primitive(PrimitiveType::U64); - $addendums.push(Constraint::Equivalent( - $loc.clone(), - $t.clone(), - resty.clone(), - )); - warnings.push(TypeInferenceWarning::DefaultedTo($loc.clone(), resty)); - tracing::trace!("Adding number equivalence"); - }; + Some(field_subtype) => { + tracing::trace!(%field_subtype, %result_type, %field, "validating that field's subtype matches target result type"); + new_constraints.push(Constraint::Equivalent( + loc, + result_type, + field_subtype, + )); + changed_something = true; + } + }, + + Constraint::TypeHasField( + loc, + ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _), + field, + _, + ) => { + tracing::trace!(other_type = %ot, %field, "can't get field from primitive or function type"); + self.errors + .push(TypeInferenceError::NoFieldForType(loc, field, ot)); + changed_something = true; + } + + // if we're testing if an actual primitive type is numeric, that's pretty easy + Constraint::ConstantNumericType(loc, TypeOrVar::Primitive(pt)) => { + tracing::trace!(primitive_type = %pt, "its easy to tell if a constant number can be a primitive type"); + if pt.max_value().is_none() { + self.errors.push(TypeInferenceError::NotANumber(loc, pt)) + } + changed_something = true; + } + + // if we're testing if a function type is numeric, then throw a useful warning + Constraint::ConstantNumericType(loc, t @ TypeOrVar::Function(_, _)) => { + tracing::trace!(function_type = %t, "functions can't be constant numbers"); + self.errors + .push(TypeInferenceError::CannotMakeNumberAFunction(loc, t, None)); + changed_something = true; + } + + // if we're testing if a function type is numeric, then throw a useful warning + Constraint::ConstantNumericType(loc, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(structure_type = %t, "structures can't be constant numbers"); + self.errors + .push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None)); + changed_something = true; + } + + // if we're testing if a number can fit into a numeric type, we can just do that! + Constraint::FitsInNumType(loc, TypeOrVar::Primitive(ctype), val) => { + match ctype.max_value() { + None => self.errors.push(TypeInferenceError::NotANumber(loc, ctype)), + + Some(max_value) if max_value < val => { + self.errors + .push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); + } + + Some(_) => {} + } + changed_something = true; + tracing::trace!(primitive_type = %ctype, value = val, "we can test for a value fitting in a primitive type"); + } + + // if we're testing if a function type can fit into a numeric type, that's a problem + Constraint::FitsInNumType(loc, t @ TypeOrVar::Function(_, _), val) => { + tracing::trace!(function_type = %t, "values don't fit in function types"); + self.errors + .push(TypeInferenceError::CannotMakeNumberAFunction( + loc, + t, + Some(val), + )); + changed_something = true; + } + + // if we're testing if a function type can fit into a numeric type, that's a problem + Constraint::FitsInNumType(loc, t @ TypeOrVar::Structure(_), val) => { + tracing::trace!(function_type = %t, "values don't fit in structure types"); + self.errors + .push(TypeInferenceError::CannotMakeNumberAStructure( + loc, + t, + Some(val), + )); + changed_something = true; + } + + // if we want to know if a type is something, and it is something, then we're done + Constraint::IsSomething(_, t @ TypeOrVar::Function(_, _)) + | Constraint::IsSomething(_, t @ TypeOrVar::Primitive(_)) + | Constraint::IsSomething(_, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(tested_type = %t, "type is definitely something"); + changed_something = true; + } + + // if we want to know if something is signed, we can check its primitive type + Constraint::IsSigned(loc, TypeOrVar::Primitive(pt)) => { + tracing::trace!(primitive_type = %pt, "we can check if a primitive is signed"); + if !pt.valid_operators().contains(&("-", 1)) { + self.errors.push(TypeInferenceError::IsNotSigned( + loc, + TypeOrVar::Primitive(pt), + )); + } + changed_something = true; + } + + // again with the functions and the numbers + Constraint::IsSigned(loc, t @ TypeOrVar::Function(_, _)) => { + tracing::trace!(function_type = %t, "functions are not signed"); + self.errors.push(TypeInferenceError::IsNotSigned(loc, t)); + changed_something = true; + } + + // again with the functions and the numbers + Constraint::IsSigned(loc, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(structure_type = %t, "structures are not signed"); + self.errors.push(TypeInferenceError::IsNotSigned(loc, t)); + changed_something = true; + } + + // if we're testing if an actual primitive type is numeric, that's pretty easy + Constraint::NumericType(loc, TypeOrVar::Primitive(pt)) => { + tracing::trace!(primitive_type = %pt, "its easy to tell if a primitive type is numeric"); + if pt.max_value().is_none() { + self.errors.push(TypeInferenceError::NotANumber(loc, pt)) + } + changed_something = true; + } + + // if we're testing if a function type is numeric, then throw a useful warning + Constraint::NumericType(loc, t @ TypeOrVar::Function(_, _)) => { + tracing::trace!(function_type = %t, "function types aren't numeric"); + self.errors + .push(TypeInferenceError::CannotMakeNumberAFunction(loc, t, None)); + changed_something = true; + } + + // if we're testing if a structure type is numeric, then throw a useful warning + Constraint::NumericType(loc, t @ TypeOrVar::Structure(_)) => { + tracing::trace!(structure_type = %t, "structure types aren't numeric"); + self.errors + .push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None)); + changed_something = true; + } + + // all of our primitive types are printable + Constraint::Printable(_, TypeOrVar::Primitive(pt)) => { + tracing::trace!(primitive_type = %pt, "primitive types are printable"); + changed_something = true; + } + + // function types are definitely not printable + Constraint::Printable(loc, ft @ TypeOrVar::Function(_, _)) => { + tracing::trace!(function_type = %ft, "function types are not printable"); + self.errors + .push(TypeInferenceError::FunctionsAreNotPrintable(loc)); + changed_something = true; + } + + // structure types are printable if all the types inside them are printable + Constraint::Printable(loc, TypeOrVar::Structure(fields)) => { + tracing::trace!( + "structure types are printable if all their subtypes are printable" + ); + for (_, subtype) in fields.into_iter() { + new_constraints.push(Constraint::Printable(loc.clone(), subtype)); + } + changed_something = true; + } + + // Some equivalences we can/should solve directly + Constraint::Equivalent( + loc, + TypeOrVar::Primitive(pt1), + TypeOrVar::Primitive(pt2), + ) => { + if pt1 != pt2 { + self.errors.push(TypeInferenceError::NotEquivalent( + loc, + TypeOrVar::Primitive(pt1), + TypeOrVar::Primitive(pt2), + )); + } + changed_something = true; + tracing::trace!(primitive_type1 = %pt1, primitive_type2 = %pt2, "we checked for primitive type equivalence"); + } + + Constraint::Equivalent( + loc, + pt @ TypeOrVar::Primitive(_), + ft @ TypeOrVar::Function(_, _), + ) + | Constraint::Equivalent( + loc, + ft @ TypeOrVar::Function(_, _), + pt @ TypeOrVar::Primitive(_), + ) => { + tracing::trace!(primitive_type = %pt, function_type = %ft, "function and primitive types cannot be equivalent"); + self.errors + .push(TypeInferenceError::NotEquivalent(loc, pt, ft)); + changed_something = true; + } + + Constraint::Equivalent( + loc, + pt @ TypeOrVar::Primitive(_), + st @ TypeOrVar::Structure(_), + ) + | Constraint::Equivalent( + loc, + st @ TypeOrVar::Structure(_), + pt @ TypeOrVar::Primitive(_), + ) => { + tracing::trace!(primitive_type = %pt, structure_type = %st, "structure and primitive types cannot be equivalent"); + self.errors + .push(TypeInferenceError::NotEquivalent(loc, pt, st)); + changed_something = true; + } + + Constraint::Equivalent( + loc, + st @ TypeOrVar::Structure(_), + ft @ TypeOrVar::Function(_, _), + ) + | Constraint::Equivalent( + loc, + ft @ TypeOrVar::Function(_, _), + st @ TypeOrVar::Structure(_), + ) => { + tracing::trace!(structure_type = %st, function_type = %ft, "structure and primitive types cannot be equivalent"); + self.errors + .push(TypeInferenceError::NotEquivalent(loc, st, ft)); + changed_something = true; + } + + Constraint::Equivalent( + _, + TypeOrVar::Variable(_, name1), + TypeOrVar::Variable(_, name2), + ) if name1 == name2 => { + tracing::trace!(name = %name1, "variable is equivalent to itself"); + changed_something = true; + } + + Constraint::Equivalent( + loc, + TypeOrVar::Function(args1, ret1), + TypeOrVar::Function(args2, ret2), + ) => { + if args1.len() != args2.len() { + let t1 = TypeOrVar::Function(args1, ret1); + let t2 = TypeOrVar::Function(args2, ret2); + self.errors + .push(TypeInferenceError::NotEquivalent(loc, t1, t2)); + } else { + for (left, right) in args1.into_iter().zip(args2) { + new_constraints.push(Constraint::Equivalent( + loc.clone(), + left, + right, + )); + } + new_constraints.push(Constraint::Equivalent(loc, *ret1, *ret2)); + all_constraints_solved = false; + } + + changed_something = true; + tracing::trace!("we checked/rewrote if function types are equivalent"); + } + + Constraint::Equivalent( + loc, + TypeOrVar::Structure(fields1), + TypeOrVar::Structure(mut fields2), + ) => { + if fields1.count() == fields2.count() + && fields1.field_names().all(|x| fields2.has_field(x)) + { + for (name, subtype1) in fields1.into_iter() { + let subtype2 = fields2 + .remove_field(&name) + .expect("can find matching field after equivalence check"); + new_constraints.push(Constraint::Equivalent( + loc.clone(), + subtype1, + subtype2, + )); + } + } else { + self.errors.push(TypeInferenceError::NotEquivalent( + loc, + TypeOrVar::Structure(fields1), + TypeOrVar::Structure(fields2), + )) + } + changed_something = true; + tracing::trace!("we checked/rewrote if structures are equivalent"); + } + + Constraint::Equivalent(_, TypeOrVar::Variable(_, ref name), ref rhs) => { + changed_something |= replace_variable(&mut self.constraints, name, rhs); + changed_something |= replace_variable(&mut new_constraints, name, rhs); + all_constraints_solved &= rhs.is_resolved(); + if changed_something { + tracing::trace!(%name, new_type = %rhs, "we were able to rewrite name somewhere"); + } + new_constraints.push(constraint); + } + + Constraint::Equivalent(loc, lhs, rhs @ TypeOrVar::Variable(_, _)) => { + tracing::trace!(new_left = %rhs, new_right = %lhs, "we flipped the order on an equivalence"); + new_constraints.push(Constraint::Equivalent(loc, rhs, lhs)); + changed_something = true; + all_constraints_solved = false; + } + + Constraint::CanCastTo(_, TypeOrVar::Variable(_, _), _) + | Constraint::CanCastTo(_, _, TypeOrVar::Variable(_, _)) + | Constraint::TypeHasField(_, TypeOrVar::Variable(_, _), _, _) + | Constraint::ConstantNumericType(_, TypeOrVar::Variable(_, _)) + | Constraint::FitsInNumType(_, TypeOrVar::Variable(_, _), _) + | Constraint::IsSomething(_, TypeOrVar::Variable(_, _)) + | Constraint::IsSigned(_, TypeOrVar::Variable(_, _)) + | Constraint::NumericType(_, TypeOrVar::Variable(_, _)) + | Constraint::Printable(_, TypeOrVar::Variable(_, _)) => { + all_constraints_solved = false; + new_constraints.push(constraint); + } + } } - new_constraints.retain(|x| { - if let Constraint::ConstantNumericType(loc, t) = x { - default_type!(addendums, loc, t); - false - } else { - true - } - }); + if all_constraints_solved { + return; + } + + if !changed_something { + let mut addendums = vec![]; + + macro_rules! default_type { + ($addendums: ident, $loc: ident, $t: ident) => { + let resty = TypeOrVar::Primitive(PrimitiveType::U64); + $addendums.push(Constraint::Equivalent( + $loc.clone(), + $t.clone(), + resty.clone(), + )); + self.warnings + .push(TypeInferenceWarning::DefaultedTo($loc.clone(), resty)); + tracing::trace!("Adding number equivalence"); + }; + } - if addendums.is_empty() { new_constraints.retain(|x| { - if let Constraint::IsSomething(loc, t) = x { + if let Constraint::ConstantNumericType(loc, t) = x { default_type!(addendums, loc, t); false } else { true } }); - } - if addendums.is_empty() { - if errors.is_empty() { - errors = new_constraints - .into_iter() - .map(TypeInferenceError::CouldNotSolve) - .collect(); + if addendums.is_empty() { + new_constraints.retain(|x| { + if let Constraint::IsSomething(loc, t) = x { + default_type!(addendums, loc, t); + false + } else { + true + } + }); } - return TypeInferenceResult::Failure { errors, warnings }; + + if addendums.is_empty() { + if self.errors.is_empty() { + self.errors = new_constraints + .into_iter() + .map(TypeInferenceError::CouldNotSolve) + .collect(); + } + return; + } + + new_constraints.append(&mut addendums); } - new_constraints.append(&mut addendums); + self.constraints = new_constraints; } - - constraint_db = new_constraints; } } diff --git a/src/type_infer/warning.rs b/src/type_infer/warning.rs new file mode 100644 index 0000000..c5db275 --- /dev/null +++ b/src/type_infer/warning.rs @@ -0,0 +1,21 @@ +use crate::ir::TypeOrVar; +use crate::syntax::Location; +use codespan_reporting::diagnostic::Diagnostic; + +/// Warnings that we might want to tell the user about. +/// +/// These are fine, probably, but could indicate some behavior the user might not +/// expect, and so they might want to do something about them. +pub enum TypeInferenceWarning { + DefaultedTo(Location, TypeOrVar), +} + +impl From for Diagnostic { + fn from(value: TypeInferenceWarning) -> Self { + match value { + TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning() + .with_labels(vec![loc.primary_label().with_message("unknown type")]) + .with_message(format!("Defaulted unknown type to {}", ty)), + } + } +}