From 71228b9e0915fd7ee70da2c3656305aba9da1775 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sat, 14 Oct 2023 16:39:16 -0700 Subject: [PATCH] Checkpoint --- src/backend/into_crane.rs | 2 +- src/eval.rs | 18 ++++- src/eval/primop.rs | 15 +++- src/eval/primtype.rs | 43 +++++++---- src/eval/value.rs | 86 ++++++++++++++++++--- src/ir/ast.rs | 58 +++++++++++--- src/ir/eval.rs | 6 +- src/ir/strings.rs | 7 +- src/syntax.rs | 2 +- src/syntax/location.rs | 4 +- src/syntax/validate.rs | 2 +- src/type_infer.rs | 2 +- src/type_infer/ast.rs | 19 +---- src/type_infer/convert.rs | 4 +- src/type_infer/solve.rs | 158 ++++++++++++++++++++++++++++++++++++-- src/util.rs | 1 + src/util/pretty.rs | 47 ++++++++++++ src/util/scoped_map.rs | 6 ++ 18 files changed, 402 insertions(+), 78 deletions(-) create mode 100644 src/util/pretty.rs diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 4d3d7ea..a9989de 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -122,7 +122,7 @@ impl Backend { // state, it's easier to just include them. for item in program.items.drain(..) { match item { - TopLevel::Function(_, _, _) => unimplemented!(), + TopLevel::Function(_, _, _, _) => unimplemented!(), // Print statements are fairly easy to compile: we just lookup the // output buffer, the address of the string to print, and the value diff --git a/src/eval.rs b/src/eval.rs index a593a63..b45b4d6 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -46,6 +46,8 @@ pub use value::Value; use crate::backend::BackendError; +use self::primtype::UnknownPrimType; + /// All of the errors that can happen trying to evaluate an NGR program. /// /// This is yet another standard [`thiserror::Error`] type, but with the @@ -71,9 +73,13 @@ pub enum EvalError { ExitCode(std::process::ExitStatus), #[error("Unexpected output at runtime: {0}")] RuntimeOutput(String), + #[error("Cannot cast to function type: {0}")] + CastToFunction(String), + #[error(transparent)] + UnknownPrimType(#[from] UnknownPrimType), } -impl PartialEq for EvalError { +impl PartialEq for EvalError { fn eq(&self, other: &Self) -> bool { match self { EvalError::Lookup(a) => match other { @@ -115,6 +121,16 @@ impl PartialEq for EvalError { EvalError::RuntimeOutput(b) => a == b, _ => false, }, + + EvalError::CastToFunction(a) => match other { + EvalError::CastToFunction(b) => a == b, + _ => false, + }, + + EvalError::UnknownPrimType(a) => match other { + EvalError::UnknownPrimType(b) => a == b, + _ => false, + }, } } } diff --git a/src/eval/primop.rs b/src/eval/primop.rs index 01a8dd1..41634a3 100644 --- a/src/eval/primop.rs +++ b/src/eval/primop.rs @@ -1,6 +1,8 @@ use crate::eval::primtype::PrimitiveType; use crate::eval::value::Value; +use super::primtype::{UnknownPrimType, ValuePrimitiveTypeError}; + /// Errors that can occur running primitive operations in the evaluators. #[derive(Clone, Debug, PartialEq, thiserror::Error)] pub enum PrimOpError { @@ -14,7 +16,7 @@ pub enum PrimOpError { /// This variant covers when an operator must take a particular /// type, but the user has provided a different one. #[error("Bad type for operator {0}: {1}")] - BadTypeFor(&'static str, Value), + BadTypeFor(String, Value), /// Probably obvious from the name, but just to be very clear: this /// happens when you pass three arguments to a two argument operator, /// etc. Technically that's a type error of some sort, but we split @@ -28,8 +30,10 @@ pub enum PrimOpError { from: PrimitiveType, to: PrimitiveType, }, - #[error("Unknown primitive type {0}")] - UnknownPrimType(String), + #[error(transparent)] + UnknownPrimType(#[from] UnknownPrimType), + #[error(transparent)] + ValuePrimitiveTypeError(#[from] ValuePrimitiveTypeError), } // Implementing primitives in an interpreter like this is *super* tedious, @@ -63,7 +67,7 @@ impl Value { Value::I16(x) => Ok(Value::I16(x.wrapping_neg())), Value::I32(x) => Ok(Value::I32(x.wrapping_neg())), Value::I64(x) => Ok(Value::I64(x.wrapping_neg())), - _ => Err(PrimOpError::BadTypeFor("-", value.clone())), + _ => Err(PrimOpError::BadTypeFor("-".to_string(), value.clone())), }, _ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)), } @@ -135,6 +139,9 @@ impl Value { right.clone(), )), }, + Value::Function(_, _) => { + Err(PrimOpError::BadTypeFor(operation.to_string(), left.clone())) + } } } diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index 7690137..de1e6eb 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -31,17 +31,28 @@ impl Display for PrimitiveType { } } -impl<'a> From<&'a Value> for PrimitiveType { - fn from(value: &Value) -> Self { +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum ValuePrimitiveTypeError { + #[error("Could not convert function value to primitive type (possible function name: {0:?}")] + CannotConvertFunction(Option), +} + +impl<'a> TryFrom<&'a Value> for PrimitiveType { + type Error = ValuePrimitiveTypeError; + + fn try_from(value: &'a Value) -> Result { match value { - Value::I8(_) => PrimitiveType::I8, - Value::I16(_) => PrimitiveType::I16, - Value::I32(_) => PrimitiveType::I32, - Value::I64(_) => PrimitiveType::I64, - Value::U8(_) => PrimitiveType::U8, - Value::U16(_) => PrimitiveType::U16, - Value::U32(_) => PrimitiveType::U32, - Value::U64(_) => PrimitiveType::U64, + Value::I8(_) => Ok(PrimitiveType::I8), + Value::I16(_) => Ok(PrimitiveType::I16), + Value::I32(_) => Ok(PrimitiveType::I32), + Value::I64(_) => Ok(PrimitiveType::I64), + Value::U8(_) => Ok(PrimitiveType::U8), + Value::U16(_) => Ok(PrimitiveType::U16), + Value::U32(_) => Ok(PrimitiveType::U32), + Value::U64(_) => Ok(PrimitiveType::U64), + Value::Function(name, _) => { + Err(ValuePrimitiveTypeError::CannotConvertFunction(name.clone())) + } } } } @@ -61,8 +72,14 @@ impl From for PrimitiveType { } } +#[derive(thiserror::Error, Debug, Clone, PartialEq)] +pub enum UnknownPrimType { + #[error("Could not convert '{0}' into a primitive type")] + UnknownPrimType(String), +} + impl FromStr for PrimitiveType { - type Err = PrimOpError; + type Err = UnknownPrimType; fn from_str(s: &str) -> Result { match s { @@ -74,7 +91,7 @@ impl FromStr for PrimitiveType { "u16" => Ok(PrimitiveType::U16), "u32" => Ok(PrimitiveType::U32), "u64" => Ok(PrimitiveType::U64), - _ => Err(PrimOpError::UnknownPrimType(s.to_string())), + _ => Err(UnknownPrimType::UnknownPrimType(s.to_owned())), } } } @@ -152,7 +169,7 @@ impl PrimitiveType { (PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)), _ => Err(PrimOpError::UnsafeCast { - from: source.into(), + from: PrimitiveType::try_from(source)?, to: *self, }), } diff --git a/src/eval/value.rs b/src/eval/value.rs index 12f3746..3634a2f 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -1,11 +1,13 @@ -use std::fmt::Display; +use super::EvalError; +use std::fmt; +use std::rc::Rc; /// Values in the interpreter. /// /// Yes, this is yet another definition of a structure called `Value`, which /// are almost entirely identical. However, it's nice to have them separated /// by type so that we don't mix them up. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone)] pub enum Value { I8(i8), I16(i16), @@ -15,19 +17,79 @@ pub enum Value { U16(u16), U32(u32), U64(u64), + Function( + Option, + Rc) -> Result>, + ), } -impl Display for Value { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +fn format_value(value: &Value, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match value { + Value::I8(x) => write!(f, "{}i8", x), + Value::I16(x) => write!(f, "{}i16", x), + Value::I32(x) => write!(f, "{}i32", x), + Value::I64(x) => write!(f, "{}i64", x), + Value::U8(x) => write!(f, "{}u8", x), + Value::U16(x) => write!(f, "{}u16", x), + Value::U32(x) => write!(f, "{}u32", x), + Value::U64(x) => write!(f, "{}u64", x), + Value::Function(Some(name), _) => write!(f, "", name), + Value::Function(None, _) => write!(f, ""), + } +} + +impl fmt::Debug for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_value(self, f) + } +} + +impl fmt::Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_value(self, f) + } +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { match self { - Value::I8(x) => write!(f, "{}i8", x), - Value::I16(x) => write!(f, "{}i16", x), - Value::I32(x) => write!(f, "{}i32", x), - Value::I64(x) => write!(f, "{}i64", x), - Value::U8(x) => write!(f, "{}u8", x), - Value::U16(x) => write!(f, "{}u16", x), - Value::U32(x) => write!(f, "{}u32", x), - Value::U64(x) => write!(f, "{}u64", x), + Value::I8(x) => match other { + Value::I8(y) => x == y, + _ => false, + }, + Value::I16(x) => match other { + Value::I16(y) => x == y, + _ => false, + }, + Value::I32(x) => match other { + Value::I32(y) => x == y, + _ => false, + }, + Value::I64(x) => match other { + Value::I64(y) => x == y, + _ => false, + }, + Value::U8(x) => match other { + Value::U8(y) => x == y, + _ => false, + }, + Value::U16(x) => match other { + Value::U16(y) => x == y, + _ => false, + }, + Value::U32(x) => match other { + Value::U32(y) => x == y, + _ => false, + }, + Value::U64(x) => match other { + Value::U64(y) => x == y, + _ => false, + }, + Value::Function(Some(x), _) => match other { + Value::Function(Some(y), _) => x == y, + _ => false, + }, + Value::Function(None, _) => false, } } } diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 6a7b62d..c63f5db 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -1,6 +1,7 @@ use crate::{ eval::PrimitiveType, syntax::{self, ConstantType, Location}, + util::pretty::{pretty_comma_separated, PrettySymbol}, }; use internment::ArcIntern; use pretty::{BoxAllocator, DocAllocator, Pretty}; @@ -87,20 +88,33 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { - TopLevel::Function(name, args, stmts, body) => allocator - .text("function") - .append(allocator.space()) - .append(allocator.text(name.as_ref().to_string())) - .append( - allocator - .intersperse( - args.iter().map(|x| allocator.text(x.as_ref().to_string())), - ", ", + TopLevel::Function(name, args, stmts, expr) => { + let base = allocator + .text("function") + .append(allocator.space()) + .append(allocator.text(name.as_ref().to_string())) + .append(allocator.space()) + .append( + pretty_comma_separated( + allocator, + &args.iter().map(PrettySymbol::from).collect(), ) .parens(), - ) - .append(allocator.space()) - .append(body.pretty(allocator)), + ) + .append(allocator.space()); + + let mut body = allocator.nil(); + for stmt in stmts { + body = body + .append(stmt.pretty(allocator)) + .append(allocator.text(";")) + .append(allocator.hardline()); + } + body = body.append(expr.pretty(allocator)); + body = body.append(allocator.hardline()); + body = body.braces(); + base.append(body) + } TopLevel::Statement(stmt) => stmt.pretty(allocator), } @@ -389,6 +403,14 @@ where fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { match self { Type::Primitive(pt) => allocator.text(format!("{}", pt)), + Type::Function(args, rettype) => { + pretty_comma_separated(allocator, &args.iter().collect()) + .parens() + .append(allocator.space()) + .append(allocator.text("->")) + .append(allocator.space()) + .append(rettype.pretty(allocator)) + } } } } @@ -397,6 +419,18 @@ impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Primitive(pt) => pt.fmt(f), + Type::Function(args, ret) => { + write!(f, "(")?; + let mut argiter = args.iter().peekable(); + while let Some(arg) = argiter.next() { + arg.fmt(f)?; + if argiter.peek().is_some() { + write!(f, ",")?; + } + } + write!(f, "->")?; + ret.fmt(f) + } } } } diff --git a/src/ir/eval.rs b/src/ir/eval.rs index 46123f3..93d56b0 100644 --- a/src/ir/eval.rs +++ b/src/ir/eval.rs @@ -1,8 +1,7 @@ +use super::{Primitive, Type, ValueOrRef}; use crate::eval::{EvalEnvironment, EvalError, Value}; use crate::ir::{Expression, Program, Statement, TopLevel}; -use super::{Primitive, Type, ValueOrRef}; - impl Program { /// Evaluate the program, returning either an error or a string containing everything /// the program printed out. @@ -14,7 +13,7 @@ impl Program { for stmt in self.items.iter() { match stmt { - TopLevel::Function(_, _, _) => unimplemented!(), + TopLevel::Function(_, _, _, _) => unimplemented!(), TopLevel::Statement(Statement::Binding(_, name, _, value)) => { let actual_value = value.eval(&env)?; @@ -43,6 +42,7 @@ impl Expression { match t { Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), + Type::Function(_, _) => Err(EvalError::CastToFunction(t.to_string())), } } diff --git a/src/ir/strings.rs b/src/ir/strings.rs index 6326672..dc05432 100644 --- a/src/ir/strings.rs +++ b/src/ir/strings.rs @@ -21,7 +21,12 @@ impl Program { impl TopLevel { fn register_strings(&self, string_set: &mut HashSet>) { match self { - TopLevel::Function(_, _, body) => body.register_strings(string_set), + TopLevel::Function(_, _, stmts, body) => { + for stmt in stmts.iter() { + stmt.register_strings(string_set); + } + body.register_strings(string_set); + } TopLevel::Statement(stmt) => stmt.register_strings(string_set), } } diff --git a/src/syntax.rs b/src/syntax.rs index 2a8acc8..a33119b 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -29,7 +29,7 @@ use logos::Logos; pub mod arbitrary; mod ast; -mod eval; +pub mod eval; mod location; mod tokens; lalrpop_mod!( diff --git a/src/syntax/location.rs b/src/syntax/location.rs index 5acd070..09e6e0c 100644 --- a/src/syntax/location.rs +++ b/src/syntax/location.rs @@ -85,9 +85,9 @@ impl Location { /// the user with some guidance. That being said, you still might want to add /// even more information to ut, using [`Diagnostic::with_labels`], /// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`]. - pub fn labelled_error(&self, msg: &str) -> Diagnostic { + pub fn labelled_error>(&self, msg: T) -> Diagnostic { Diagnostic::error().with_labels(vec![ - Label::primary(self.file_idx, self.location.clone()).with_message(msg) + Label::primary(self.file_idx, self.location.clone()).with_message(msg.as_ref()) ]) } diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index f175eed..46c471c 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -123,7 +123,7 @@ impl TopLevel { for arg in arguments.iter() { bound_variables.insert(arg.name.clone(), arg.location.clone()); } - let result = body.validate(&bound_variables); + let result = body.validate(bound_variables); bound_variables.release_scope(); result } diff --git a/src/type_infer.rs b/src/type_infer.rs index 0402998..08cd2da 100644 --- a/src/type_infer.rs +++ b/src/type_infer.rs @@ -47,6 +47,6 @@ proptest::proptest! { let syntax_result = input.eval(); let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); let ir_result = ir.eval(); - proptest::prop_assert_eq!(syntax_result, ir_result); + proptest::prop_assert!(syntax_result.eq(&ir_result)); } } diff --git a/src/type_infer/ast.rs b/src/type_infer/ast.rs index daffbf7..21eb2fc 100644 --- a/src/type_infer/ast.rs +++ b/src/type_infer/ast.rs @@ -9,6 +9,7 @@ pub use crate::ir::ast::Primitive; use crate::{ eval::PrimitiveType, syntax::{self, ConstantType, Location}, + util::pretty::{pretty_comma_separated, PrettySymbol}, }; use internment::ArcIntern; use pretty::{DocAllocator, Pretty}; @@ -87,10 +88,7 @@ where .append( pretty_comma_separated( allocator, - &args - .iter() - .map(|x| allocator.text(x.as_ref().to_string())) - .collect(), + &args.iter().map(PrettySymbol::from).collect(), ) .parens(), ) @@ -408,16 +406,3 @@ pub fn gentype() -> Type { Type::Variable(Location::manufactured(), name) } - -fn pretty_comma_separated<'a, D, A, P>( - allocator: &'a D, - args: &Vec

, -) -> pretty::DocBuilder<'a, D, A> -where - A: 'a, - D: ?Sized + DocAllocator<'a, A>, - P: Pretty<'a, D, A>, -{ - let individuals = args.iter().map(|x| x.pretty(allocator)); - allocator.intersperse(individuals, ", ") -} diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index bff8c2a..e6dc161 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -65,7 +65,7 @@ pub fn convert_top_level( // Now let's bind these types into the environment. First, we bind our function // namae to the function type we just generated. - bindings.insert(funname, funtype); + bindings.insert(funname.clone(), funtype); // And then we attach the argument names to the argument types. (We have to go // convert all the names, first.) let iargs: Vec> = @@ -291,7 +291,7 @@ fn finalize_name( renames: &mut HashMap, ArcIntern>, name: syntax::Name, ) -> ArcIntern { - if bindings.contains_key(&ArcIntern::new(name.name)) { + if bindings.contains_key(&ArcIntern::new(name.name.clone())) { let new_name = ir::gensym(&name.name); renames.insert(ArcIntern::new(name.name.to_string()), new_name.clone()); new_name diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index ec33384..76d072b 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -102,11 +102,19 @@ pub enum TypeInferenceError { /// The user provide a constant that is too large for its inferred type. ConstantTooLarge(Location, PrimitiveType, u64), /// The two types needed to be equivalent, but weren't. - NotEquivalent(Location, PrimitiveType, PrimitiveType), + NotEquivalent(Location, Type, Type), /// We cannot safely cast the first type to the second type. CannotSafelyCast(Location, PrimitiveType, PrimitiveType), /// The primitive invocation provided the wrong number of arguments. WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize), + /// We cannot cast between function types at the moment. + CannotCastBetweenFunctinoTypes(Location, Type, Type), + /// We cannot cast from a function type to something else. + CannotCastFromFunctionType(Location, Type), + /// We cannot cast to a function type from something else. + CannotCastToFunctionType(Location, Type), + /// We cannot turn a number into a function. + CannotMakeNumberAFunction(Location, Type, Option), /// We had a constraint we just couldn't solve. CouldNotSolve(Constraint), } @@ -142,6 +150,29 @@ impl From for Diagnostic { prim, observed )), + TypeInferenceError::CannotCastBetweenFunctinoTypes(loc, t1, t2) => loc + .labelled_error("cannot cast between function types") + .with_message(format!( + "tried to cast from {} to {}", + t1, t2, + )), + TypeInferenceError::CannotCastFromFunctionType(loc, t) => loc + .labelled_error("cannot cast from a function type to anything else") + .with_message(format!( + "function type was {}", t, + )), + TypeInferenceError::CannotCastToFunctionType(loc, t) => loc + .labelled_error("cannot cast to a function type") + .with_message(format!( + "function type was {}", t, + )), + TypeInferenceError::CannotMakeNumberAFunction(loc, t, val) => loc + .labelled_error(if let Some(val) = val { + format!("cannot turn {} into a function", val) + } else { + "cannot use a constant as a function type".to_string() + }) + .with_message(format!("function type was {}", t)), TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => { loc.labelled_error("internal error").with_message(format!( "could not determine if it was safe to cast from {} to {:#?}", @@ -236,12 +267,12 @@ pub fn solve_constraints( // Currently, all of our types are printable Constraint::Printable(_loc, _ty) => changed_something = true, - // Case #1: We have two primitive types. If they're equal, we've discharged this + // Case #1a: We have two primitive types. If they're equal, we've discharged this // constraint! We can just continue. If they're not equal, add an error and then // see what else we come up with. - Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => { - if t1 != t2 { - errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2)); + Constraint::Equivalent(loc, a @ Type::Primitive(_), b @ Type::Primitive(_)) => { + if a != b { + errors.push(TypeInferenceError::NotEquivalent(loc, a, b)); } changed_something = true; } @@ -257,7 +288,11 @@ pub fn solve_constraints( resolutions.insert(name, t); } Some(t2) if &t == t2 => {} - Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)), + Some(t2) => errors.push(TypeInferenceError::NotEquivalent( + loc, + Type::Primitive(t), + Type::Primitive(*t2), + )), } changed_something = true; } @@ -284,11 +319,61 @@ pub fn solve_constraints( changed_something = true; } (Some(pt1), Some(pt2)) => { - errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2)); + errors.push(TypeInferenceError::NotEquivalent( + loc.clone(), + Type::Primitive(*pt1), + Type::Primitive(*pt2), + )); changed_something = true; } }, + // Case #4: Like primitives, but for function types. This is a little complicated, because + // we first want to resolve all the type variables in the two types, and then see if they're + // equivalent. Fortunately, though, we can cheat a bit. What we're going to do is first see + // if the two types have the same arity (the same number of arguments). If not, we know the + // types don't match. If they do, then we're going to just turn this into a bunch of different + // equivalence constraints by matching up each of the arguments as well as the return type, and + // then restarting the type checking loop. That will cause any of those type variables to be + // handled appropriately. This even works recursively, so we can support arbitrarily nested + // function types. + Constraint::Equivalent( + loc, + ref a @ Type::Function(ref args1, ref ret1), + ref b @ Type::Function(ref args2, ref ret2), + ) => { + if args1.len() != args2.len() { + errors.push(TypeInferenceError::NotEquivalent( + loc.clone(), + a.clone(), + b.clone(), + )); + } else { + for (left, right) in args1.iter().zip(args2.iter()) { + constraint_db.push(Constraint::Equivalent( + loc.clone(), + left.clone(), + right.clone(), + )); + } + } + + constraint_db.push(Constraint::Equivalent( + loc, + ret1.as_ref().clone(), + ret2.as_ref().clone(), + )); + + changed_something = true; + } + + // Case #5: They're just totally the wrong types. In which case, we're done with + // this one; emit the error and drop the constraint. + Constraint::Equivalent(loc, a, b) => { + errors.push(TypeInferenceError::NotEquivalent(loc, a, b)); + changed_something = true; + } + // Make sure that the provided number fits within the provided constant type. For the // moment, we're going to call an error here a failure, although this could be a // warning in the future. @@ -320,6 +405,15 @@ pub fn solve_constraints( } } + // Function types definitely do not fit in numeric types + Constraint::FitsInNumType(loc, t @ Type::Function(_, _), val) => { + errors.push(TypeInferenceError::CannotMakeNumberAFunction( + loc, + t.clone(), + Some(val), + )); + } + // If the left type in a "can cast to" check is a variable, let's see if we can advance // it into something more tangible Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => { @@ -373,6 +467,36 @@ pub fn solve_constraints( changed_something = true; } + // If either type is a function type, then we can only cast if the two types + // are equivalent. + Constraint::CanCastTo( + loc, + t1 @ Type::Function(_, _), + t2 @ Type::Function(_, _), + ) => { + if t1 != t2 { + errors.push(TypeInferenceError::CannotCastBetweenFunctinoTypes( + loc, + t1.clone(), + t2.clone(), + )); + } + changed_something = true; + } + + Constraint::CanCastTo(loc, t @ Type::Function(_, _), Type::Primitive(_)) => { + errors.push(TypeInferenceError::CannotCastFromFunctionType( + loc, + t.clone(), + )); + changed_something = true; + } + + Constraint::CanCastTo(loc, Type::Primitive(_), t @ Type::Function(_, _)) => { + errors.push(TypeInferenceError::CannotCastToFunctionType(loc, t.clone())); + changed_something = true; + } + // As per usual, if we're trying to test if a type variable is numeric, first // we try to advance it to a primitive Constraint::NumericType(loc, Type::Variable(vloc, var)) => { @@ -392,6 +516,16 @@ pub fn solve_constraints( changed_something = true; } + // But functions are definitely not numbers + Constraint::NumericType(loc, t @ Type::Function(_, _)) => { + errors.push(TypeInferenceError::CannotMakeNumberAFunction( + loc, + t.clone(), + None, + )); + changed_something = true; + } + // As per usual, if we're trying to test if a type variable is numeric, first // we try to advance it to a primitive Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => { @@ -414,6 +548,16 @@ pub fn solve_constraints( changed_something = true; } + // But functions are definitely not numbers + Constraint::ConstantNumericType(loc, t @ Type::Function(_, _)) => { + errors.push(TypeInferenceError::CannotMakeNumberAFunction( + loc, + t.clone(), + None, + )); + changed_something = true; + } + // OK, this one could be a little tricky if we tried to do it all at once, but // instead what we're going to do here is just use this constraint to generate // a bunch more constraints, and then go have the engine solve those. The only diff --git a/src/util.rs b/src/util.rs index b7271cc..5e15869 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1 +1,2 @@ +pub mod pretty; pub mod scoped_map; diff --git a/src/util/pretty.rs b/src/util/pretty.rs new file mode 100644 index 0000000..e5c478a --- /dev/null +++ b/src/util/pretty.rs @@ -0,0 +1,47 @@ +use internment::ArcIntern; +use pretty::{DocAllocator, Pretty}; + +#[derive(Clone)] +pub struct PrettySymbol { + name: ArcIntern, +} + +impl<'a> From<&'a ArcIntern> for PrettySymbol { + fn from(value: &'a ArcIntern) -> Self { + PrettySymbol { name: value.clone() } + } +} + +impl<'a, D, A> Pretty<'a, D, A> for PrettySymbol +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + allocator.text(self.name.as_ref().to_string()) + } +} + +impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b PrettySymbol +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { + allocator.text(self.name.as_ref().to_string()) + } +} + +#[allow(clippy::ptr_arg)] +pub fn pretty_comma_separated<'a, D, A, P>( + allocator: &'a D, + args: &Vec

, +) -> pretty::DocBuilder<'a, D, A> +where + A: 'a, + D: ?Sized + DocAllocator<'a, A>, + P: Pretty<'a, D, A> + Clone, +{ + let individuals = args.iter().map(|x| x.clone().pretty(allocator)); + allocator.intersperse(individuals, ", ") +} diff --git a/src/util/scoped_map.rs b/src/util/scoped_map.rs index dc8a9a1..69f442c 100644 --- a/src/util/scoped_map.rs +++ b/src/util/scoped_map.rs @@ -5,6 +5,12 @@ pub struct ScopedMap { scopes: Vec>, } +impl Default for ScopedMap { + fn default() -> Self { + ScopedMap::new() + } +} + impl ScopedMap { /// Generate a new scoped map. ///