🤔 Add a type inference engine, along with typed literals. (#4)
The typed literal formatting mirrors that of Rust. If no type can be inferred for an untagged literal, the type inference engine will warn the user and then assume that they meant an unsigned 64-bit number. (This is slightly inconvenient, because there can be cases in which our Arbitrary instance may generate a unary negation, in which we should assume that it's a signed 64-bit number; we may want to revisit this later.) The type inference engine is a standard two phase one, in which we first generate a series of type constraints, and then we solve those constraints. In this particular implementation, we actually use a third phase to generate a final AST. Finally, to increase the amount of testing performed, I've removed the overflow checking in the evaluator. The only thing we now check for is division by zero. This does make things a trace slower in testing, but hopefully we get more coverage this way.
This commit was merged in pull request #4.
This commit is contained in:
336
src/type_infer/ast.rs
Normal file
336
src/type_infer/ast.rs
Normal file
@@ -0,0 +1,336 @@
|
||||
pub use crate::ir::ast::Primitive;
|
||||
/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going
|
||||
/// to want to use while we're doing type inference, but don't want to keep around
|
||||
/// afterwards. These are:
|
||||
///
|
||||
/// * A notion of a type variable
|
||||
/// * An unknown numeric constant form
|
||||
///
|
||||
use crate::{
|
||||
eval::PrimitiveType,
|
||||
syntax::{self, ConstantType, Location},
|
||||
};
|
||||
use internment::ArcIntern;
|
||||
use pretty::{DocAllocator, Pretty};
|
||||
use std::fmt;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
|
||||
/// We're going to represent variables as interned strings.
|
||||
///
|
||||
/// These should be fast enough for comparison that it's OK, since it's going to end up
|
||||
/// being pretty much the pointer to the string.
|
||||
type Variable = ArcIntern<String>;
|
||||
|
||||
/// The representation of a program within our IR. For now, this is exactly one file.
|
||||
///
|
||||
/// In addition, for the moment there's not really much of interest to hold here besides
|
||||
/// the list of statements read from the file. Order is important. In the future, you
|
||||
/// could imagine caching analysis information in this structure.
|
||||
///
|
||||
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
|
||||
/// to print the structure whenever possible, especially if you value your or your
|
||||
/// user's time. The latter is useful for testing that conversions of `Program` retain
|
||||
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
|
||||
/// syntactically valid, although they may contain runtime issue like over- or underflow.
|
||||
#[derive(Debug)]
|
||||
pub struct Program {
|
||||
// For now, a program is just a vector of statements. In the future, we'll probably
|
||||
// extend this to include a bunch of other information, but for now: just a list.
|
||||
pub(crate) statements: Vec<Statement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
|
||||
where
|
||||
A: 'a,
|
||||
D: ?Sized + DocAllocator<'a, A>,
|
||||
{
|
||||
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
|
||||
let mut result = allocator.nil();
|
||||
|
||||
for stmt in self.statements.iter() {
|
||||
// there's probably a better way to do this, rather than constantly
|
||||
// adding to the end, but this works.
|
||||
result = result
|
||||
.append(stmt.pretty(allocator))
|
||||
.append(allocator.text(";"))
|
||||
.append(allocator.hardline());
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// The representation of a statement in the language.
|
||||
///
|
||||
/// For now, this is either a binding site (`x = 4`) or a print statement
|
||||
/// (`print x`). Someday, though, more!
|
||||
///
|
||||
/// As with `Program`, this type implements [`Pretty`], which should
|
||||
/// be used to display the structure whenever possible. It does not
|
||||
/// implement [`Arbitrary`], though, mostly because it's slightly
|
||||
/// complicated to do so.
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub enum Statement {
|
||||
Binding(Location, Variable, Type, Expression),
|
||||
Print(Location, Type, Variable),
|
||||
}
|
||||
|
||||
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
|
||||
where
|
||||
A: 'a,
|
||||
D: ?Sized + DocAllocator<'a, A>,
|
||||
{
|
||||
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
|
||||
match self {
|
||||
Statement::Binding(_, var, _, expr) => allocator
|
||||
.text(var.as_ref().to_string())
|
||||
.append(allocator.space())
|
||||
.append(allocator.text("="))
|
||||
.append(allocator.space())
|
||||
.append(expr.pretty(allocator)),
|
||||
Statement::Print(_, _, var) => allocator
|
||||
.text("print")
|
||||
.append(allocator.space())
|
||||
.append(allocator.text(var.as_ref().to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The representation of an expression.
|
||||
///
|
||||
/// Note that expressions, like everything else in this syntax tree,
|
||||
/// supports [`Pretty`], and it's strongly encouraged that you use
|
||||
/// that trait/module when printing these structures.
|
||||
///
|
||||
/// Also, Expressions at this point in the compiler are explicitly
|
||||
/// defined so that they are *not* recursive. By this point, if an
|
||||
/// expression requires some other data (like, for example, invoking
|
||||
/// a primitive), any subexpressions have been bound to variables so
|
||||
/// that the referenced data will always either be a constant or a
|
||||
/// variable reference.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Expression {
|
||||
Atomic(ValueOrRef),
|
||||
Cast(Location, Type, ValueOrRef),
|
||||
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
|
||||
}
|
||||
|
||||
impl Expression {
|
||||
/// Return a reference to the type of the expression, as inferred or recently
|
||||
/// computed.
|
||||
pub fn type_of(&self) -> &Type {
|
||||
match self {
|
||||
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
|
||||
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
|
||||
Expression::Cast(_, t, _) => t,
|
||||
Expression::Primitive(_, t, _, _) => t,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the location associated with the expression.
|
||||
pub fn location(&self) -> &Location {
|
||||
match self {
|
||||
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
|
||||
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
|
||||
Expression::Cast(l, _, _) => l,
|
||||
Expression::Primitive(l, _, _, _) => l,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
|
||||
where
|
||||
A: 'a,
|
||||
D: ?Sized + DocAllocator<'a, A>,
|
||||
{
|
||||
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
|
||||
match self {
|
||||
Expression::Atomic(x) => x.pretty(allocator),
|
||||
Expression::Cast(_, t, e) => allocator
|
||||
.text("<")
|
||||
.append(t.pretty(allocator))
|
||||
.append(allocator.text(">"))
|
||||
.append(e.pretty(allocator)),
|
||||
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
|
||||
op.pretty(allocator).append(exprs[0].pretty(allocator))
|
||||
}
|
||||
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
|
||||
let left = exprs[0].pretty(allocator);
|
||||
let right = exprs[1].pretty(allocator);
|
||||
|
||||
left.append(allocator.space())
|
||||
.append(op.pretty(allocator))
|
||||
.append(allocator.space())
|
||||
.append(right)
|
||||
.parens()
|
||||
}
|
||||
Expression::Primitive(_, _, op, exprs) => {
|
||||
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An expression that is always either a value or a reference.
|
||||
///
|
||||
/// This is the type used to guarantee that we don't nest expressions
|
||||
/// at this level. Instead, expressions that take arguments take one
|
||||
/// of these, which can only be a constant or a reference.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum ValueOrRef {
|
||||
Value(Location, Type, Value),
|
||||
Ref(Location, Type, ArcIntern<String>),
|
||||
}
|
||||
|
||||
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
|
||||
where
|
||||
A: 'a,
|
||||
D: ?Sized + DocAllocator<'a, A>,
|
||||
{
|
||||
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
|
||||
match self {
|
||||
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
|
||||
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValueOrRef> for Expression {
|
||||
fn from(value: ValueOrRef) -> Self {
|
||||
Expression::Atomic(value)
|
||||
}
|
||||
}
|
||||
|
||||
/// A constant in the IR.
|
||||
///
|
||||
/// The optional argument in numeric types is the base that was used by the
|
||||
/// user to input the number. By retaining it, we can ensure that if we need
|
||||
/// to print the number back out, we can do so in the form that the user
|
||||
/// entered it.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Value {
|
||||
Unknown(Option<u8>, u64),
|
||||
I8(Option<u8>, i8),
|
||||
I16(Option<u8>, i16),
|
||||
I32(Option<u8>, i32),
|
||||
I64(Option<u8>, i64),
|
||||
U8(Option<u8>, u8),
|
||||
U16(Option<u8>, u16),
|
||||
U32(Option<u8>, u32),
|
||||
U64(Option<u8>, u64),
|
||||
}
|
||||
|
||||
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
|
||||
where
|
||||
A: 'a,
|
||||
D: ?Sized + DocAllocator<'a, A>,
|
||||
{
|
||||
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
|
||||
let pretty_internal = |opt_base: &Option<u8>, x, t| {
|
||||
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
|
||||
};
|
||||
|
||||
let pretty_internal_signed = |opt_base, x: i64, t| {
|
||||
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
|
||||
|
||||
allocator.text("-").append(base)
|
||||
};
|
||||
|
||||
match self {
|
||||
Value::Unknown(opt_base, value) => {
|
||||
pretty_internal_signed(opt_base, *value as i64, ConstantType::U64)
|
||||
}
|
||||
Value::I8(opt_base, value) => {
|
||||
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
|
||||
}
|
||||
Value::I16(opt_base, value) => {
|
||||
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
|
||||
}
|
||||
Value::I32(opt_base, value) => {
|
||||
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
|
||||
}
|
||||
Value::I64(opt_base, value) => {
|
||||
pretty_internal_signed(opt_base, *value, ConstantType::I64)
|
||||
}
|
||||
Value::U8(opt_base, value) => {
|
||||
pretty_internal(opt_base, *value as u64, ConstantType::U8)
|
||||
}
|
||||
Value::U16(opt_base, value) => {
|
||||
pretty_internal(opt_base, *value as u64, ConstantType::U16)
|
||||
}
|
||||
Value::U32(opt_base, value) => {
|
||||
pretty_internal(opt_base, *value as u64, ConstantType::U32)
|
||||
}
|
||||
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum Type {
|
||||
Variable(Location, ArcIntern<String>),
|
||||
Primitive(PrimitiveType),
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn is_concrete(&self) -> bool {
|
||||
!matches!(self, Type::Variable(_, _))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
|
||||
where
|
||||
A: 'a,
|
||||
D: ?Sized + DocAllocator<'a, A>,
|
||||
{
|
||||
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
|
||||
match self {
|
||||
Type::Variable(_, x) => allocator.text(x.to_string()),
|
||||
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Type {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Type::Variable(_, x) => write!(f, "{}", x),
|
||||
Type::Primitive(pt) => pt.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a fresh new name based on the given name.
|
||||
///
|
||||
/// The new name is guaranteed to be unique across the entirety of the
|
||||
/// execution. This is achieved by using characters in the variable name
|
||||
/// that would not be valid input, and by including a counter that is
|
||||
/// incremented on every invocation.
|
||||
pub fn gensym(name: &str) -> ArcIntern<String> {
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
let new_name = format!(
|
||||
"<{}:{}>",
|
||||
name,
|
||||
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
|
||||
);
|
||||
ArcIntern::new(new_name)
|
||||
}
|
||||
|
||||
/// Generate a fresh new type; this will be a unique new type variable.
|
||||
///
|
||||
/// The new name is guaranteed to be unique across the entirety of the
|
||||
/// execution. This is achieved by using characters in the variable name
|
||||
/// that would not be valid input, and by including a counter that is
|
||||
/// incremented on every invocation.
|
||||
pub fn gentype() -> Type {
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
let name = ArcIntern::new(format!(
|
||||
"t<{}>",
|
||||
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
|
||||
));
|
||||
|
||||
Type::Variable(Location::manufactured(), name)
|
||||
}
|
||||
378
src/type_infer/convert.rs
Normal file
378
src/type_infer/convert.rs
Normal file
@@ -0,0 +1,378 @@
|
||||
use super::ast as ir;
|
||||
use super::ast::Type;
|
||||
use crate::eval::PrimitiveType;
|
||||
use crate::syntax::{self, ConstantType};
|
||||
use crate::type_infer::solve::Constraint;
|
||||
use internment::ArcIntern;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// This function takes a syntactic program and converts it into the IR version of the
|
||||
/// program, with appropriate type variables introduced and their constraints added to
|
||||
/// the given database.
|
||||
///
|
||||
/// If the input function has been validated (which it should be), then this should run
|
||||
/// into no error conditions. However, if you failed to validate the input, then this
|
||||
/// function can panic.
|
||||
pub fn convert_program(
|
||||
mut program: syntax::Program,
|
||||
constraint_db: &mut Vec<Constraint>,
|
||||
) -> ir::Program {
|
||||
let mut statements = Vec::new();
|
||||
let mut renames = HashMap::new();
|
||||
let mut bindings = HashMap::new();
|
||||
|
||||
for stmt in program.statements.drain(..) {
|
||||
statements.append(&mut convert_statement(
|
||||
stmt,
|
||||
constraint_db,
|
||||
&mut renames,
|
||||
&mut bindings,
|
||||
));
|
||||
}
|
||||
|
||||
ir::Program { statements }
|
||||
}
|
||||
|
||||
/// This function takes a syntactic statements and converts it into a series of
|
||||
/// IR statements, adding type variables and constraints as necessary.
|
||||
///
|
||||
/// We generate a series of statements because we're going to flatten all
|
||||
/// incoming expressions so that they are no longer recursive. This will
|
||||
/// generate a bunch of new bindings for all the subexpressions, which we
|
||||
/// return as a bundle.
|
||||
///
|
||||
/// See the safety warning on [`convert_program`]! This function assumes that
|
||||
/// you have run [`Statement::validate`], and will trigger panics in error
|
||||
/// conditions if you have run that and had it come back clean.
|
||||
fn convert_statement(
|
||||
statement: syntax::Statement,
|
||||
constraint_db: &mut Vec<Constraint>,
|
||||
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
|
||||
bindings: &mut HashMap<ArcIntern<String>, Type>,
|
||||
) -> Vec<ir::Statement> {
|
||||
match statement {
|
||||
syntax::Statement::Print(loc, name) => {
|
||||
let iname = ArcIntern::new(name.to_string());
|
||||
let final_name = renames
|
||||
.get(&iname)
|
||||
.map(Clone::clone)
|
||||
.unwrap_or_else(|| iname.clone());
|
||||
let varty = bindings
|
||||
.get(&final_name)
|
||||
.expect("print variable defined before use")
|
||||
.clone();
|
||||
|
||||
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
|
||||
|
||||
vec![ir::Statement::Print(loc, varty, iname)]
|
||||
}
|
||||
|
||||
syntax::Statement::Binding(loc, name, expr) => {
|
||||
let (mut prereqs, expr, ty) =
|
||||
convert_expression(expr, constraint_db, renames, bindings);
|
||||
let iname = ArcIntern::new(name.to_string());
|
||||
let final_name = if bindings.contains_key(&iname) {
|
||||
let new_name = ir::gensym(iname.as_str());
|
||||
renames.insert(iname, new_name.clone());
|
||||
new_name
|
||||
} else {
|
||||
iname
|
||||
};
|
||||
|
||||
bindings.insert(final_name.clone(), ty.clone());
|
||||
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr));
|
||||
prereqs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This function takes a syntactic expression and converts it into a series
|
||||
/// of IR statements, adding type variables and constraints as necessary.
|
||||
///
|
||||
/// We generate a series of statements because we're going to flatten all
|
||||
/// incoming expressions so that they are no longer recursive. This will
|
||||
/// generate a bunch of new bindings for all the subexpressions, which we
|
||||
/// return as a bundle.
|
||||
///
|
||||
/// See the safety warning on [`convert_program`]! This function assumes that
|
||||
/// you have run [`Statement::validate`], and will trigger panics in error
|
||||
/// conditions if you have run that and had it come back clean.
|
||||
fn convert_expression(
|
||||
expression: syntax::Expression,
|
||||
constraint_db: &mut Vec<Constraint>,
|
||||
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>,
|
||||
bindings: &mut HashMap<ArcIntern<String>, Type>,
|
||||
) -> (Vec<ir::Statement>, ir::Expression, Type) {
|
||||
match expression {
|
||||
syntax::Expression::Value(loc, val) => match val {
|
||||
syntax::Value::Number(base, mctype, value) => {
|
||||
let (newval, newtype) = match mctype {
|
||||
None => {
|
||||
let newtype = ir::gentype();
|
||||
let newval = ir::Value::Unknown(base, value);
|
||||
|
||||
constraint_db.push(Constraint::ConstantNumericType(
|
||||
loc.clone(),
|
||||
newtype.clone(),
|
||||
));
|
||||
(newval, newtype)
|
||||
}
|
||||
Some(ConstantType::U8) => (
|
||||
ir::Value::U8(base, value as u8),
|
||||
ir::Type::Primitive(PrimitiveType::U8),
|
||||
),
|
||||
Some(ConstantType::U16) => (
|
||||
ir::Value::U16(base, value as u16),
|
||||
ir::Type::Primitive(PrimitiveType::U16),
|
||||
),
|
||||
Some(ConstantType::U32) => (
|
||||
ir::Value::U32(base, value as u32),
|
||||
ir::Type::Primitive(PrimitiveType::U32),
|
||||
),
|
||||
Some(ConstantType::U64) => (
|
||||
ir::Value::U64(base, value),
|
||||
ir::Type::Primitive(PrimitiveType::U64),
|
||||
),
|
||||
Some(ConstantType::I8) => (
|
||||
ir::Value::I8(base, value as i8),
|
||||
ir::Type::Primitive(PrimitiveType::I8),
|
||||
),
|
||||
Some(ConstantType::I16) => (
|
||||
ir::Value::I16(base, value as i16),
|
||||
ir::Type::Primitive(PrimitiveType::I16),
|
||||
),
|
||||
Some(ConstantType::I32) => (
|
||||
ir::Value::I32(base, value as i32),
|
||||
ir::Type::Primitive(PrimitiveType::I32),
|
||||
),
|
||||
Some(ConstantType::I64) => (
|
||||
ir::Value::I64(base, value as i64),
|
||||
ir::Type::Primitive(PrimitiveType::I64),
|
||||
),
|
||||
};
|
||||
|
||||
constraint_db.push(Constraint::FitsInNumType(
|
||||
loc.clone(),
|
||||
newtype.clone(),
|
||||
value,
|
||||
));
|
||||
(
|
||||
vec![],
|
||||
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)),
|
||||
newtype,
|
||||
)
|
||||
}
|
||||
},
|
||||
|
||||
syntax::Expression::Reference(loc, name) => {
|
||||
let iname = ArcIntern::new(name);
|
||||
let final_name = renames.get(&iname).cloned().unwrap_or(iname);
|
||||
let rtype = bindings
|
||||
.get(&final_name)
|
||||
.cloned()
|
||||
.expect("variable bound before use");
|
||||
let refexp =
|
||||
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
|
||||
|
||||
(vec![], refexp, rtype)
|
||||
}
|
||||
|
||||
syntax::Expression::Cast(loc, target, expr) => {
|
||||
let (mut stmts, nexpr, etype) =
|
||||
convert_expression(*expr, constraint_db, renames, bindings);
|
||||
let val_or_ref = simplify_expr(nexpr, &mut stmts);
|
||||
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast");
|
||||
let target_type = Type::Primitive(target_prim_type);
|
||||
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
|
||||
|
||||
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone()));
|
||||
|
||||
(stmts, res, target_type)
|
||||
}
|
||||
|
||||
syntax::Expression::Primitive(loc, fun, mut args) => {
|
||||
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
|
||||
let mut stmts = vec![];
|
||||
let mut nargs = vec![];
|
||||
let mut atypes = vec![];
|
||||
let ret_type = ir::gentype();
|
||||
|
||||
for arg in args.drain(..) {
|
||||
let (mut astmts, aexp, atype) =
|
||||
convert_expression(arg, constraint_db, renames, bindings);
|
||||
|
||||
stmts.append(&mut astmts);
|
||||
nargs.push(simplify_expr(aexp, &mut stmts));
|
||||
atypes.push(atype);
|
||||
}
|
||||
|
||||
constraint_db.push(Constraint::ProperPrimitiveArgs(
|
||||
loc.clone(),
|
||||
primop,
|
||||
atypes.clone(),
|
||||
ret_type.clone(),
|
||||
));
|
||||
|
||||
(
|
||||
stmts,
|
||||
ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs),
|
||||
ret_type,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef {
|
||||
match expr {
|
||||
ir::Expression::Atomic(v_or_ref) => v_or_ref,
|
||||
expr => {
|
||||
let etype = expr.type_of().clone();
|
||||
let loc = expr.location().clone();
|
||||
let nname = ir::gensym("g");
|
||||
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr);
|
||||
|
||||
stmts.push(nbinding);
|
||||
ir::ValueOrRef::Ref(loc, etype, nname)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::syntax::Location;
|
||||
|
||||
fn one() -> syntax::Expression {
|
||||
syntax::Expression::Value(
|
||||
Location::manufactured(),
|
||||
syntax::Value::Number(None, None, 1),
|
||||
)
|
||||
}
|
||||
|
||||
fn vec_contains<T, F: Fn(&T) -> bool>(x: &[T], f: F) -> bool {
|
||||
for x in x.iter() {
|
||||
if f(x) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn infer_expression(
|
||||
x: syntax::Expression,
|
||||
) -> (ir::Expression, Vec<ir::Statement>, Vec<Constraint>, Type) {
|
||||
let mut constraints = Vec::new();
|
||||
let renames = HashMap::new();
|
||||
let mut bindings = HashMap::new();
|
||||
let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings);
|
||||
(expr, stmts, constraints, ty)
|
||||
}
|
||||
|
||||
fn infer_statement(x: syntax::Statement) -> (Vec<ir::Statement>, Vec<Constraint>) {
|
||||
let mut constraints = Vec::new();
|
||||
let mut renames = HashMap::new();
|
||||
let mut bindings = HashMap::new();
|
||||
let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings);
|
||||
(res, constraints)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constant_one() {
|
||||
let (expr, stmts, constraints, ty) = infer_expression(one());
|
||||
assert!(stmts.is_empty());
|
||||
assert!(matches!(
|
||||
expr,
|
||||
ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1)))
|
||||
));
|
||||
assert!(vec_contains(&constraints, |x| matches!(
|
||||
x,
|
||||
Constraint::FitsInNumType(_, _, 1)
|
||||
)));
|
||||
assert!(vec_contains(
|
||||
&constraints,
|
||||
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_plus_one() {
|
||||
let opo = syntax::Expression::Primitive(
|
||||
Location::manufactured(),
|
||||
"+".to_string(),
|
||||
vec![one(), one()],
|
||||
);
|
||||
let (expr, stmts, constraints, ty) = infer_expression(opo);
|
||||
assert!(stmts.is_empty());
|
||||
assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty));
|
||||
assert!(vec_contains(&constraints, |x| matches!(
|
||||
x,
|
||||
Constraint::FitsInNumType(_, _, 1)
|
||||
)));
|
||||
assert!(vec_contains(
|
||||
&constraints,
|
||||
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty)
|
||||
));
|
||||
assert!(vec_contains(
|
||||
&constraints,
|
||||
|x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_plus_one_plus_one() {
|
||||
let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse");
|
||||
let (stmts, constraints) = infer_statement(stmt);
|
||||
assert_eq!(stmts.len(), 2);
|
||||
let ir::Statement::Binding(
|
||||
_args,
|
||||
name1,
|
||||
temp_ty1,
|
||||
ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1),
|
||||
) = stmts.get(0).expect("item two")
|
||||
else {
|
||||
panic!("Failed to match first statement");
|
||||
};
|
||||
let ir::Statement::Binding(
|
||||
_args,
|
||||
name2,
|
||||
temp_ty2,
|
||||
ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2),
|
||||
) = stmts.get(1).expect("item two")
|
||||
else {
|
||||
panic!("Failed to match second statement");
|
||||
};
|
||||
let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] =
|
||||
&primargs1[..]
|
||||
else {
|
||||
panic!("Failed to match first arguments");
|
||||
};
|
||||
let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] =
|
||||
&primargs2[..]
|
||||
else {
|
||||
panic!("Failed to match first arguments");
|
||||
};
|
||||
assert_ne!(name1, name2);
|
||||
assert_ne!(temp_ty1, temp_ty2);
|
||||
assert_ne!(primty1, primty2);
|
||||
assert_eq!(name1, left2name);
|
||||
assert!(vec_contains(
|
||||
&constraints,
|
||||
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty)
|
||||
));
|
||||
assert!(vec_contains(
|
||||
&constraints,
|
||||
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty)
|
||||
));
|
||||
assert!(vec_contains(
|
||||
&constraints,
|
||||
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty)
|
||||
));
|
||||
for (i, s) in stmts.iter().enumerate() {
|
||||
println!("{}: {:?}", i, s);
|
||||
}
|
||||
for (i, c) in constraints.iter().enumerate() {
|
||||
println!("{}: {:?}", i, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
187
src/type_infer/finalize.rs
Normal file
187
src/type_infer/finalize.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use super::{ast as input, solve::TypeResolutions};
|
||||
use crate::{eval::PrimitiveType, ir as output};
|
||||
|
||||
pub fn finalize_program(
|
||||
mut program: input::Program,
|
||||
resolutions: &TypeResolutions,
|
||||
) -> output::Program {
|
||||
output::Program {
|
||||
statements: program
|
||||
.statements
|
||||
.drain(..)
|
||||
.map(|x| finalize_statement(x, resolutions))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_statement(
|
||||
statement: input::Statement,
|
||||
resolutions: &TypeResolutions,
|
||||
) -> output::Statement {
|
||||
match statement {
|
||||
input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding(
|
||||
loc,
|
||||
var,
|
||||
finalize_type(ty, resolutions),
|
||||
finalize_expression(expr, resolutions),
|
||||
),
|
||||
input::Statement::Print(loc, ty, var) => {
|
||||
output::Statement::Print(loc, finalize_type(ty, resolutions), var)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_expression(
|
||||
expression: input::Expression,
|
||||
resolutions: &TypeResolutions,
|
||||
) -> output::Expression {
|
||||
match expression {
|
||||
input::Expression::Atomic(val_or_ref) => {
|
||||
output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions))
|
||||
}
|
||||
input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast(
|
||||
loc,
|
||||
finalize_type(target, resolutions),
|
||||
finalize_val_or_ref(val_or_ref, resolutions),
|
||||
),
|
||||
input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive(
|
||||
loc,
|
||||
finalize_type(ty, resolutions),
|
||||
prim,
|
||||
args.drain(..)
|
||||
.map(|x| finalize_val_or_ref(x, resolutions))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type {
|
||||
match ty {
|
||||
input::Type::Primitive(x) => output::Type::Primitive(x),
|
||||
input::Type::Variable(_, tvar) => match resolutions.get(&tvar) {
|
||||
None => panic!("Did not resolve type for type variable {}", tvar),
|
||||
Some(pt) => output::Type::Primitive(*pt),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_val_or_ref(
|
||||
valref: input::ValueOrRef,
|
||||
resolutions: &TypeResolutions,
|
||||
) -> output::ValueOrRef {
|
||||
match valref {
|
||||
input::ValueOrRef::Ref(loc, ty, var) => {
|
||||
output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var)
|
||||
}
|
||||
input::ValueOrRef::Value(loc, ty, val) => {
|
||||
let new_type = finalize_type(ty, resolutions);
|
||||
|
||||
match val {
|
||||
input::Value::Unknown(base, value) => match new_type {
|
||||
output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::U8(base, value as u8),
|
||||
),
|
||||
output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::U16(base, value as u16),
|
||||
),
|
||||
output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::U32(base, value as u32),
|
||||
),
|
||||
output::Type::Primitive(PrimitiveType::U64) => {
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
|
||||
}
|
||||
output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::I8(base, value as i8),
|
||||
),
|
||||
output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::I16(base, value as i16),
|
||||
),
|
||||
output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::I32(base, value as i32),
|
||||
),
|
||||
output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value(
|
||||
loc,
|
||||
new_type,
|
||||
output::Value::I64(base, value as i64),
|
||||
),
|
||||
},
|
||||
|
||||
input::Value::U8(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::U8)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value))
|
||||
}
|
||||
|
||||
input::Value::U16(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::U16)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value))
|
||||
}
|
||||
|
||||
input::Value::U32(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::U32)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value))
|
||||
}
|
||||
|
||||
input::Value::U64(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::U64)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
|
||||
}
|
||||
|
||||
input::Value::I8(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::I8)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value))
|
||||
}
|
||||
|
||||
input::Value::I16(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::I16)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value))
|
||||
}
|
||||
|
||||
input::Value::I32(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::I32)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value))
|
||||
}
|
||||
|
||||
input::Value::I64(base, value) => {
|
||||
assert!(matches!(
|
||||
new_type,
|
||||
output::Type::Primitive(PrimitiveType::I64)
|
||||
));
|
||||
output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
542
src/type_infer/solve.rs
Normal file
542
src/type_infer/solve.rs
Normal file
@@ -0,0 +1,542 @@
|
||||
use super::ast as ir;
|
||||
use super::ast::Type;
|
||||
use crate::{eval::PrimitiveType, syntax::Location};
|
||||
use codespan_reporting::diagnostic::Diagnostic;
|
||||
use internment::ArcIntern;
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
/// A type inference constraint that we're going to need to solve.
|
||||
#[derive(Debug)]
|
||||
pub enum Constraint {
|
||||
/// The given type must be printable using the `print` built-in
|
||||
Printable(Location, Type),
|
||||
/// The provided numeric value fits in the given constant type
|
||||
FitsInNumType(Location, Type, u64),
|
||||
/// The given primitive has the proper arguments types associated with it
|
||||
ProperPrimitiveArgs(Location, ir::Primitive, Vec<Type>, Type),
|
||||
/// The given type can be casted to the target type safely
|
||||
CanCastTo(Location, Type, Type),
|
||||
/// The given type must be some numeric type, but this is not a constant
|
||||
/// value, so don't try to default it if we can't figure it out
|
||||
NumericType(Location, Type),
|
||||
/// The given type is attached to a constant and must be some numeric type.
|
||||
/// If we can't figure it out, we should warn the user and then just use a
|
||||
/// default.
|
||||
ConstantNumericType(Location, Type),
|
||||
/// The two types should be equivalent
|
||||
Equivalent(Location, Type, Type),
|
||||
}
|
||||
|
||||
impl fmt::Display for Constraint {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty),
|
||||
Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty),
|
||||
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => {
|
||||
write!(f, "PRIM {} {} -> {}", op, args[0], ret)
|
||||
}
|
||||
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => {
|
||||
write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret)
|
||||
}
|
||||
Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret),
|
||||
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
|
||||
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
|
||||
Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty),
|
||||
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>;
|
||||
|
||||
/// The results of type inference; like [`Result`], but with a bit more information.
|
||||
///
|
||||
/// This result is parameterized, because sometimes it's handy to return slightly
|
||||
/// different things; there's a [`TypeInferenceResult::map`] function for performing
|
||||
/// those sorts of conversions.
|
||||
pub enum TypeInferenceResult<Result> {
|
||||
Success {
|
||||
result: Result,
|
||||
warnings: Vec<TypeInferenceWarning>,
|
||||
},
|
||||
Failure {
|
||||
errors: Vec<TypeInferenceError>,
|
||||
warnings: Vec<TypeInferenceWarning>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<R> TypeInferenceResult<R> {
|
||||
// If this was a successful type inference, run the function over the result to
|
||||
// create a new result.
|
||||
//
|
||||
// This is the moral equivalent of [`Result::map`], but for type inference results.
|
||||
pub fn map<U, F>(self, f: F) -> TypeInferenceResult<U>
|
||||
where
|
||||
F: FnOnce(R) -> U,
|
||||
{
|
||||
match self {
|
||||
TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success {
|
||||
result: f(result),
|
||||
warnings,
|
||||
},
|
||||
|
||||
TypeInferenceResult::Failure { errors, warnings } => {
|
||||
TypeInferenceResult::Failure { errors, warnings }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the final result, or panic if it's not a success
|
||||
pub fn expect(self, msg: &str) -> R {
|
||||
match self {
|
||||
TypeInferenceResult::Success { result, .. } => result,
|
||||
TypeInferenceResult::Failure { .. } => {
|
||||
panic!("tried to get value from failed type inference: {}", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The various kinds of errors that can occur while doing type inference.
|
||||
pub enum TypeInferenceError {
|
||||
/// The user provide a constant that is too large for its inferred type.
|
||||
ConstantTooLarge(Location, PrimitiveType, u64),
|
||||
/// The two types needed to be equivalent, but weren't.
|
||||
NotEquivalent(Location, PrimitiveType, PrimitiveType),
|
||||
/// We cannot safely cast the first type to the second type.
|
||||
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
|
||||
/// The primitive invocation provided the wrong number of arguments.
|
||||
WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize),
|
||||
/// We had a constraint we just couldn't solve.
|
||||
CouldNotSolve(Constraint),
|
||||
}
|
||||
|
||||
impl From<TypeInferenceError> for Diagnostic<usize> {
|
||||
fn from(value: TypeInferenceError) -> Self {
|
||||
match value {
|
||||
TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc
|
||||
.labelled_error("constant too large for type")
|
||||
.with_message(format!(
|
||||
"Type {} has a max value of {}, which is smaller than {}",
|
||||
primty,
|
||||
primty.max_value(),
|
||||
value
|
||||
)),
|
||||
TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc
|
||||
.labelled_error("type inference error")
|
||||
.with_message(format!("Expected type {}, received type {}", ty1, ty2)),
|
||||
TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc
|
||||
.labelled_error("unsafe type cast")
|
||||
.with_message(format!("Cannot safely cast {} to {}", ty1, ty2)),
|
||||
TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc
|
||||
.labelled_error("wrong number of arguments")
|
||||
.with_message(format!(
|
||||
"expected {} for {}, received {}",
|
||||
if lower == upper && lower > 1 {
|
||||
format!("{} arguments", lower)
|
||||
} else if lower == upper {
|
||||
format!("{} argument", lower)
|
||||
} else {
|
||||
format!("{}-{} arguments", lower, upper)
|
||||
},
|
||||
prim,
|
||||
observed
|
||||
)),
|
||||
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
|
||||
loc.labelled_error("internal error").with_message(format!(
|
||||
"could not determine if it was safe to cast from {} to {:#?}",
|
||||
a, b
|
||||
))
|
||||
}
|
||||
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
|
||||
loc.labelled_error("internal error").with_message(format!(
|
||||
"could not determine if {} and {:#?} were equivalent",
|
||||
a, b
|
||||
))
|
||||
}
|
||||
TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => {
|
||||
loc.labelled_error("internal error").with_message(format!(
|
||||
"Could not determine if {} could fit in {}",
|
||||
val, ty
|
||||
))
|
||||
}
|
||||
TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc
|
||||
.labelled_error("internal error")
|
||||
.with_message(format!("Could not determine if {} was a numeric type", ty)),
|
||||
TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) =>
|
||||
panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty),
|
||||
TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc
|
||||
.labelled_error("internal error")
|
||||
.with_message(format!("Could not determine if type {} was printable", ty)),
|
||||
TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => {
|
||||
loc.labelled_error("internal error").with_message(format!(
|
||||
"Could not tell if primitive {} received the proper argument types",
|
||||
prim
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Warnings that we might want to tell the user about.
|
||||
///
|
||||
/// These are fine, probably, but could indicate some behavior the user might not
|
||||
/// expect, and so they might want to do something about them.
|
||||
pub enum TypeInferenceWarning {
|
||||
DefaultedTo(Location, Type),
|
||||
}
|
||||
|
||||
impl From<TypeInferenceWarning> for Diagnostic<usize> {
|
||||
fn from(value: TypeInferenceWarning) -> Self {
|
||||
match value {
|
||||
TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning()
|
||||
.with_labels(vec![loc.primary_label().with_message("unknown type")])
|
||||
.with_message(format!("Defaulted unknown type to {}", ty)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solve all the constraints in the provided database.
|
||||
///
|
||||
/// This process can take a bit, so you might not want to do it multiple times. Basically,
|
||||
/// it's going to grind on these constraints until either it figures them out, or it stops
|
||||
/// making progress. I haven't done the math on the constraints to even figure out if this
|
||||
/// is guaranteed to halt, though, let alone terminate in some reasonable amount of time.
|
||||
///
|
||||
/// The return value is a type inference result, which pairs some warnings with either a
|
||||
/// successful set of type resolutions (mappings from type variables to their values), or
|
||||
/// a series of inference errors.
|
||||
pub fn solve_constraints(
|
||||
mut constraint_db: Vec<Constraint>,
|
||||
) -> TypeInferenceResult<TypeResolutions> {
|
||||
let mut errors = vec![];
|
||||
let mut warnings = vec![];
|
||||
let mut resolutions = HashMap::new();
|
||||
let mut changed_something = true;
|
||||
|
||||
// We want to run this inference endlessly, until either we have solved all of our
|
||||
// constraints. Internal to the loop, we have a check that will make sure that we
|
||||
// do (eventually) stop.
|
||||
while changed_something && !constraint_db.is_empty() {
|
||||
// Set this to false at the top of the loop. We'll set this to true if we make
|
||||
// progress in any way further down, but having this here prevents us from going
|
||||
// into an infinite look when we can't figure stuff out.
|
||||
changed_something = false;
|
||||
// This is sort of a double-buffering thing; we're going to rename constraint_db
|
||||
// and then set it to a new empty vector, which we'll add to as we find new
|
||||
// constraints or find ourselves unable to solve existing ones.
|
||||
let mut local_constraints = constraint_db;
|
||||
constraint_db = vec![];
|
||||
|
||||
// OK. First thing we're going to do is run through all of our constraints,
|
||||
// and see if we can solve any, or reduce them to theoretically more simple
|
||||
// constraints.
|
||||
for constraint in local_constraints.drain(..) {
|
||||
match constraint {
|
||||
// Currently, all of our types are printable
|
||||
Constraint::Printable(_loc, _ty) => changed_something = true,
|
||||
|
||||
// Case #1: We have two primitive types. If they're equal, we've discharged this
|
||||
// constraint! We can just continue. If they're not equal, add an error and then
|
||||
// see what else we come up with.
|
||||
Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => {
|
||||
if t1 != t2 {
|
||||
errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2));
|
||||
}
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
// Case #2: One of the two constraints is a primitive, and the other is a variable.
|
||||
// In this case, we'll check to see if we've resolved the variable, and check for
|
||||
// equivalence if we have. If we haven't, we'll set that variable to be primitive
|
||||
// type.
|
||||
Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name))
|
||||
| Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => {
|
||||
match resolutions.get(&name) {
|
||||
None => {
|
||||
resolutions.insert(name, t);
|
||||
}
|
||||
Some(t2) if &t == t2 => {}
|
||||
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)),
|
||||
}
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
// Case #3: They're both variables. In which case, we'll have to do much the same
|
||||
// check, but now on their resolutions.
|
||||
Constraint::Equivalent(
|
||||
ref loc,
|
||||
Type::Variable(_, ref name1),
|
||||
Type::Variable(_, ref name2),
|
||||
) => match (resolutions.get(name1), resolutions.get(name2)) {
|
||||
(None, None) => {
|
||||
constraint_db.push(constraint);
|
||||
}
|
||||
(Some(pt), None) => {
|
||||
resolutions.insert(name2.clone(), *pt);
|
||||
changed_something = true;
|
||||
}
|
||||
(None, Some(pt)) => {
|
||||
resolutions.insert(name1.clone(), *pt);
|
||||
changed_something = true;
|
||||
}
|
||||
(Some(pt1), Some(pt2)) if pt1 == pt2 => {
|
||||
changed_something = true;
|
||||
}
|
||||
(Some(pt1), Some(pt2)) => {
|
||||
errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2));
|
||||
changed_something = true;
|
||||
}
|
||||
},
|
||||
|
||||
// Make sure that the provided number fits within the provided constant type. For the
|
||||
// moment, we're going to call an error here a failure, although this could be a
|
||||
// warning in the future.
|
||||
Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => {
|
||||
if ctype.max_value() < val {
|
||||
errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val));
|
||||
}
|
||||
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
// If we have a non-constant type, then let's see if we can advance this to a constant
|
||||
// type
|
||||
Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => {
|
||||
match resolutions.get(&var) {
|
||||
None => constraint_db.push(Constraint::FitsInNumType(
|
||||
loc,
|
||||
Type::Variable(vloc, var),
|
||||
val,
|
||||
)),
|
||||
Some(nt) => {
|
||||
constraint_db.push(Constraint::FitsInNumType(
|
||||
loc,
|
||||
Type::Primitive(*nt),
|
||||
val,
|
||||
));
|
||||
changed_something = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the left type in a "can cast to" check is a variable, let's see if we can advance
|
||||
// it into something more tangible
|
||||
Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => {
|
||||
match resolutions.get(&var) {
|
||||
None => constraint_db.push(Constraint::CanCastTo(
|
||||
loc,
|
||||
Type::Variable(vloc, var),
|
||||
to_type,
|
||||
)),
|
||||
Some(nt) => {
|
||||
constraint_db.push(Constraint::CanCastTo(
|
||||
loc,
|
||||
Type::Primitive(*nt),
|
||||
to_type,
|
||||
));
|
||||
changed_something = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the right type in a "can cast to" check is a variable, same deal
|
||||
Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => {
|
||||
match resolutions.get(&var) {
|
||||
None => constraint_db.push(Constraint::CanCastTo(
|
||||
loc,
|
||||
from_type,
|
||||
Type::Variable(vloc, var),
|
||||
)),
|
||||
Some(nt) => {
|
||||
constraint_db.push(Constraint::CanCastTo(
|
||||
loc,
|
||||
from_type,
|
||||
Type::Primitive(*nt),
|
||||
));
|
||||
changed_something = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If both of them are types, then we can actually do the test. yay!
|
||||
Constraint::CanCastTo(
|
||||
loc,
|
||||
Type::Primitive(from_type),
|
||||
Type::Primitive(to_type),
|
||||
) => {
|
||||
if !from_type.can_cast_to(&to_type) {
|
||||
errors.push(TypeInferenceError::CannotSafelyCast(
|
||||
loc, from_type, to_type,
|
||||
));
|
||||
}
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
// As per usual, if we're trying to test if a type variable is numeric, first
|
||||
// we try to advance it to a primitive
|
||||
Constraint::NumericType(loc, Type::Variable(vloc, var)) => {
|
||||
match resolutions.get(&var) {
|
||||
None => constraint_db
|
||||
.push(Constraint::NumericType(loc, Type::Variable(vloc, var))),
|
||||
Some(nt) => {
|
||||
constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt)));
|
||||
changed_something = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Of course, if we get to a primitive type, then it's true, because all of our
|
||||
// primitive types are numbers
|
||||
Constraint::NumericType(_, Type::Primitive(_)) => {
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
// As per usual, if we're trying to test if a type variable is numeric, first
|
||||
// we try to advance it to a primitive
|
||||
Constraint::ConstantNumericType(loc, Type::Variable(vloc, var)) => {
|
||||
match resolutions.get(&var) {
|
||||
None => constraint_db.push(Constraint::ConstantNumericType(
|
||||
loc,
|
||||
Type::Variable(vloc, var),
|
||||
)),
|
||||
Some(nt) => {
|
||||
constraint_db
|
||||
.push(Constraint::ConstantNumericType(loc, Type::Primitive(*nt)));
|
||||
changed_something = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Of course, if we get to a primitive type, then it's true, because all of our
|
||||
// primitive types are numbers
|
||||
Constraint::ConstantNumericType(_, Type::Primitive(_)) => {
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
// OK, this one could be a little tricky if we tried to do it all at once, but
|
||||
// instead what we're going to do here is just use this constraint to generate
|
||||
// a bunch more constraints, and then go have the engine solve those. The only
|
||||
// real errors we're going to come up with here are "arity errors"; errors we
|
||||
// find by discovering that the number of arguments provided doesn't make sense
|
||||
// given the primitive being used.
|
||||
Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim {
|
||||
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide
|
||||
if args.len() != 2 =>
|
||||
{
|
||||
errors.push(TypeInferenceError::WrongPrimitiveArity(
|
||||
loc,
|
||||
prim,
|
||||
2,
|
||||
2,
|
||||
args.len(),
|
||||
));
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => {
|
||||
let right = args.pop().expect("2 > 0");
|
||||
let left = args.pop().expect("2 > 1");
|
||||
|
||||
// technically testing that both are numeric is redundant, but it might give
|
||||
// a slightly helpful type error if we do both.
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
|
||||
constraint_db.push(Constraint::Equivalent(
|
||||
loc.clone(),
|
||||
left.clone(),
|
||||
right,
|
||||
));
|
||||
constraint_db.push(Constraint::Equivalent(loc, left, ret));
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
ir::Primitive::Minus if args.is_empty() || args.len() > 2 => {
|
||||
errors.push(TypeInferenceError::WrongPrimitiveArity(
|
||||
loc,
|
||||
prim,
|
||||
1,
|
||||
2,
|
||||
args.len(),
|
||||
));
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
ir::Primitive::Minus if args.len() == 1 => {
|
||||
let arg = args.pop().expect("1 > 0");
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone()));
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
|
||||
constraint_db.push(Constraint::Equivalent(loc, arg, ret));
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
ir::Primitive::Minus => {
|
||||
let right = args.pop().expect("2 > 0");
|
||||
let left = args.pop().expect("2 > 1");
|
||||
|
||||
// technically testing that both are numeric is redundant, but it might give
|
||||
// a slightly helpful type error if we do both.
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
|
||||
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
|
||||
constraint_db.push(Constraint::Equivalent(
|
||||
loc.clone(),
|
||||
left.clone(),
|
||||
right,
|
||||
));
|
||||
constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret));
|
||||
changed_something = true;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// If that didn't actually come up with anything, and we just recycled all the constraints
|
||||
// back into the database unchanged, then let's take a look for cases in which we just
|
||||
// wanted something we didn't know to be a number. Basically, those are cases where the
|
||||
// user just wrote a number, but didn't tell us what type it was, and there isn't enough
|
||||
// information in the context to tell us. If that happens, we'll just set that type to
|
||||
// be u64, and warn the user that we did so.
|
||||
if !changed_something && !constraint_db.is_empty() {
|
||||
local_constraints = constraint_db;
|
||||
constraint_db = vec![];
|
||||
|
||||
for constraint in local_constraints.drain(..) {
|
||||
match constraint {
|
||||
Constraint::ConstantNumericType(loc, t @ Type::Variable(_, _)) => {
|
||||
let resty = Type::Primitive(PrimitiveType::U64);
|
||||
constraint_db.push(Constraint::Equivalent(
|
||||
loc.clone(),
|
||||
t,
|
||||
Type::Primitive(PrimitiveType::U64),
|
||||
));
|
||||
warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty));
|
||||
changed_something = true;
|
||||
}
|
||||
|
||||
_ => constraint_db.push(constraint),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OK, we left our loop. Which means that either we solved everything, or we didn't.
|
||||
// If we didn't, turn the unsolved constraints into type inference errors, and add
|
||||
// them to our error list.
|
||||
let mut unsolved_constraint_errors = constraint_db
|
||||
.drain(..)
|
||||
.map(TypeInferenceError::CouldNotSolve)
|
||||
.collect();
|
||||
errors.append(&mut unsolved_constraint_errors);
|
||||
|
||||
// How'd we do?
|
||||
if errors.is_empty() {
|
||||
TypeInferenceResult::Success {
|
||||
result: resolutions,
|
||||
warnings,
|
||||
}
|
||||
} else {
|
||||
TypeInferenceResult::Failure { errors, warnings }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user