diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 3db2c59..1b1d3af 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -112,7 +112,7 @@ fn convert_expression( let newtype = ir::gentype(); let newval = ir::Value::Unknown(base, value); - constraint_db.push(Constraint::NumericType(loc.clone(), newtype.clone())); + constraint_db.push(Constraint::ConstantNumericType(loc.clone(), newtype.clone())); (newval, newtype) } Some(ConstantType::U8) => ( @@ -288,7 +288,7 @@ mod tests { ))); assert!(vec_contains( &constraints, - |x| matches!(x, Constraint::NumericType(_, t) if t == &ty) + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty) )); } @@ -308,7 +308,7 @@ mod tests { ))); assert!(vec_contains( &constraints, - |x| matches!(x, Constraint::NumericType(_, t) if t != &ty) + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty) )); assert!(vec_contains( &constraints, @@ -339,15 +339,15 @@ mod tests { assert_eq!(name1, left2name); assert!(vec_contains( &constraints, - |x| matches!(x, Constraint::NumericType(_, t) if t == left1ty) + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty) )); assert!(vec_contains( &constraints, - |x| matches!(x, Constraint::NumericType(_, t) if t == right1ty) + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty) )); assert!(vec_contains( &constraints, - |x| matches!(x, Constraint::NumericType(_, t) if t == right2ty) + |x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty) )); for (i, s) in stmts.iter().enumerate() { println!("{}: {:?}", i, s); diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index e8c4121..afecfb3 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -18,6 +18,10 @@ pub enum Constraint { /// The given type must be some numeric type, but this is not a constant /// value, so don't try to default it if we can't figure it out NumericType(Location, Type), + /// The given type is attached to a constant and must be some numeric type. + /// If we can't figure it out, we should warn the user and then just use a + /// default. + ConstantNumericType(Location, Type), /// The two types should be equivalent Equivalent(Location, Type, Type), } @@ -36,6 +40,7 @@ impl fmt::Display for Constraint { Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret), Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2), Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), + Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty), Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), } } @@ -144,6 +149,8 @@ impl From for Diagnostic { TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc .labelled_error("internal error") .with_message(format!("Could not determine if {} was a numeric type", ty)), + TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) => + panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty), TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc .labelled_error("internal error") .with_message(format!("Could not determine if type {} was printable", ty)), @@ -362,6 +369,25 @@ pub fn solve_constraints( changed_something = true; } + // As per usual, if we're trying to test if a type variable is numeric, first + // we try to advance it to a primitive + Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => { + match resolutions.get(&var) { + None => constraint_db + .push(Constraint::ConstantNumericType(loc, Type::Variable(vloc, var))), + Some(nt) => { + constraint_db.push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt))); + changed_something = true; + } + } + } + + // Of course, if we get to a primitive type, then it's true, because all of our + // primitive types are numbers + Constraint::ConstantNumericType(_, Type::Primitive(_)) => { + changed_something = true; + } + // OK, this one could be a little tricky if we tried to do it all at once, but // instead what we're going to do here is just use this constraint to generate // a bunch more constraints, and then go have the engine solve those. The only @@ -452,7 +478,7 @@ pub fn solve_constraints( for constraint in local_constraints.drain(..) { match constraint { - Constraint::NumericType(loc, t @ Type::Variable(_, _)) => { + Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => { let resty = Type::Primitive(PrimitiveType::U64); constraint_db.push(Constraint::Equivalent( loc.clone(),