diff --git a/examples/basic/type_checker1.ngr b/examples/basic/type_checker1.ngr new file mode 100644 index 0000000..aa0a44a --- /dev/null +++ b/examples/basic/type_checker1.ngr @@ -0,0 +1,6 @@ +x = 1 + 1u16; +print x; +y = 1u16 + 1; +print y; +z = 1 + 1 + 1; +print z; \ No newline at end of file diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index 59ec352..7690137 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -4,7 +4,7 @@ use crate::{ }; use std::{fmt::Display, str::FromStr}; -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum PrimitiveType { U8, U16, @@ -83,56 +83,39 @@ impl PrimitiveType { /// Return true if this type can be safely cast into the target type. pub fn can_cast_to(&self, target: &PrimitiveType) -> bool { match self { - PrimitiveType::U8 => match target { - PrimitiveType::U8 => true, - PrimitiveType::U16 => true, - PrimitiveType::U32 => true, - PrimitiveType::U64 => true, - PrimitiveType::I16 => true, - PrimitiveType::I32 => true, - PrimitiveType::I64 => true, - _ => false, - }, - PrimitiveType::U16 => match target { - PrimitiveType::U16 => true, - PrimitiveType::U32 => true, - PrimitiveType::U64 => true, - PrimitiveType::I32 => true, - PrimitiveType::I64 => true, - _ => false, - }, - PrimitiveType::U32 => match target { - PrimitiveType::U32 => true, - PrimitiveType::U64 => true, - PrimitiveType::I64 => true, - _ => false, - }, - PrimitiveType::U64 => match target { - PrimitiveType::U64 => true, - _ => false, - }, - PrimitiveType::I8 => match target { - PrimitiveType::I8 => true, - PrimitiveType::I16 => true, - PrimitiveType::I32 => true, - PrimitiveType::I64 => true, - _ => false, - }, - PrimitiveType::I16 => match target { - PrimitiveType::I16 => true, - PrimitiveType::I32 => true, - PrimitiveType::I64 => true, - _ => false, - }, - PrimitiveType::I32 => match target { - PrimitiveType::I32 => true, - PrimitiveType::I64 => true, - _ => false, - }, - PrimitiveType::I64 => match target { - PrimitiveType::I64 => true, - _ => false, - }, + PrimitiveType::U8 => matches!( + target, + PrimitiveType::U8 + | PrimitiveType::U16 + | PrimitiveType::U32 + | PrimitiveType::U64 + | PrimitiveType::I16 + | PrimitiveType::I32 + | PrimitiveType::I64 + ), + PrimitiveType::U16 => matches!( + target, + PrimitiveType::U16 + | PrimitiveType::U32 + | PrimitiveType::U64 + | PrimitiveType::I32 + | PrimitiveType::I64 + ), + PrimitiveType::U32 => matches!( + target, + PrimitiveType::U32 | PrimitiveType::U64 | PrimitiveType::I64 + ), + PrimitiveType::U64 => target == &PrimitiveType::U64, + PrimitiveType::I8 => matches!( + target, + PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64 + ), + PrimitiveType::I16 => matches!( + target, + PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64 + ), + PrimitiveType::I32 => matches!(target, PrimitiveType::I32 | PrimitiveType::I64), + PrimitiveType::I64 => target == &PrimitiveType::I64, } } @@ -174,4 +157,17 @@ impl PrimitiveType { }), } } + + pub fn max_value(&self) -> u64 { + match self { + PrimitiveType::U8 => u8::MAX as u64, + PrimitiveType::U16 => u16::MAX as u64, + PrimitiveType::U32 => u32::MAX as u64, + PrimitiveType::U64 => u64::MAX, + PrimitiveType::I8 => i8::MAX as u64, + PrimitiveType::I16 => i16::MAX as u64, + PrimitiveType::I32 => i32::MAX as u64, + PrimitiveType::I64 => i64::MAX as u64, + } + } } diff --git a/src/ir/ast.rs b/src/ir/ast.rs index 037b51f..9360de1 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -3,7 +3,7 @@ use crate::{ syntax::{self, ConstantType, Location}, }; use internment::ArcIntern; -use pretty::{DocAllocator, Pretty}; +use pretty::{BoxAllocator, DocAllocator, Pretty}; use proptest::{ prelude::Arbitrary, strategy::{BoxedStrategy, Strategy}, @@ -224,6 +224,12 @@ where } } +impl fmt::Display for Primitive { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f) + } +} + /// An expression that is always either a value or a reference. /// /// This is the type used to guarantee that we don't nest expressions diff --git a/src/ir/type_infer.rs b/src/ir/type_infer.rs index 341e931..172d44e 100644 --- a/src/ir/type_infer.rs +++ b/src/ir/type_infer.rs @@ -21,13 +21,23 @@ impl syntax::Program { let program = convert_program(self, &mut constraint_db); let inference_result = solve_constraints(constraint_db); - inference_result.map(|type_renames| finalize_program(program, type_renames)) + inference_result.map(|resolutions| finalize_program(program, &resolutions)) } } proptest::proptest! { #[test] fn translation_maintains_semantics(input: syntax::Program) { + use pretty::{DocAllocator, Pretty}; + let allocator = pretty::BoxAllocator; + allocator + .text("---------------") + .append(allocator.hardline()) + .append(input.pretty(&allocator)) + .1 + .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) + .expect("rendering works"); + let syntax_result = input.eval(); let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); let ir_result = ir.eval(); diff --git a/src/ir/type_infer/ast.rs b/src/ir/type_infer/ast.rs index 87d3cbc..ff91a00 100644 --- a/src/ir/type_infer/ast.rs +++ b/src/ir/type_infer/ast.rs @@ -109,7 +109,7 @@ where /// a primitive), any subexpressions have been bound to variables so /// that the referenced data will always either be a constant or a /// variable reference. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Expression { Atomic(ValueOrRef), Cast(Location, Type, ValueOrRef), @@ -177,7 +177,7 @@ where /// This is the type used to guarantee that we don't nest expressions /// at this level. Instead, expressions that take arguments take one /// of these, which can only be a constant or a reference. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum ValueOrRef { Value(Location, Type, Value), Ref(Location, Type, ArcIntern), @@ -208,7 +208,7 @@ impl From for Expression { /// user to input the number. By retaining it, we can ensure that if we need /// to print the number back out, we can do so in the form that the user /// entered it. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum Value { Unknown(Option, u64), I8(Option, i8), @@ -273,6 +273,12 @@ pub enum Type { Primitive(PrimitiveType), } +impl Type { + pub fn is_concrete(&self) -> bool { + !matches!(self, Type::Variable(_, _)) + } +} + impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type where A: 'a, @@ -321,10 +327,10 @@ pub fn gensym(name: &str) -> ArcIntern { pub fn gentype() -> Type { static COUNTER: AtomicUsize = AtomicUsize::new(0); - let new_name = ArcIntern::new(format!( + let name = ArcIntern::new(format!( "t<{}>", COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst) )); - Type::Variable(Location::manufactured(), new_name) + Type::Variable(Location::manufactured(), name) } diff --git a/src/ir/type_infer/convert.rs b/src/ir/type_infer/convert.rs index 8b61584..8db9753 100644 --- a/src/ir/type_infer/convert.rs +++ b/src/ir/type_infer/convert.rs @@ -149,7 +149,11 @@ fn convert_expression( ), }; - constraint_db.push(Constraint::FitsInNumType(loc.clone(), newtype.clone(), value)); + constraint_db.push(Constraint::FitsInNumType( + loc.clone(), + newtype.clone(), + value, + )); ( vec![], ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), @@ -230,3 +234,126 @@ fn simplify_expr(expr: ir::Expression, stmts: &mut Vec) -> ir::Va } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::syntax::Location; + + fn one() -> syntax::Expression { + syntax::Expression::Value( + Location::manufactured(), + syntax::Value::Number(None, None, 1), + ) + } + + fn vec_contains bool>(x: &[T], f: F) -> bool { + for x in x.iter() { + if f(x) { + return true; + } + } + false + } + + fn infer_expression( + x: syntax::Expression, + ) -> (ir::Expression, Vec, Vec, Type) { + let mut constraints = Vec::new(); + let renames = HashMap::new(); + let mut bindings = HashMap::new(); + let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings); + (expr, stmts, constraints, ty) + } + + fn infer_statement(x: syntax::Statement) -> (Vec, Vec) { + let mut constraints = Vec::new(); + let mut renames = HashMap::new(); + let mut bindings = HashMap::new(); + let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings); + (res, constraints) + } + + #[test] + fn constant_one() { + let (expr, stmts, constraints, ty) = infer_expression(one()); + assert!(stmts.is_empty()); + assert!(matches!( + expr, + ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1))) + )); + assert!(vec_contains(&constraints, |x| matches!( + x, + Constraint::FitsInNumType(_, _, 1) + ))); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::NumericType(_, t) if t == &ty) + )); + } + + #[test] + fn one_plus_one() { + let opo = syntax::Expression::Primitive( + Location::manufactured(), + "+".to_string(), + vec![one(), one()], + ); + let (expr, stmts, constraints, ty) = infer_expression(opo); + assert!(stmts.is_empty()); + assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty)); + assert!(vec_contains(&constraints, |x| matches!( + x, + Constraint::FitsInNumType(_, _, 1) + ))); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::NumericType(_, t) if t != &ty) + )); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty) + )); + } + + #[test] + fn one_plus_one_plus_one() { + let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse"); + let (stmts, constraints) = infer_statement(stmt); + assert_eq!(stmts.len(), 2); + let ir::Statement::Binding(_args, name1, temp_ty1, ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1)) = stmts.get(0).expect("item two") else { + panic!("Failed to match first statement"); + }; + let ir::Statement::Binding(_args, name2, temp_ty2, ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2)) = stmts.get(1).expect("item two") else { + panic!("Failed to match second statement"); + }; + let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] = &primargs1[..] else { + panic!("Failed to match first arguments"); + }; + let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] = &primargs2[..] else { + panic!("Failed to match first arguments"); + }; + assert_ne!(name1, name2); + assert_ne!(temp_ty1, temp_ty2); + assert_ne!(primty1, primty2); + assert_eq!(name1, left2name); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::NumericType(_, t) if t == left1ty) + )); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::NumericType(_, t) if t == right1ty) + )); + assert!(vec_contains( + &constraints, + |x| matches!(x, Constraint::NumericType(_, t) if t == right2ty) + )); + for (i, s) in stmts.iter().enumerate() { + println!("{}: {:?}", i, s); + } + for (i, c) in constraints.iter().enumerate() { + println!("{}: {:?}", i, c); + } + } +} diff --git a/src/ir/type_infer/finalize.rs b/src/ir/type_infer/finalize.rs index cc8c5de..9c70944 100644 --- a/src/ir/type_infer/finalize.rs +++ b/src/ir/type_infer/finalize.rs @@ -1,88 +1,80 @@ -use super::ast as input; +use super::{ast as input, solve::TypeResolutions}; use crate::{eval::PrimitiveType, ir as output}; -use internment::ArcIntern; -use std::collections::HashMap; pub fn finalize_program( mut program: input::Program, - type_renames: HashMap, input::Type>, + resolutions: &TypeResolutions, ) -> output::Program { output::Program { statements: program .statements .drain(..) - .map(|x| finalize_statement(x, &type_renames)) + .map(|x| finalize_statement(x, resolutions)) .collect(), } } fn finalize_statement( statement: input::Statement, - type_renames: &HashMap, input::Type>, + resolutions: &TypeResolutions, ) -> output::Statement { match statement { input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding( loc, var, - finalize_type(ty, type_renames), - finalize_expression(expr, type_renames), + finalize_type(ty, resolutions), + finalize_expression(expr, resolutions), ), input::Statement::Print(loc, ty, var) => { - output::Statement::Print(loc, finalize_type(ty, type_renames), var) + output::Statement::Print(loc, finalize_type(ty, resolutions), var) } } } fn finalize_expression( expression: input::Expression, - type_renames: &HashMap, input::Type>, + resolutions: &TypeResolutions, ) -> output::Expression { match expression { input::Expression::Atomic(val_or_ref) => { - output::Expression::Atomic(finalize_val_or_ref(val_or_ref, type_renames)) + output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions)) } input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast( loc, - finalize_type(target, type_renames), - finalize_val_or_ref(val_or_ref, type_renames), + finalize_type(target, resolutions), + finalize_val_or_ref(val_or_ref, resolutions), ), input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive( loc, - finalize_type(ty, type_renames), + finalize_type(ty, resolutions), prim, args.drain(..) - .map(|x| finalize_val_or_ref(x, type_renames)) + .map(|x| finalize_val_or_ref(x, resolutions)) .collect(), ), } } -fn finalize_type( - ty: input::Type, - type_renames: &HashMap, input::Type>, -) -> output::Type { +fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type { match ty { input::Type::Primitive(x) => output::Type::Primitive(x), - input::Type::Variable(loc, name) => match type_renames.get(&name) { - Some(input::Type::Primitive(x)) => output::Type::Primitive(*x), - res => panic!( - "ACK! Internal error cleaning up temporary type name at {:?}: got {:?}", - loc, res - ), + input::Type::Variable(_, tvar) => match resolutions.get(&tvar) { + None => panic!("Did not resolve type for type variable {}", tvar), + Some(pt) => output::Type::Primitive(*pt), }, } } fn finalize_val_or_ref( valref: input::ValueOrRef, - type_renames: &HashMap, input::Type>, + resolutions: &TypeResolutions, ) -> output::ValueOrRef { match valref { input::ValueOrRef::Ref(loc, ty, var) => { - output::ValueOrRef::Ref(loc, finalize_type(ty, type_renames), var) + output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var) } input::ValueOrRef::Value(loc, ty, val) => { - let new_type = finalize_type(ty, type_renames); + let new_type = finalize_type(ty, resolutions); match val { input::Value::Unknown(base, value) => match new_type { @@ -101,11 +93,9 @@ fn finalize_val_or_ref( new_type, output::Value::U32(base, value as u32), ), - output::Type::Primitive(PrimitiveType::U64) => output::ValueOrRef::Value( - loc, - new_type, - output::Value::U64(base, value), - ), + output::Type::Primitive(PrimitiveType::U64) => { + output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) + } output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value( loc, new_type, diff --git a/src/ir/type_infer/solve.rs b/src/ir/type_infer/solve.rs index 66cbedc..e8c4121 100644 --- a/src/ir/type_infer/solve.rs +++ b/src/ir/type_infer/solve.rs @@ -3,8 +3,9 @@ use super::ast::Type; use crate::{eval::PrimitiveType, syntax::Location}; use codespan_reporting::diagnostic::Diagnostic; use internment::ArcIntern; -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; +#[derive(Debug)] pub enum Constraint { /// The given type must be printable using the `print` built-in Printable(Location, Type), @@ -14,12 +15,34 @@ pub enum Constraint { ProperPrimitiveArgs(Location, ir::Primitive, Vec, Type), /// The given type can be casted to the target type safely CanCastTo(Location, Type, Type), - /// The given type must be some numeric type + /// The given type must be some numeric type, but this is not a constant + /// value, so don't try to default it if we can't figure it out NumericType(Location, Type), /// The two types should be equivalent Equivalent(Location, Type, Type), } +impl fmt::Display for Constraint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty), + Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty), + Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => { + write!(f, "PRIM {} {} -> {}", op, args[0], ret) + } + Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => { + write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret) + } + Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret), + Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), + Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), + Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), + } + } +} + +pub type TypeResolutions = HashMap, PrimitiveType>; + pub enum TypeInferenceResult { Success { result: Result, @@ -71,7 +94,66 @@ pub enum TypeInferenceError { impl From for Diagnostic { fn from(value: TypeInferenceError) -> Self { - unimplemented!() + match value { + TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc + .labelled_error("constant too large for type") + .with_message(format!( + "Type {} has a max value of {}, which is smaller than {}", + primty, + primty.max_value(), + value + )), + TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc + .labelled_error("type inference error") + .with_message(format!("Expected type {}, received type {}", ty1, ty2)), + TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc + .labelled_error("unsafe type cast") + .with_message(format!("Cannot safely cast {} to {}", ty1, ty2)), + TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc + .labelled_error("wrong number of arguments") + .with_message(format!( + "expected {} for {}, received {}", + if lower == upper && lower > 1 { + format!("{} arguments", lower) + } else if lower == upper { + format!("{} argument", lower) + } else { + format!("{}-{} arguments", lower, upper) + }, + prim, + observed + )), + TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => { + loc.labelled_error("internal error").with_message(format!( + "could not determine if it was safe to cast from {} to {:#?}", + a, b + )) + } + TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => { + loc.labelled_error("internal error").with_message(format!( + "could not determine if {} and {:#?} were equivalent", + a, b + )) + } + TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => { + loc.labelled_error("internal error").with_message(format!( + "Could not determine if {} could fit in {}", + val, ty + )) + } + TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc + .labelled_error("internal error") + .with_message(format!("Could not determine if {} was a numeric type", ty)), + TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc + .labelled_error("internal error") + .with_message(format!("Could not determine if type {} was printable", ty)), + TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => { + loc.labelled_error("internal error").with_message(format!( + "Could not tell if primitive {} received the proper argument types", + prim + )) + } + } } } @@ -81,22 +163,31 @@ pub enum TypeInferenceWarning { impl From for Diagnostic { fn from(value: TypeInferenceWarning) -> Self { - unimplemented!() + match value { + TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning() + .with_labels(vec![loc.primary_label().with_message("unknown type")]) + .with_message(format!("Defaulted unknown type to {}", ty)), + } } } pub fn solve_constraints( mut constraint_db: Vec, -) -> TypeInferenceResult, Type>> { - let mut type_renames = HashMap::new(); +) -> TypeInferenceResult { let mut errors = vec![]; let mut warnings = vec![]; + let mut resolutions = HashMap::new(); let mut changed_something = true; // We want to run this inference endlessly, until either we have solved all of our // constraints. Internal to the loop, we have a check that will make sure that we // do (eventually) stop. while changed_something && !constraint_db.is_empty() { + println!("-------CONSTRAINTS---------"); + for constraint in constraint_db.iter() { + println!("{}", constraint); + } + println!("---------------------------"); // Set this to false at the top of the loop. We'll set this to true if we make // progress in any way further down, but having this here prevents us from going // into an infinite look when we can't figure stuff out. @@ -115,43 +206,64 @@ pub fn solve_constraints( // Currently, all of our types are printable Constraint::Printable(_loc, _ty) => changed_something = true, - // If the first type is a variable, then we rename it to the second. - Constraint::Equivalent(_, Type::Variable(_, n), second) => { - type_renames.insert(n, second); - changed_something = true; - } - - // If the second type is a variable (I guess the first wasn't, then rename it to the first) - Constraint::Equivalent(_, first, Type::Variable(_, n)) => { - type_renames.insert(n, first); - changed_something = true; - } - - // Otherwise, we ar testing if two concrete types are equivalent, and just need to see - // if they're the same. - Constraint::Equivalent(loc, Type::Primitive(a), Type::Primitive(b)) => { - if a != b { - errors.push(TypeInferenceError::NotEquivalent(loc, a, b)); + // Case #1: We have two primitive types. If they're equal, we've discharged this + // constraint! We can just continue. If they're not equal, add an error and then + // see what else we come up with. + Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => { + if t1 != t2 { + errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2)); } changed_something = true; } + // Case #2: One of the two constraints is a primitive, and the other is a variable. + // In this case, we'll check to see if we've resolved the variable, and check for + // equivalence if we have. If we haven't, we'll set that variable to be primitive + // type. + Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name)) + | Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => { + match resolutions.get(&name) { + None => { + resolutions.insert(name, t); + } + Some(t2) if &t == t2 => {} + Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)), + } + changed_something = true; + } + + // Case #3: They're both variables. In which case, we'll have to do much the same + // check, but now on their resolutions. + Constraint::Equivalent( + ref loc, + Type::Variable(_, ref name1), + Type::Variable(_, ref name2), + ) => match (resolutions.get(name1), resolutions.get(name2)) { + (None, None) => { + constraint_db.push(constraint); + } + (Some(pt), None) => { + resolutions.insert(name2.clone(), *pt); + changed_something = true; + } + (None, Some(pt)) => { + resolutions.insert(name1.clone(), *pt); + changed_something = true; + } + (Some(pt1), Some(pt2)) if pt1 == pt2 => { + changed_something = true; + } + (Some(pt1), Some(pt2)) => { + errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2)); + changed_something = true; + } + }, + // Make sure that the provided number fits within the provided constant type. For the // moment, we're going to call an error here a failure, although this could be a // warning in the future. Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => { - let is_too_big = match ctype { - PrimitiveType::U8 => (u8::MAX as u64) < val, - PrimitiveType::U16 => (u16::MAX as u64) < val, - PrimitiveType::U32 => (u32::MAX as u64) < val, - PrimitiveType::U64 => false, - PrimitiveType::I8 => (i8::MAX as u64) < val, - PrimitiveType::I16 => (i16::MAX as u64) < val, - PrimitiveType::I32 => (i32::MAX as u64) < val, - PrimitiveType::I64 => (i64::MAX as u64) < val, - }; - - if is_too_big { + if ctype.max_value() < val { errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val)); } @@ -161,10 +273,18 @@ pub fn solve_constraints( // If we have a non-constant type, then let's see if we can advance this to a constant // type Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => { - match type_renames.get(&var) { - None => constraint_db.push(Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val)), + match resolutions.get(&var) { + None => constraint_db.push(Constraint::FitsInNumType( + loc, + Type::Variable(vloc, var), + val, + )), Some(nt) => { - constraint_db.push(Constraint::FitsInNumType(loc, nt.clone(), val)); + constraint_db.push(Constraint::FitsInNumType( + loc, + Type::Primitive(*nt), + val, + )); changed_something = true; } } @@ -173,10 +293,18 @@ pub fn solve_constraints( // If the left type in a "can cast to" check is a variable, let's see if we can advance // it into something more tangible Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => { - match type_renames.get(&var) { - None => constraint_db.push(Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type)), + match resolutions.get(&var) { + None => constraint_db.push(Constraint::CanCastTo( + loc, + Type::Variable(vloc, var), + to_type, + )), Some(nt) => { - constraint_db.push(Constraint::CanCastTo(loc, nt.clone(), to_type)); + constraint_db.push(Constraint::CanCastTo( + loc, + Type::Primitive(*nt), + to_type, + )); changed_something = true; } } @@ -184,10 +312,18 @@ pub fn solve_constraints( // If the right type in a "can cast to" check is a variable, same deal Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => { - match type_renames.get(&var) { - None => constraint_db.push(Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var))), + match resolutions.get(&var) { + None => constraint_db.push(Constraint::CanCastTo( + loc, + from_type, + Type::Variable(vloc, var), + )), Some(nt) => { - constraint_db.push(Constraint::CanCastTo(loc, from_type, nt.clone())); + constraint_db.push(Constraint::CanCastTo( + loc, + from_type, + Type::Primitive(*nt), + )); changed_something = true; } } @@ -210,10 +346,11 @@ pub fn solve_constraints( // As per usual, if we're trying to test if a type variable is numeric, first // we try to advance it to a primitive Constraint::NumericType(loc, Type::Variable(vloc, var)) => { - match type_renames.get(&var) { - None => constraint_db.push(Constraint::NumericType(loc, Type::Variable(vloc, var))), + match resolutions.get(&var) { + None => constraint_db + .push(Constraint::NumericType(loc, Type::Variable(vloc, var))), Some(nt) => { - constraint_db.push(Constraint::NumericType(loc, nt.clone())); + constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt))); changed_something = true; } } @@ -254,7 +391,11 @@ pub fn solve_constraints( constraint_db.push(Constraint::NumericType(loc.clone(), left.clone())); constraint_db.push(Constraint::NumericType(loc.clone(), right.clone())); constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); - constraint_db.push(Constraint::Equivalent(loc.clone(), left.clone(), right)); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + left.clone(), + right, + )); constraint_db.push(Constraint::Equivalent(loc, left, ret)); changed_something = true; } @@ -287,7 +428,11 @@ pub fn solve_constraints( constraint_db.push(Constraint::NumericType(loc.clone(), left.clone())); constraint_db.push(Constraint::NumericType(loc.clone(), right.clone())); constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone())); - constraint_db.push(Constraint::Equivalent(loc.clone(), left.clone(), right)); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + left.clone(), + right, + )); constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret)); changed_something = true; } @@ -307,9 +452,13 @@ pub fn solve_constraints( for constraint in local_constraints.drain(..) { match constraint { - Constraint::NumericType(loc, Type::Variable(_, name)) => { + Constraint::NumericType(loc, t @ Type::Variable(_, _)) => { let resty = Type::Primitive(PrimitiveType::U64); - type_renames.insert(name, resty.clone()); + constraint_db.push(Constraint::Equivalent( + loc.clone(), + t, + Type::Primitive(PrimitiveType::U64), + )); warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty)); changed_something = true; } @@ -332,7 +481,7 @@ pub fn solve_constraints( // How'd we do? if errors.is_empty() { TypeInferenceResult::Success { - result: type_renames, + result: resolutions, warnings, } } else { diff --git a/src/syntax/arbitrary.rs b/src/syntax/arbitrary.rs index 1fd9b8d..1fb2397 100644 --- a/src/syntax/arbitrary.rs +++ b/src/syntax/arbitrary.rs @@ -10,7 +10,7 @@ use std::collections::HashMap; const VALID_VARIABLE_NAMES: &str = r"[a-z][a-zA-Z0-9_]*"; const OPERATORS: &[(&str, usize)] = &[("+", 2), ("-", 1), ("-", 2), ("*", 2), ("/", 2)]; -#[derive(Debug)] +#[derive(Clone, Debug)] struct Name(String); impl Arbitrary for Name { @@ -22,153 +22,133 @@ impl Arbitrary for Name { } } +#[derive(Debug)] +struct ProgramStatementInfo { + should_be_binding: bool, + name: Name, + binding_type: ConstantType, +} + +impl Arbitrary for ProgramStatementInfo { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + ( + Union::new(vec![Just(true), Just(true), Just(false)]), + Name::arbitrary(), + ConstantType::arbitrary(), + ) + .prop_map( + |(should_be_binding, name, binding_type)| ProgramStatementInfo { + should_be_binding, + name, + binding_type, + }, + ) + .boxed() + } +} + impl Arbitrary for Program { type Parameters = (); type Strategy = BoxedStrategy; fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - let optionals = Vec::<(Name, ConstantType, u8)>::arbitrary(); - - optionals - .prop_flat_map(|mut possible_names| { + proptest::collection::vec(ProgramStatementInfo::arbitrary(), 1..100) + .prop_flat_map(|mut items| { let mut statements = Vec::new(); - let mut defined_variables: HashMap = HashMap::new(); + let mut defined_variables = HashMap::new(); - for (possible_name, possible_type, dice_roll) in possible_names.drain(..) { - if !defined_variables.is_empty() && dice_roll < 100 { + for psi in items.drain(..) { + if defined_variables.is_empty() || psi.should_be_binding { + let expr = Expression::arbitrary_with(ExpressionGeneratorSettings { + bound_variables: defined_variables.clone(), + output_type: Some(psi.binding_type), + }); + + defined_variables.insert(psi.name.0.clone(), psi.binding_type); statements.push( - Union::new(defined_variables.keys().map(|name| { - Just(Statement::Print(Location::manufactured(), name.to_string())) - })) + expr.prop_map(move |expr| { + Statement::Binding( + Location::manufactured(), + psi.name.0.clone(), + expr, + ) + }) .boxed(), ); } else { - let closures_name = possible_name.0.clone(); - let retval = Expression::arbitrary_with(( - Some(defined_variables.clone()), - Some(possible_type), - )) - .prop_map(move |exp| { - Statement::Binding(Location::manufactured(), closures_name.clone(), exp) - }) - .boxed(); - - defined_variables.insert(possible_name.0, possible_type); - statements.push(retval); + let printers = defined_variables + .keys() + .map(|n| Just(Statement::Print(Location::manufactured(), n.clone()))); + statements.push(Union::new(printers).boxed()); } } statements + .prop_map(|statements| Program { statements }) + .boxed() }) - .prop_map(|statements| Program { statements }) .boxed() } } -impl Arbitrary for Statement { - type Parameters = Option>; - type Strategy = BoxedStrategy; - - fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - let duplicated_args = args.clone(); - let defined_variables = args.unwrap_or_default(); - - let binding_strategy = ( - VALID_VARIABLE_NAMES, - Expression::arbitrary_with((duplicated_args, None)), - ) - .prop_map(|(name, exp)| Statement::Binding(Location::manufactured(), name, exp)) - .boxed(); - - if defined_variables.is_empty() { - binding_strategy - } else { - let print_strategy = Union::new( - defined_variables - .keys() - .map(|x| Just(Statement::Print(Location::manufactured(), x.to_string()))), - ) - .boxed(); - - Union::new([binding_strategy, print_strategy]).boxed() - } - } +#[derive(Default)] +pub struct ExpressionGeneratorSettings { + bound_variables: HashMap, + output_type: Option, } impl Arbitrary for Expression { - type Parameters = (Option>, Option); + type Parameters = ExpressionGeneratorSettings; type Strategy = BoxedStrategy; - fn arbitrary_with((env, target_type): Self::Parameters) -> Self::Strategy { - let defined_variables = env.unwrap_or_default(); - let mut acceptable_variables = defined_variables - .iter() - .filter(|(_, ctype)| Some(**ctype) == target_type) - .map(|(x, _)| x) - .peekable(); - - let value_strategy = Value::arbitrary_with(target_type) - .prop_map(move |x| Expression::Value(Location::manufactured(), x)) + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + // Value(Location, Value). These are the easiest variations to create, because we can always + // create one. + let value_strategy = Value::arbitrary_with(params.output_type) + .prop_map(|x| Expression::Value(Location::manufactured(), x)) .boxed(); - let leaf_strategy = if acceptable_variables.peek().is_none() { + // Reference(Location, String), These are slightly trickier, because we can end up in a situation + // where either no variables are defined, or where none of the defined variables have a type we + // can work with. So what we're going to do is combine this one with the previous one as a "leaf + // strategy" -- our non-recursive items -- if we can, or just set that to be the value strategy + // if we can't actually create an references. + let mut bound_variables_of_type = params + .bound_variables + .iter() + .filter(|(_, v)| { + params + .output_type + .as_ref() + .map(|ot| ot == *v) + .unwrap_or(true) + }) + .map(|(n, _)| n) + .collect::>(); + let leaf_strategy = if bound_variables_of_type.is_empty() { value_strategy } else { - let reference_strategy = Union::new(acceptable_variables.map(|x| { - Just(Expression::Reference( - Location::manufactured(), - x.to_owned(), - )) - })) - .boxed(); - Union::new([value_strategy, reference_strategy]).boxed() + let mut strats = bound_variables_of_type + .drain(..) + .map(|x| Just(Expression::Reference(Location::manufactured(), x.clone())).boxed()) + .collect::>(); + strats.push(value_strategy); + Union::new(strats).boxed() }; - let cast_strategy = if let Some(bigger_type) = target_type { - let mut smaller_types = bigger_type.safe_casts_to(); - - if smaller_types.is_empty() { - leaf_strategy - } else { - let duplicated_env = defined_variables.clone(); - let cast_exp = |t, e| Expression::Cast(Location::manufactured(), t, Box::new(e)); - - let smaller_strats: Vec> = smaller_types - .drain(..) - .map(|t| { - Expression::arbitrary_with((Some(duplicated_env.clone()), Some(t))) - .prop_map(move |e| cast_exp(t.name(), e)) - .boxed() - }) - .collect(); - Union::new(smaller_strats).boxed() - } - } else { - leaf_strategy - }; - - cast_strategy - .prop_recursive(3, 64, 2, move |inner| { - (select(OPERATORS), proptest::collection::vec(inner, 2)).prop_map( - move |((operator, arg_count), mut exprs)| { - if arg_count == 1 && operator == "-" { - if target_type.map(|x| x.is_signed()).unwrap_or(false) { - Expression::Primitive( - Location::manufactured(), - operator.to_string(), - exprs, - ) - } else { - exprs.pop().unwrap() - } - } else { - exprs.truncate(arg_count); - Expression::Primitive( - Location::manufactured(), - operator.to_string(), - exprs, - ) + // now we generate our recursive types, given our leaf strategy + leaf_strategy + .prop_recursive(3, 10, 2, move |strat| { + (select(OPERATORS), strat.clone(), strat).prop_map( + |((oper, count), left, right)| { + let mut args = vec![left, right]; + while args.len() > count { + args.pop(); } + Expression::Primitive(Location::manufactured(), oper.to_string(), args) }, ) }) @@ -181,7 +161,7 @@ impl Arbitrary for Value { type Strategy = BoxedStrategy; fn arbitrary_with(target_type: Self::Parameters) -> Self::Strategy { - let base_strategy = Union::new([ + let printed_base_strategy = Union::new([ Just(None::), Just(Some(2)), Just(Some(8)), @@ -189,14 +169,13 @@ impl Arbitrary for Value { Just(Some(16)), ]); - let type_strategy = if target_type.is_some() { - Just(target_type).boxed() - } else { - proptest::option::of(ConstantType::arbitrary()).boxed() + let type_strategy = match target_type { + None => proptest::option::of(ConstantType::arbitrary()).boxed(), + Some(target) => proptest::option::of(Just(target)).boxed(), }; let value_strategy = u64::arbitrary(); - (base_strategy, type_strategy, value_strategy) + (printed_base_strategy, type_strategy, value_strategy) .prop_map(move |(base, ty, value)| { let converted_value = match ty { Some(ConstantType::I8) => value % (i8::MAX as u64), diff --git a/src/syntax/location.rs b/src/syntax/location.rs index 3c97d3d..d193d26 100644 --- a/src/syntax/location.rs +++ b/src/syntax/location.rs @@ -4,7 +4,7 @@ use codespan_reporting::diagnostic::{Diagnostic, Label}; /// /// Internally, locations are very tied to the `codespan_reporting` library, /// and the primary use of them is to serve as anchors within that library. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct Location { file_idx: usize, offset: usize, diff --git a/src/syntax/tokens.rs b/src/syntax/tokens.rs index a545435..e521987 100644 --- a/src/syntax/tokens.rs +++ b/src/syntax/tokens.rs @@ -179,17 +179,50 @@ impl ConstantType { /// Return the set of types that can be safely casted into this type. pub fn safe_casts_to(self) -> Vec { match self { - ConstantType::I8 => vec![], - ConstantType::I16 => vec![ConstantType::I8], - ConstantType::I32 => vec![ConstantType::I16, ConstantType::I8], - ConstantType::I64 => vec![ConstantType::I32, ConstantType::I16, ConstantType::I8], - ConstantType::U8 => vec![], - ConstantType::U16 => vec![ConstantType::U8], - ConstantType::U32 => vec![ConstantType::U16, ConstantType::U8], - ConstantType::U64 => vec![ConstantType::U32, ConstantType::U16, ConstantType::U8], + ConstantType::I8 => vec![ConstantType::I8], + ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8], + ConstantType::I32 => vec![ + ConstantType::I32, + ConstantType::I16, + ConstantType::I8, + ConstantType::U16, + ConstantType::U8, + ], + ConstantType::I64 => vec![ + ConstantType::I64, + ConstantType::I32, + ConstantType::I16, + ConstantType::I8, + ConstantType::U32, + ConstantType::U16, + ConstantType::U8, + ], + ConstantType::U8 => vec![ConstantType::U8], + ConstantType::U16 => vec![ConstantType::U16, ConstantType::U8], + ConstantType::U32 => vec![ConstantType::U32, ConstantType::U16, ConstantType::U8], + ConstantType::U64 => vec![ + ConstantType::U64, + ConstantType::U32, + ConstantType::U16, + ConstantType::U8, + ], } } + /// Return the set of all currently-available constant types + pub fn all_types() -> Vec { + vec![ + ConstantType::U8, + ConstantType::U16, + ConstantType::U32, + ConstantType::U64, + ConstantType::I8, + ConstantType::I16, + ConstantType::I32, + ConstantType::I64, + ] + } + /// Return the name of the given type, as a string pub fn name(&self) -> String { match self {