lift the inference engine up a level

This commit is contained in:
2023-07-21 21:40:41 -07:00
parent 9fb6bf3b86
commit a8d32a917f
9 changed files with 5 additions and 6 deletions

336
src/type_infer/ast.rs Normal file
View File

@@ -0,0 +1,336 @@
pub use crate::ir::ast::Primitive;
/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going
/// to want to use while we're doing type inference, but don't want to keep around
/// afterwards. These are:
///
/// * A notion of a type variable
/// * An unknown numeric constant form
///
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use std::fmt;
use std::sync::atomic::AtomicUsize;
/// We're going to represent variables as interned strings.
///
/// These should be fast enough for comparison that it's OK, since it's going to end up
/// being pretty much the pointer to the string.
type Variable = ArcIntern<String>;
/// The representation of a program within our IR. For now, this is exactly one file.
///
/// In addition, for the moment there's not really much of interest to hold here besides
/// the list of statements read from the file. Order is important. In the future, you
/// could imagine caching analysis information in this structure.
///
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
/// to print the structure whenever possible, especially if you value your or your
/// user's time. The latter is useful for testing that conversions of `Program` retain
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
/// syntactically valid, although they may contain runtime issue like over- or underflow.
#[derive(Debug)]
pub struct Program {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) statements: Vec<Statement>,
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let mut result = allocator.nil();
for stmt in self.statements.iter() {
// there's probably a better way to do this, rather than constantly
// adding to the end, but this works.
result = result
.append(stmt.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline());
}
result
}
}
/// The representation of a statement in the language.
///
/// For now, this is either a binding site (`x = 4`) or a print statement
/// (`print x`). Someday, though, more!
///
/// As with `Program`, this type implements [`Pretty`], which should
/// be used to display the structure whenever possible. It does not
/// implement [`Arbitrary`], though, mostly because it's slightly
/// complicated to do so.
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
}
}
}
/// The representation of an expression.
///
/// Note that expressions, like everything else in this syntax tree,
/// supports [`Pretty`], and it's strongly encouraged that you use
/// that trait/module when printing these structures.
///
/// Also, Expressions at this point in the compiler are explicitly
/// defined so that they are *not* recursive. By this point, if an
/// expression requires some other data (like, for example, invoking
/// a primitive), any subexpressions have been bound to variables so
/// that the referenced data will always either be a constant or a
/// variable reference.
#[derive(Debug, PartialEq)]
pub enum Expression {
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(op.pretty(allocator))
.append(allocator.space())
.append(right)
.parens()
}
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Clone, Debug, PartialEq)]
pub enum ValueOrRef {
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
Expression::Atomic(value)
}
}
/// A constant in the IR.
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Unknown(Option<u8>, u64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::Unknown(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::U64)
}
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Variable(Location, ArcIntern<String>),
Primitive(PrimitiveType),
}
impl Type {
pub fn is_concrete(&self) -> bool {
!matches!(self, Type::Variable(_, _))
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Variable(_, x) => allocator.text(x.to_string()),
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Variable(_, x) => write!(f, "{}", x),
Type::Primitive(pt) => pt.fmt(f),
}
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
/// Generate a fresh new type; this will be a unique new type variable.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gentype() -> Type {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let name = ArcIntern::new(format!(
"t<{}>",
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
));
Type::Variable(Location::manufactured(), name)
}

359
src/type_infer/convert.rs Normal file
View File

@@ -0,0 +1,359 @@
use super::ast as ir;
use super::ast::Type;
use crate::eval::PrimitiveType;
use crate::type_infer::solve::Constraint;
use crate::syntax::{self, ConstantType};
use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the
/// program, with appropriate type variables introduced and their constraints added to
/// the given database.
///
/// If the input function has been validated (which it should be), then this should run
/// into no error conditions. However, if you failed to validate the input, then this
/// function can panic.
pub fn convert_program(
mut program: syntax::Program,
constraint_db: &mut Vec<Constraint>,
) -> ir::Program {
let mut statements = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
for stmt in program.statements.drain(..) {
statements.append(&mut convert_statement(
stmt,
constraint_db,
&mut renames,
&mut bindings,
));
}
ir::Program { statements }
}
/// This function takes a syntactic statements and converts it into a series of
/// IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> Vec<ir::Statement> {
match statement {
syntax::Statement::Print(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames
.get(&iname)
.map(Clone::clone)
.unwrap_or_else(|| iname.clone());
let varty = bindings
.get(&final_name)
.expect("print variable defined before use")
.clone();
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
vec![ir::Statement::Print(loc, varty, iname)]
}
syntax::Statement::Binding(loc, name, expr) => {
let (mut prereqs, expr, ty) =
convert_expression(expr, constraint_db, renames, bindings);
let iname = ArcIntern::new(name);
let final_name = if bindings.contains_key(&iname) {
let new_name = ir::gensym(iname.as_str());
renames.insert(iname, new_name.clone());
new_name
} else {
iname
};
bindings.insert(final_name.clone(), ty.clone());
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr));
prereqs
}
}
}
/// This function takes a syntactic expression and converts it into a series
/// of IR statements, adding type variables and constraints as necessary.
///
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_expression(
expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>,
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> (Vec<ir::Statement>, ir::Expression, Type) {
match expression {
syntax::Expression::Value(loc, val) => match val {
syntax::Value::Number(base, mctype, value) => {
let (newval, newtype) = match mctype {
None => {
let newtype = ir::gentype();
let newval = ir::Value::Unknown(base, value);
constraint_db.push(Constraint::NumericType(loc.clone(), newtype.clone()));
(newval, newtype)
}
Some(ConstantType::U8) => (
ir::Value::U8(base, value as u8),
ir::Type::Primitive(PrimitiveType::U8),
),
Some(ConstantType::U16) => (
ir::Value::U16(base, value as u16),
ir::Type::Primitive(PrimitiveType::U16),
),
Some(ConstantType::U32) => (
ir::Value::U32(base, value as u32),
ir::Type::Primitive(PrimitiveType::U32),
),
Some(ConstantType::U64) => (
ir::Value::U64(base, value),
ir::Type::Primitive(PrimitiveType::U64),
),
Some(ConstantType::I8) => (
ir::Value::I8(base, value as i8),
ir::Type::Primitive(PrimitiveType::I8),
),
Some(ConstantType::I16) => (
ir::Value::I16(base, value as i16),
ir::Type::Primitive(PrimitiveType::I16),
),
Some(ConstantType::I32) => (
ir::Value::I32(base, value as i32),
ir::Type::Primitive(PrimitiveType::I32),
),
Some(ConstantType::I64) => (
ir::Value::I64(base, value as i64),
ir::Type::Primitive(PrimitiveType::I64),
),
};
constraint_db.push(Constraint::FitsInNumType(
loc.clone(),
newtype.clone(),
value,
));
(
vec![],
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)),
newtype,
)
}
},
syntax::Expression::Reference(loc, name) => {
let iname = ArcIntern::new(name);
let final_name = renames.get(&iname).cloned().unwrap_or(iname);
let rtype = bindings
.get(&final_name)
.cloned()
.expect("variable bound before use");
let refexp =
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
(vec![], refexp, rtype)
}
syntax::Expression::Cast(loc, target, expr) => {
let (mut stmts, nexpr, etype) =
convert_expression(*expr, constraint_db, renames, bindings);
let val_or_ref = simplify_expr(nexpr, &mut stmts);
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast");
let target_type = Type::Primitive(target_prim_type);
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone()));
(stmts, res, target_type)
}
syntax::Expression::Primitive(loc, fun, mut args) => {
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
let mut stmts = vec![];
let mut nargs = vec![];
let mut atypes = vec![];
let ret_type = ir::gentype();
for arg in args.drain(..) {
let (mut astmts, aexp, atype) =
convert_expression(arg, constraint_db, renames, bindings);
stmts.append(&mut astmts);
nargs.push(simplify_expr(aexp, &mut stmts));
atypes.push(atype);
}
constraint_db.push(Constraint::ProperPrimitiveArgs(
loc.clone(),
primop,
atypes.clone(),
ret_type.clone(),
));
(
stmts,
ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs),
ret_type,
)
}
}
}
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef {
match expr {
ir::Expression::Atomic(v_or_ref) => v_or_ref,
expr => {
let etype = expr.type_of().clone();
let loc = expr.location().clone();
let nname = ir::gensym("g");
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr);
stmts.push(nbinding);
ir::ValueOrRef::Ref(loc, etype, nname)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::syntax::Location;
fn one() -> syntax::Expression {
syntax::Expression::Value(
Location::manufactured(),
syntax::Value::Number(None, None, 1),
)
}
fn vec_contains<T, F: Fn(&T) -> bool>(x: &[T], f: F) -> bool {
for x in x.iter() {
if f(x) {
return true;
}
}
false
}
fn infer_expression(
x: syntax::Expression,
) -> (ir::Expression, Vec<ir::Statement>, Vec<Constraint>, Type) {
let mut constraints = Vec::new();
let renames = HashMap::new();
let mut bindings = HashMap::new();
let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings);
(expr, stmts, constraints, ty)
}
fn infer_statement(x: syntax::Statement) -> (Vec<ir::Statement>, Vec<Constraint>) {
let mut constraints = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings);
(res, constraints)
}
#[test]
fn constant_one() {
let (expr, stmts, constraints, ty) = infer_expression(one());
assert!(stmts.is_empty());
assert!(matches!(
expr,
ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1)))
));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::NumericType(_, t) if t == &ty)
));
}
#[test]
fn one_plus_one() {
let opo = syntax::Expression::Primitive(
Location::manufactured(),
"+".to_string(),
vec![one(), one()],
);
let (expr, stmts, constraints, ty) = infer_expression(opo);
assert!(stmts.is_empty());
assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::NumericType(_, t) if t != &ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty)
));
}
#[test]
fn one_plus_one_plus_one() {
let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse");
let (stmts, constraints) = infer_statement(stmt);
assert_eq!(stmts.len(), 2);
let ir::Statement::Binding(_args, name1, temp_ty1, ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1)) = stmts.get(0).expect("item two") else {
panic!("Failed to match first statement");
};
let ir::Statement::Binding(_args, name2, temp_ty2, ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2)) = stmts.get(1).expect("item two") else {
panic!("Failed to match second statement");
};
let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] = &primargs1[..] else {
panic!("Failed to match first arguments");
};
let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] = &primargs2[..] else {
panic!("Failed to match first arguments");
};
assert_ne!(name1, name2);
assert_ne!(temp_ty1, temp_ty2);
assert_ne!(primty1, primty2);
assert_eq!(name1, left2name);
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::NumericType(_, t) if t == left1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::NumericType(_, t) if t == right1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::NumericType(_, t) if t == right2ty)
));
for (i, s) in stmts.iter().enumerate() {
println!("{}: {:?}", i, s);
}
for (i, c) in constraints.iter().enumerate() {
println!("{}: {:?}", i, c);
}
}
}

187
src/type_infer/finalize.rs Normal file
View File

@@ -0,0 +1,187 @@
use super::{ast as input, solve::TypeResolutions};
use crate::{eval::PrimitiveType, ir as output};
pub fn finalize_program(
mut program: input::Program,
resolutions: &TypeResolutions,
) -> output::Program {
output::Program {
statements: program
.statements
.drain(..)
.map(|x| finalize_statement(x, resolutions))
.collect(),
}
}
fn finalize_statement(
statement: input::Statement,
resolutions: &TypeResolutions,
) -> output::Statement {
match statement {
input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding(
loc,
var,
finalize_type(ty, resolutions),
finalize_expression(expr, resolutions),
),
input::Statement::Print(loc, ty, var) => {
output::Statement::Print(loc, finalize_type(ty, resolutions), var)
}
}
}
fn finalize_expression(
expression: input::Expression,
resolutions: &TypeResolutions,
) -> output::Expression {
match expression {
input::Expression::Atomic(val_or_ref) => {
output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions))
}
input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast(
loc,
finalize_type(target, resolutions),
finalize_val_or_ref(val_or_ref, resolutions),
),
input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive(
loc,
finalize_type(ty, resolutions),
prim,
args.drain(..)
.map(|x| finalize_val_or_ref(x, resolutions))
.collect(),
),
}
}
fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type {
match ty {
input::Type::Primitive(x) => output::Type::Primitive(x),
input::Type::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => output::Type::Primitive(*pt),
},
}
}
fn finalize_val_or_ref(
valref: input::ValueOrRef,
resolutions: &TypeResolutions,
) -> output::ValueOrRef {
match valref {
input::ValueOrRef::Ref(loc, ty, var) => {
output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var)
}
input::ValueOrRef::Value(loc, ty, val) => {
let new_type = finalize_type(ty, resolutions);
match val {
input::Value::Unknown(base, value) => match new_type {
output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U8(base, value as u8),
),
output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U16(base, value as u16),
),
output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::U32(base, value as u32),
),
output::Type::Primitive(PrimitiveType::U64) => {
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I8(base, value as i8),
),
output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I16(base, value as i16),
),
output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I32(base, value as i32),
),
output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I64(base, value as i64),
),
},
input::Value::U8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value))
}
input::Value::U16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value))
}
input::Value::U32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value))
}
input::Value::U64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::U64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
}
input::Value::I8(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value))
}
input::Value::I16(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value))
}
input::Value::I32(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value))
}
input::Value::I64(base, value) => {
assert!(matches!(
new_type,
output::Type::Primitive(PrimitiveType::I64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value))
}
}
}
}
}

490
src/type_infer/solve.rs Normal file
View File

@@ -0,0 +1,490 @@
use super::ast as ir;
use super::ast::Type;
use crate::{eval::PrimitiveType, syntax::Location};
use codespan_reporting::diagnostic::Diagnostic;
use internment::ArcIntern;
use std::{collections::HashMap, fmt};
#[derive(Debug)]
pub enum Constraint {
/// The given type must be printable using the `print` built-in
Printable(Location, Type),
/// The provided numeric value fits in the given constant type
FitsInNumType(Location, Type, u64),
/// The given primitive has the proper arguments types associated with it
ProperPrimitiveArgs(Location, ir::Primitive, Vec<Type>, Type),
/// The given type can be casted to the target type safely
CanCastTo(Location, Type, Type),
/// The given type must be some numeric type, but this is not a constant
/// value, so don't try to default it if we can't figure it out
NumericType(Location, Type),
/// The two types should be equivalent
Equivalent(Location, Type, Type),
}
impl fmt::Display for Constraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty),
Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty),
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 1 => {
write!(f, "PRIM {} {} -> {}", op, args[0], ret)
}
Constraint::ProperPrimitiveArgs(_, op, args, ret) if args.len() == 2 => {
write!(f, "PRIM {} ({}, {}) -> {}", op, args[0], args[1], ret)
}
Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret),
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
}
}
}
pub type TypeResolutions = HashMap<ArcIntern<String>, PrimitiveType>;
pub enum TypeInferenceResult<Result> {
Success {
result: Result,
warnings: Vec<TypeInferenceWarning>,
},
Failure {
errors: Vec<TypeInferenceError>,
warnings: Vec<TypeInferenceWarning>,
},
}
impl<R> TypeInferenceResult<R> {
// If this was a successful type inference, run the function over the result to
// create a new result.
pub fn map<U, F>(self, f: F) -> TypeInferenceResult<U>
where
F: FnOnce(R) -> U,
{
match self {
TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success {
result: f(result),
warnings,
},
TypeInferenceResult::Failure { errors, warnings } => {
TypeInferenceResult::Failure { errors, warnings }
}
}
}
// Return the final result, or panic if it's not a success
pub fn expect(self, msg: &str) -> R {
match self {
TypeInferenceResult::Success { result, .. } => result,
TypeInferenceResult::Failure { .. } => {
panic!("tried to get value from failed type inference: {}", msg)
}
}
}
}
pub enum TypeInferenceError {
ConstantTooLarge(Location, PrimitiveType, u64),
NotEquivalent(Location, PrimitiveType, PrimitiveType),
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
WrongPrimitiveArity(Location, ir::Primitive, usize, usize, usize),
CouldNotSolve(Constraint),
}
impl From<TypeInferenceError> for Diagnostic<usize> {
fn from(value: TypeInferenceError) -> Self {
match value {
TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc
.labelled_error("constant too large for type")
.with_message(format!(
"Type {} has a max value of {}, which is smaller than {}",
primty,
primty.max_value(),
value
)),
TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc
.labelled_error("type inference error")
.with_message(format!("Expected type {}, received type {}", ty1, ty2)),
TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc
.labelled_error("unsafe type cast")
.with_message(format!("Cannot safely cast {} to {}", ty1, ty2)),
TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc
.labelled_error("wrong number of arguments")
.with_message(format!(
"expected {} for {}, received {}",
if lower == upper && lower > 1 {
format!("{} arguments", lower)
} else if lower == upper {
format!("{} argument", lower)
} else {
format!("{}-{} arguments", lower, upper)
},
prim,
observed
)),
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if it was safe to cast from {} to {:#?}",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if {} and {:#?} were equivalent",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not determine if {} could fit in {}",
val, ty
))
}
TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if {} was a numeric type", ty)),
TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if type {} was printable", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ProperPrimitiveArgs(loc, prim, _, _)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not tell if primitive {} received the proper argument types",
prim
))
}
}
}
}
pub enum TypeInferenceWarning {
DefaultedTo(Location, Type),
}
impl From<TypeInferenceWarning> for Diagnostic<usize> {
fn from(value: TypeInferenceWarning) -> Self {
match value {
TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning()
.with_labels(vec![loc.primary_label().with_message("unknown type")])
.with_message(format!("Defaulted unknown type to {}", ty)),
}
}
}
pub fn solve_constraints(
mut constraint_db: Vec<Constraint>,
) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![];
let mut warnings = vec![];
let mut resolutions = HashMap::new();
let mut changed_something = true;
// We want to run this inference endlessly, until either we have solved all of our
// constraints. Internal to the loop, we have a check that will make sure that we
// do (eventually) stop.
while changed_something && !constraint_db.is_empty() {
println!("-------CONSTRAINTS---------");
for constraint in constraint_db.iter() {
println!("{}", constraint);
}
println!("---------------------------");
// Set this to false at the top of the loop. We'll set this to true if we make
// progress in any way further down, but having this here prevents us from going
// into an infinite look when we can't figure stuff out.
changed_something = false;
// This is sort of a double-buffering thing; we're going to rename constraint_db
// and then set it to a new empty vector, which we'll add to as we find new
// constraints or find ourselves unable to solve existing ones.
let mut local_constraints = constraint_db;
constraint_db = vec![];
// OK. First thing we're going to do is run through all of our constraints,
// and see if we can solve any, or reduce them to theoretically more simple
// constraints.
for constraint in local_constraints.drain(..) {
match constraint {
// Currently, all of our types are printable
Constraint::Printable(_loc, _ty) => changed_something = true,
// Case #1: We have two primitive types. If they're equal, we've discharged this
// constraint! We can just continue. If they're not equal, add an error and then
// see what else we come up with.
Constraint::Equivalent(loc, Type::Primitive(t1), Type::Primitive(t2)) => {
if t1 != t2 {
errors.push(TypeInferenceError::NotEquivalent(loc, t1, t2));
}
changed_something = true;
}
// Case #2: One of the two constraints is a primitive, and the other is a variable.
// In this case, we'll check to see if we've resolved the variable, and check for
// equivalence if we have. If we haven't, we'll set that variable to be primitive
// type.
Constraint::Equivalent(loc, Type::Primitive(t), Type::Variable(_, name))
| Constraint::Equivalent(loc, Type::Variable(_, name), Type::Primitive(t)) => {
match resolutions.get(&name) {
None => {
resolutions.insert(name, t);
}
Some(t2) if &t == t2 => {}
Some(t2) => errors.push(TypeInferenceError::NotEquivalent(loc, t, *t2)),
}
changed_something = true;
}
// Case #3: They're both variables. In which case, we'll have to do much the same
// check, but now on their resolutions.
Constraint::Equivalent(
ref loc,
Type::Variable(_, ref name1),
Type::Variable(_, ref name2),
) => match (resolutions.get(name1), resolutions.get(name2)) {
(None, None) => {
constraint_db.push(constraint);
}
(Some(pt), None) => {
resolutions.insert(name2.clone(), *pt);
changed_something = true;
}
(None, Some(pt)) => {
resolutions.insert(name1.clone(), *pt);
changed_something = true;
}
(Some(pt1), Some(pt2)) if pt1 == pt2 => {
changed_something = true;
}
(Some(pt1), Some(pt2)) => {
errors.push(TypeInferenceError::NotEquivalent(loc.clone(), *pt1, *pt2));
changed_something = true;
}
},
// Make sure that the provided number fits within the provided constant type. For the
// moment, we're going to call an error here a failure, although this could be a
// warning in the future.
Constraint::FitsInNumType(loc, Type::Primitive(ctype), val) => {
if ctype.max_value() < val {
errors.push(TypeInferenceError::ConstantTooLarge(loc, ctype, val));
}
changed_something = true;
}
// If we have a non-constant type, then let's see if we can advance this to a constant
// type
Constraint::FitsInNumType(loc, Type::Variable(vloc, var), val) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Variable(vloc, var),
val,
)),
Some(nt) => {
constraint_db.push(Constraint::FitsInNumType(
loc,
Type::Primitive(*nt),
val,
));
changed_something = true;
}
}
}
// If the left type in a "can cast to" check is a variable, let's see if we can advance
// it into something more tangible
Constraint::CanCastTo(loc, Type::Variable(vloc, var), to_type) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
Type::Variable(vloc, var),
to_type,
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
Type::Primitive(*nt),
to_type,
));
changed_something = true;
}
}
}
// If the right type in a "can cast to" check is a variable, same deal
Constraint::CanCastTo(loc, from_type, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Variable(vloc, var),
)),
Some(nt) => {
constraint_db.push(Constraint::CanCastTo(
loc,
from_type,
Type::Primitive(*nt),
));
changed_something = true;
}
}
}
// If both of them are types, then we can actually do the test. yay!
Constraint::CanCastTo(
loc,
Type::Primitive(from_type),
Type::Primitive(to_type),
) => {
if !from_type.can_cast_to(&to_type) {
errors.push(TypeInferenceError::CannotSafelyCast(
loc, from_type, to_type,
));
}
changed_something = true;
}
// As per usual, if we're trying to test if a type variable is numeric, first
// we try to advance it to a primitive
Constraint::NumericType(loc, Type::Variable(vloc, var)) => {
match resolutions.get(&var) {
None => constraint_db
.push(Constraint::NumericType(loc, Type::Variable(vloc, var))),
Some(nt) => {
constraint_db.push(Constraint::NumericType(loc, Type::Primitive(*nt)));
changed_something = true;
}
}
}
// Of course, if we get to a primitive type, then it's true, because all of our
// primitive types are numbers
Constraint::NumericType(_, Type::Primitive(_)) => {
changed_something = true;
}
// OK, this one could be a little tricky if we tried to do it all at once, but
// instead what we're going to do here is just use this constraint to generate
// a bunch more constraints, and then go have the engine solve those. The only
// real errors we're going to come up with here are "arity errors"; errors we
// find by discovering that the number of arguments provided doesn't make sense
// given the primitive being used.
Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim {
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide
if args.len() != 2 =>
{
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
2,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc, left, ret));
changed_something = true;
}
ir::Primitive::Minus if args.is_empty() || args.len() > 2 => {
errors.push(TypeInferenceError::WrongPrimitiveArity(
loc,
prim,
1,
2,
args.len(),
));
changed_something = true;
}
ir::Primitive::Minus if args.len() == 1 => {
let arg = args.pop().expect("1 > 0");
constraint_db.push(Constraint::NumericType(loc.clone(), arg.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(loc, arg, ret));
changed_something = true;
}
ir::Primitive::Minus => {
let right = args.pop().expect("2 > 0");
let left = args.pop().expect("2 > 1");
// technically testing that both are numeric is redundant, but it might give
// a slightly helpful type error if we do both.
constraint_db.push(Constraint::NumericType(loc.clone(), left.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), right.clone()));
constraint_db.push(Constraint::NumericType(loc.clone(), ret.clone()));
constraint_db.push(Constraint::Equivalent(
loc.clone(),
left.clone(),
right,
));
constraint_db.push(Constraint::Equivalent(loc.clone(), left, ret));
changed_something = true;
}
},
}
}
// If that didn't actually come up with anything, and we just recycled all the constraints
// back into the database unchanged, then let's take a look for cases in which we just
// wanted something we didn't know to be a number. Basically, those are cases where the
// user just wrote a number, but didn't tell us what type it was, and there isn't enough
// information in the context to tell us. If that happens, we'll just set that type to
// be u64, and warn the user that we did so.
if !changed_something && !constraint_db.is_empty() {
local_constraints = constraint_db;
constraint_db = vec![];
for constraint in local_constraints.drain(..) {
match constraint {
Constraint::NumericType(loc, t @ Type::Variable(_, _)) => {
let resty = Type::Primitive(PrimitiveType::U64);
constraint_db.push(Constraint::Equivalent(
loc.clone(),
t,
Type::Primitive(PrimitiveType::U64),
));
warnings.push(TypeInferenceWarning::DefaultedTo(loc, resty));
changed_something = true;
}
_ => constraint_db.push(constraint),
}
}
}
}
// OK, we left our loop. Which means that either we solved everything, or we didn't.
// If we didn't, turn the unsolved constraints into type inference errors, and add
// them to our error list.
let mut unsolved_constraint_errors = constraint_db
.drain(..)
.map(TypeInferenceError::CouldNotSolve)
.collect();
errors.append(&mut unsolved_constraint_errors);
// How'd we do?
if errors.is_empty() {
TypeInferenceResult::Success {
result: resolutions,
warnings,
}
} else {
TypeInferenceResult::Failure { errors, warnings }
}
}