λ Support functions! #5

Open
acw wants to merge 59 commits from awick/functions into develop
12 changed files with 572 additions and 94 deletions
Showing only changes of commit a7b85d37da - Show all commits

View File

@@ -33,6 +33,9 @@ impl<M: Module> Backend<M> {
types::Type::triple_pointer_type(&self.platform),
ir::ArgumentExtension::None,
),
Type::Structure(_) => {
unimplemented!()
}
Type::Primitive(PrimitiveType::Void) => (types::I8, ir::ArgumentExtension::None), // FIXME?
Type::Primitive(PrimitiveType::I8) => (types::I8, ir::ArgumentExtension::Sext),
Type::Primitive(PrimitiveType::I16) => (types::I16, ir::ArgumentExtension::Sext),
@@ -88,6 +91,10 @@ impl<M: Module> Backend<M> {
self.defined_symbols
.insert(top_level_name, (data_id, pt.into()));
}
Type::Structure(_) => {
unimplemented!()
}
}
}
@@ -392,6 +399,9 @@ impl<M: Module> Backend<M> {
}
}
Expression::Construct(_, _, _, _) => unimplemented!(),
Expression::FieldRef(_, _, _, _) => unimplemented!(),
Expression::Block(_, _, mut exprs) => match exprs.pop() {
None => Ok((builder.ins().iconst(types::I64, 0), ConstantType::Void)),
Some(last) => {
@@ -511,8 +521,10 @@ impl<M: Module> Backend<M> {
Type::Function(_, _) => {
panic!("function returns a function?")
}
Type::Structure(_) => unimplemented!(),
Type::Primitive(ct) => Ok((*result, ct.into())),
},
Type::Structure(_) => unimplemented!(),
},
_ => panic!("don't support multi-value returns yet"),
}

View File

@@ -9,6 +9,7 @@ use proptest::test_runner::{TestRng, TestRunner};
use rand::distributions::{Distribution, WeightedIndex};
use rand::seq::SliceRandom;
use rand::Rng;
use std::collections::HashMap;
use std::str::FromStr;
lazy_static::lazy_static! {
@@ -214,7 +215,10 @@ impl ProgramTree {
}
}
let current = Program { items };
let current = Program {
items,
type_definitions: HashMap::new(),
};
ProgramTree { _rng: rng, current }
}
@@ -328,6 +332,7 @@ fn generate_random_expression(
.expect("actually chose type");
Expression::Cast(Location::manufactured(), Type::Primitive(*to_type), inner)
}
Type::Structure(_) => unimplemented!(),
}
}
@@ -350,6 +355,7 @@ fn generate_random_expression(
Expression::Primitive(Location::manufactured(), out_type, primop, args)
}
},
Type::Structure(_) => unimplemented!(),
}
}

View File

@@ -2,6 +2,7 @@ use crate::eval::PrimitiveType;
use crate::syntax::{ConstantType, Location};
use internment::ArcIntern;
use proptest::arbitrary::Arbitrary;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::str::FromStr;
use std::sync::atomic::AtomicUsize;
@@ -50,7 +51,9 @@ pub fn gensym(base: &str) -> Variable {
pub struct Program<Type> {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) items: Vec<TopLevel<Type>>,
pub items: Vec<TopLevel<Type>>,
// The set of types declared in this program.
pub type_definitions: HashMap<ArcIntern<String>, Type>,
}
impl Arbitrary for Program<Type> {
@@ -103,6 +106,13 @@ pub enum Expression<Type> {
Atomic(ValueOrRef<Type>),
Cast(Location, Type, ValueOrRef<Type>),
Primitive(Location, Type, Primitive, Vec<ValueOrRef<Type>>),
Construct(
Location,
Type,
ArcIntern<String>,
HashMap<ArcIntern<String>, ValueOrRef<Type>>,
),
FieldRef(Location, Type, ValueOrRef<Type>, ArcIntern<String>),
Block(Location, Type, Vec<Expression<Type>>),
Print(Location, ValueOrRef<Type>),
Call(Location, Type, Box<ValueOrRef<Type>>, Vec<ValueOrRef<Type>>),
@@ -117,6 +127,8 @@ impl<Type: Clone + TypeWithVoid> Expression<Type> {
Expression::Atomic(x) => x.type_of(),
Expression::Cast(_, t, _) => t.clone(),
Expression::Primitive(_, t, _, _) => t.clone(),
Expression::Construct(_, t, _, _) => t.clone(),
Expression::FieldRef(_, t, _, _) => t.clone(),
Expression::Block(_, t, _) => t.clone(),
Expression::Print(_, _) => Type::void(),
Expression::Call(_, t, _, _) => t.clone(),
@@ -131,6 +143,8 @@ impl<Type: Clone + TypeWithVoid> Expression<Type> {
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
Expression::Construct(l, _, _, _) => l,
Expression::FieldRef(l, _, _, _) => l,
Expression::Block(l, _, _) => l,
Expression::Print(l, _) => l,
Expression::Call(l, _, _, _) => l,
@@ -233,6 +247,7 @@ impl Value {
pub enum Type {
Primitive(PrimitiveType),
Function(Vec<Type>, Box<Type>),
Structure(HashMap<ArcIntern<String>, Type>),
}
impl Type {
@@ -256,6 +271,7 @@ impl<'a> TryInto<ConstantType> for &'a Type {
match self {
Type::Primitive(pt) => Ok((*pt).into()),
Type::Function(_, _) => Err(()),
Type::Structure(_) => Err(()),
}
}
}
@@ -265,6 +281,7 @@ pub enum TypeOrVar {
Primitive(PrimitiveType),
Variable(Location, ArcIntern<String>),
Function(Vec<TypeOrVar>, Box<TypeOrVar>),
Structure(HashMap<ArcIntern<String>, TypeOrVar>),
}
impl Default for TypeOrVar {
@@ -311,6 +328,10 @@ impl TypeOrVar {
}
TypeOrVar::Primitive(_) => false,
TypeOrVar::Structure(fields) => {
fields.values_mut().any(|x| x.replace(name, replace_with))
}
}
}
@@ -323,6 +344,7 @@ impl TypeOrVar {
TypeOrVar::Function(args, ret) => {
args.iter().all(TypeOrVar::is_resolved) && ret.is_resolved()
}
TypeOrVar::Structure(fields) => fields.values().all(TypeOrVar::is_resolved),
}
}
}
@@ -339,6 +361,16 @@ impl PartialEq<Type> for TypeOrVar {
TypeOrVar::Primitive(x) => a == x,
_ => false,
},
Type::Structure(fields1) => match self {
TypeOrVar::Structure(fields2) => {
fields1.len() == fields2.len()
&& fields1.iter().all(|(name, subtype)| {
fields2.get(name).map(|x| x == subtype).unwrap_or(false)
})
}
_ => false,
},
}
}
}
@@ -386,6 +418,9 @@ impl<T: Into<Type>> From<T> for TypeOrVar {
args.into_iter().map(Into::into).collect(),
Box::new((*ret).into()),
),
Type::Structure(fields) => {
TypeOrVar::Structure(fields.into_iter().map(|(n, t)| (n, t.into())).collect())
}
}
}
}
@@ -413,7 +448,22 @@ impl TryFrom<TypeOrVar> for Type {
}
TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)),
_ => Err(value),
TypeOrVar::Structure(fields) => {
let mut new_fields = HashMap::with_capacity(fields.len());
for (name, field) in fields.iter() {
if let Ok(new_field) = field.clone().try_into() {
new_fields.insert(name.clone(), new_field);
} else {
return Err(TypeOrVar::Structure(fields));
}
}
Ok(Type::Structure(new_fields))
}
TypeOrVar::Variable(_, _) => Err(value),
}
}
}

View File

@@ -2,6 +2,7 @@ use super::{Primitive, Type, ValueOrRef};
use crate::eval::{EvalError, Value};
use crate::ir::{Expression, Program, TopLevel, Variable};
use crate::util::scoped_map::ScopedMap;
use std::collections::HashMap;
type IRValue<T> = Value<Expression<T>>;
type IREvalError<T> = EvalError<Expression<T>>;
@@ -60,6 +61,7 @@ where
match ty {
Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
Type::Function(_, _) => Err(EvalError::CastToFunction(ty.to_string())),
Type::Structure(_) => unimplemented!(),
}
}
@@ -79,6 +81,29 @@ where
}
}
Expression::Construct(_, _, name, fields) => {
let mut result_fields = HashMap::with_capacity(fields.len());
for (name, subexpr) in fields.iter() {
result_fields.insert(name.clone(), subexpr.eval(env)?);
}
Ok(Value::Structure(Some(name.clone()), result_fields))
}
Expression::FieldRef(loc, _, valref, field) => match valref.eval(env)? {
Value::Structure(oname, mut fields) => match fields.remove(field) {
None => Err(EvalError::NoFieldForValue(
loc.clone(),
Value::Structure(oname, fields),
field.clone(),
)),
Some(value) => Ok(value),
},
x => Err(EvalError::NoFieldForValue(loc.clone(), x, field.clone())),
},
Expression::Block(_, _, stmts) => {
let mut result = Value::Void;

View File

@@ -56,6 +56,28 @@ impl Expression<Type> {
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Construct(_, _, name, fields) => {
let inner = allocator
.intersperse(
fields.iter().map(|(k, v)| {
allocator
.text(k.to_string())
.append(":")
.append(allocator.space())
.append(v.pretty(allocator))
.append(allocator.text(";"))
}),
allocator.line(),
)
.indent(2)
.braces();
allocator.text(name.to_string()).append(inner)
}
Expression::FieldRef(_, _, val, field) => val.pretty(allocator).append(
allocator
.text(".")
.append(allocator.text(field.to_string())),
),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
@@ -180,6 +202,18 @@ impl Type {
match self {
Type::Function(args, rettype) => pretty_function_type!(allocator, args, rettype),
Type::Primitive(prim) => prim.pretty(allocator),
Type::Structure(fields) => allocator.text("struct").append(
allocator
.concat(fields.iter().map(|(n, t)| {
allocator
.text(n.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(t.pretty(allocator))
.append(allocator.text(";"))
}))
.braces(),
),
}
}
}
@@ -190,6 +224,18 @@ impl TypeOrVar {
TypeOrVar::Function(args, rettype) => pretty_function_type!(allocator, args, rettype),
TypeOrVar::Primitive(prim) => prim.pretty(allocator),
TypeOrVar::Variable(_, name) => allocator.text(name.to_string()),
TypeOrVar::Structure(fields) => allocator.text("struct").append(
allocator
.concat(fields.iter().map(|(n, t)| {
allocator
.text(n.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(t.pretty(allocator))
.append(allocator.text(";"))
}))
.braces(),
),
}
}
}

View File

@@ -33,7 +33,7 @@ pub enum TopLevel {
Option<Type>,
Expression,
),
Structure(Location, Option<Name>, Vec<(Name, Type)>),
Structure(Location, Name, Vec<(Name, Type)>),
}
/// A Name.
@@ -212,5 +212,5 @@ pub enum Value {
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Type {
Named(Name),
Struct(Option<Name>, Vec<(Option<Name>, Option<Type>)>),
Struct(Vec<(Name, Option<Type>)>),
}

View File

@@ -100,8 +100,8 @@ OptionalComma: () = {
}
Structure: TopLevel = {
<s:@L> "struct" <on: TypeName?> "{" <fields: Field*> "}" <e:@L> => {
TopLevel::Structure(Location::new(file_idx, s..e), on, fields)
<s:@L> "struct" <n: TypeName> "{" <fields: Field*> "}" <e:@L> => {
TopLevel::Structure(Location::new(file_idx, s..e), n, fields)
}
}
@@ -113,14 +113,13 @@ Field: (Name, Type) = {
Type: Type = {
<name:Name> => Type::Named(name),
<t:TypeName> => Type::Named(t),
"struct" <on: TypeName?> "{" <fields: TypeField*> "}" =>
Type::Struct(on, fields),
"struct" "{" <fields: TypeField*> "}" =>
Type::Struct(fields),
}
TypeField: (Option<Name>, Option<Type>) = {
<name: Name> ":" <ty: Type> ";" => (Some(name), Some(ty)),
<name: Name> (":" "_")? ";" => (Some(name), None),
"_" ":" <ty: Type> ";" => (None, Some(ty)),
TypeField: (Name, Option<Type>) = {
<name: Name> ":" <ty: Type> ";" => (name, Some(ty)),
<name: Name> (":" "_")? ";" => (name, None),
}
Name: Name = {

View File

@@ -65,11 +65,7 @@ impl TopLevel {
TopLevel::Structure(_, name, fields) => allocator
.text("struct")
.append(allocator.space())
.append(
name.as_ref()
.map(|x| allocator.text(x.to_string()))
.unwrap_or_else(|| allocator.nil()),
)
.append(allocator.text(name.to_string()))
.append(allocator.space())
.append(allocator.text("{"))
.append(allocator.hardline())
@@ -224,22 +220,13 @@ impl Type {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
Type::Named(x) => allocator.text(x.to_string()),
Type::Struct(name, fields) => allocator
Type::Struct(fields) => allocator
.text("struct")
.append(allocator.space())
.append(
name.as_ref()
.map(|x| allocator.text(x.to_string()))
.unwrap_or_else(|| allocator.nil()),
)
.append(allocator.intersperse(
fields.iter().map(|(name, ty)| {
allocator
.text(
name.as_ref()
.map(|x| x.to_string())
.unwrap_or_else(|| "_".to_string()),
)
.text(name.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(

View File

@@ -33,7 +33,7 @@ impl syntax::Program {
/// this method, otherwise you may experience panics during operation.
pub fn type_infer(self) -> TypeInferenceResult<ir::Program<ir::Type>> {
let (program, constraint_db) = convert_program(self);
let inference_result = solve_constraints(constraint_db);
let inference_result = solve_constraints(&program.type_definitions, constraint_db);
inference_result.map(|resolutions| finalize_program(program, &resolutions))
}

View File

@@ -7,6 +7,11 @@ use internment::ArcIntern;
use std::collections::HashMap;
use std::str::FromStr;
enum TopLevelItem {
Type(ArcIntern<String>, ir::TypeOrVar),
Expression(ir::TopLevel<ir::TypeOrVar>),
}
/// This function takes a syntactic program and converts it into the IR version of the
/// program, with appropriate type variables introduced and their constraints added to
/// the given database.
@@ -21,28 +26,37 @@ pub fn convert_program(
let mut items = Vec::new();
let mut renames = ScopedMap::new();
let mut bindings = HashMap::new();
let mut type_definitions = HashMap::new();
for item in program.items.drain(..) {
items.push(convert_top_level(
item,
&mut constraint_db,
&mut renames,
&mut bindings,
));
let tli = convert_top_level(item, &mut constraint_db, &mut renames, &mut bindings);
match tli {
TopLevelItem::Expression(item) => items.push(item),
TopLevelItem::Type(name, decl) => {
let _ = type_definitions.insert(name, decl);
}
}
}
(ir::Program { items }, constraint_db)
(
ir::Program {
items,
type_definitions,
},
constraint_db,
)
}
/// This function takes a top-level item and converts it into the IR version of the
/// program, with all the appropriate type variables introduced and their constraints
/// added to the given database.
pub fn convert_top_level(
fn convert_top_level(
top_level: syntax::TopLevel,
constraint_db: &mut Vec<Constraint>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, ir::TypeOrVar>,
) -> ir::TopLevel<ir::TypeOrVar> {
) -> TopLevelItem {
match top_level {
syntax::TopLevel::Function(name, args, _, expr) => {
// First, at some point we're going to want to know a location for this function,
@@ -75,8 +89,8 @@ pub fn convert_top_level(
// appropriately.
renames.new_scope();
let arginfo = args
.iter()
.map(|(name, _)| {
.into_iter()
.map(|(name, mut declared_type)| {
let new_type = ir::TypeOrVar::new();
constraint_db.push(Constraint::IsSomething(
name.location.clone(),
@@ -84,7 +98,16 @@ pub fn convert_top_level(
));
let new_name = finalize_name(bindings, renames, name.clone());
bindings.insert(new_name.clone(), new_type.clone());
unimplemented!();
if let Some(declared_type) = declared_type.take() {
let declared_type = convert_type(declared_type, constraint_db);
constraint_db.push(Constraint::Equivalent(
name.location.clone(),
new_type.clone(),
declared_type,
));
}
(new_name, new_type)
})
.collect::<Vec<_>>();
@@ -114,16 +137,27 @@ pub fn convert_top_level(
// Remember to exit this scoping level!
renames.release_scope();
ir::TopLevel::Function(function_name, arginfo, rettype, expr)
TopLevelItem::Expression(ir::TopLevel::Function(
function_name,
arginfo,
rettype,
expr,
))
}
syntax::TopLevel::Statement(stmt) => {
ir::TopLevel::Statement(convert_statement(stmt, constraint_db, renames, bindings))
}
syntax::TopLevel::Statement(stmt) => TopLevelItem::Expression(ir::TopLevel::Statement(
convert_statement(stmt, constraint_db, renames, bindings),
)),
syntax::TopLevel::Structure(loc, oname, fields) => {
unimplemented!()
}
syntax::TopLevel::Structure(_loc, name, fields) => TopLevelItem::Type(
name.intern(),
ir::TypeOrVar::Structure(
fields
.into_iter()
.map(|(name, t)| (name.intern(), convert_type(t, constraint_db)))
.collect(),
),
),
}
}
@@ -258,7 +292,31 @@ fn convert_expression(
}
},
syntax::Expression::Constructor(_, _, _) => unimplemented!(),
syntax::Expression::Constructor(loc, name, fields) => {
let mut result_fields = HashMap::new();
let mut type_fields = HashMap::new();
let mut prereqs = vec![];
let result_type = ir::TypeOrVar::new();
for (name, syntax_expr) in fields.into_iter() {
let (ir_expr, expr_type) =
convert_expression(syntax_expr, constraint_db, renames, bindings);
type_fields.insert(name.clone().intern(), expr_type);
let (prereq, value) = simplify_expr(ir_expr);
result_fields.insert(name.clone().intern(), value);
merge_prereq(&mut prereqs, prereq);
}
constraint_db.push(Constraint::NamedTypeIs(
loc.clone(),
name.clone().intern(),
ir::TypeOrVar::Structure(type_fields),
));
let result =
ir::Expression::Construct(loc, result_type.clone(), name.intern(), result_fields);
(finalize_expressions(prereqs, result), result_type)
}
syntax::Expression::Reference(loc, name) => {
let iname = ArcIntern::new(name);
@@ -273,7 +331,26 @@ fn convert_expression(
(refexp, rtype)
}
syntax::Expression::FieldRef(_, _, _) => unimplemented!(),
syntax::Expression::FieldRef(loc, expr, field) => {
let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings);
let (prereqs, val_or_ref) = simplify_expr(nexpr);
let result_type = ir::TypeOrVar::new();
let result = ir::Expression::FieldRef(
loc.clone(),
result_type.clone(),
val_or_ref,
field.clone().intern(),
);
constraint_db.push(Constraint::TypeHasField(
loc,
etype,
field.intern(),
result_type.clone(),
));
(finalize_expression(prereqs, result), result_type)
}
syntax::Expression::Cast(loc, target, expr) => {
let (nexpr, etype) = convert_expression(*expr, constraint_db, renames, bindings);
@@ -366,15 +443,7 @@ fn convert_expression(
let last_call =
ir::Expression::Call(loc.clone(), return_type.clone(), Box::new(fun), new_args);
if prereqs.is_empty() {
(last_call, return_type)
} else {
prereqs.push(last_call);
(
ir::Expression::Block(loc, return_type.clone(), prereqs),
return_type,
)
}
(finalize_expressions(prereqs, last_call), return_type)
}
syntax::Expression::Block(loc, stmts) => {
@@ -396,6 +465,35 @@ fn convert_expression(
}
}
fn convert_type(ty: syntax::Type, constraint_db: &mut Vec<Constraint>) -> ir::TypeOrVar {
match ty {
syntax::Type::Named(x) => match PrimitiveType::from_str(x.name.as_str()) {
Err(_) => {
let retval = ir::TypeOrVar::new_located(x.location.clone());
constraint_db.push(Constraint::NamedTypeIs(
x.location.clone(),
x.intern(),
retval.clone(),
));
retval
}
Ok(v) => ir::TypeOrVar::Primitive(v),
},
syntax::Type::Struct(fields) => ir::TypeOrVar::Structure(
fields
.into_iter()
.map(|(n, t)| {
(
n.intern(),
t.map(|x| convert_type(x, constraint_db))
.unwrap_or_else(ir::TypeOrVar::new),
)
})
.collect(),
),
}
}
fn simplify_expr(
expr: ir::Expression<ir::TypeOrVar>,
) -> (
@@ -431,6 +529,20 @@ fn finalize_expression(
}
}
fn finalize_expressions(
mut prereqs: Vec<ir::Expression<ir::TypeOrVar>>,
actual: ir::Expression<ir::TypeOrVar>,
) -> ir::Expression<ir::TypeOrVar> {
if prereqs.is_empty() {
actual
} else {
let return_type = actual.type_of();
let loc = actual.location().clone();
prereqs.push(actual);
ir::Expression::Block(loc, return_type, prereqs)
}
}
fn finalize_name(
bindings: &HashMap<ArcIntern<String>, ir::TypeOrVar>,
renames: &mut ScopedMap<ArcIntern<String>, ArcIntern<String>>,

View File

@@ -3,7 +3,7 @@ use crate::eval::PrimitiveType;
use crate::ir::{Expression, Program, TopLevel, Type, TypeOrVar, Value, ValueOrRef};
pub fn finalize_program(
mut program: Program<TypeOrVar>,
program: Program<TypeOrVar>,
resolutions: &TypeResolutions,
) -> Program<Type> {
for (name, ty) in resolutions.iter() {
@@ -13,9 +13,15 @@ pub fn finalize_program(
Program {
items: program
.items
.drain(..)
.into_iter()
.map(|x| finalize_top_level(x, resolutions))
.collect(),
type_definitions: program
.type_definitions
.into_iter()
.map(|(n, t)| (n, finalize_type(t, resolutions)))
.collect(),
}
}
@@ -57,6 +63,23 @@ fn finalize_expression(
.collect(),
),
Expression::Construct(loc, ty, name, fields) => Expression::Construct(
loc,
finalize_type(ty, resolutions),
name,
fields
.into_iter()
.map(|(k, v)| (k, finalize_val_or_ref(v, resolutions)))
.collect(),
),
Expression::FieldRef(loc, ty, valref, field) => Expression::FieldRef(
loc,
finalize_type(ty, resolutions),
finalize_val_or_ref(valref, resolutions),
field,
),
Expression::Block(loc, ty, exprs) => {
let mut final_exprs = Vec::with_capacity(exprs.len());
@@ -111,6 +134,12 @@ fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type {
.collect(),
Box::new(finalize_type(*ret, resolutions)),
),
TypeOrVar::Structure(fields) => Type::Structure(
fields
.into_iter()
.map(|(name, subtype)| (name, finalize_type(subtype, resolutions)))
.collect(),
),
}
}
@@ -129,6 +158,9 @@ fn finalize_val_or_ref(
Type::Function(_, _) => {
panic!("Somehow inferred that a constant was a function")
}
Type::Structure(_) => {
panic!("Somehow inferred that a constant was a structure")
}
Type::Primitive(PrimitiveType::Void) => {
panic!("Somehow inferred that a constant was void")
}

View File

@@ -16,6 +16,9 @@ pub enum Constraint {
ProperPrimitiveArgs(Location, Primitive, Vec<TypeOrVar>, TypeOrVar),
/// The given type can be casted to the target type safely
CanCastTo(Location, TypeOrVar, TypeOrVar),
/// The given type has the given field in it, and the type of that field
/// is as given.
TypeHasField(Location, TypeOrVar, ArcIntern<String>, TypeOrVar),
/// The given type must be some numeric type, but this is not a constant
/// value, so don't try to default it if we can't figure it out
NumericType(Location, TypeOrVar),
@@ -29,6 +32,8 @@ pub enum Constraint {
IsSomething(Location, TypeOrVar),
/// The given type can be negated
IsSigned(Location, TypeOrVar),
/// Checks to see if the given named type is equivalent to the provided one.
NamedTypeIs(Location, ArcIntern<String>, TypeOrVar),
}
impl fmt::Display for Constraint {
@@ -44,11 +49,15 @@ impl fmt::Display for Constraint {
}
Constraint::ProperPrimitiveArgs(_, op, _, ret) => write!(f, "PRIM {} -> {}", op, ret),
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
Constraint::TypeHasField(_, ty1, field, ty2) => {
write!(f, "FIELD {}.{} -> {}", ty1, field, ty2)
}
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty),
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
Constraint::IsSomething(_, ty) => write!(f, "SOMETHING {}", ty),
Constraint::IsSigned(_, ty) => write!(f, "SIGNED {}", ty),
Constraint::NamedTypeIs(_, name, ty) => write!(f, "TYPE_EQUIV {} == {}", name, ty),
}
}
}
@@ -65,6 +74,9 @@ impl Constraint {
Constraint::CanCastTo(_, ty1, ty2) => {
ty1.replace(name, replace_with) || ty2.replace(name, replace_with)
}
Constraint::TypeHasField(_, ty1, _, ty2) => {
ty1.replace(name, replace_with) || ty2.replace(name, replace_with)
}
Constraint::ConstantNumericType(_, ty) => ty.replace(name, replace_with),
Constraint::Equivalent(_, ty1, ty2) => {
ty1.replace(name, replace_with) || ty2.replace(name, replace_with)
@@ -76,6 +88,7 @@ impl Constraint {
ret.replace(name, replace_with)
| args.iter_mut().any(|x| x.replace(name, replace_with))
}
Constraint::NamedTypeIs(_, name, ty) => ty.replace(name, replace_with),
}
}
}
@@ -142,21 +155,22 @@ pub enum TypeInferenceError {
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
/// The primitive invocation provided the wrong number of arguments.
WrongPrimitiveArity(Location, Primitive, usize, usize, usize),
/// We cannot cast between the given function types, usually because they
/// have different argument lengths
CannotCastBetweenFunctinoTypes(Location, TypeOrVar, TypeOrVar),
/// We cannot cast from a function type to something else.
CannotCastFromFunctionType(Location, TypeOrVar),
/// We cannot cast to a function type from something else.
CannotCastToFunctionType(Location, TypeOrVar),
/// We cannot cast between the type types, for any number of reasons
CannotCast(Location, TypeOrVar, TypeOrVar),
/// We cannot turn a number into a function.
CannotMakeNumberAFunction(Location, TypeOrVar, Option<u64>),
/// We cannot turn a number into a Structure.
CannotMakeNumberAStructure(Location, TypeOrVar, Option<u64>),
/// We had a constraint we just couldn't solve.
CouldNotSolve(Constraint),
/// Functions are not printable.
FunctionsAreNotPrintable(Location),
/// The given type isn't signed, and can't be negated
IsNotSigned(Location, PrimitiveType),
IsNotSigned(Location, TypeOrVar),
/// The given type doesn't have the given field.
NoFieldForType(Location, ArcIntern<String>, TypeOrVar),
/// There is no type with the given name.
UnknownTypeName(Location, ArcIntern<String>),
}
impl From<TypeInferenceError> for Diagnostic<usize> {
@@ -196,22 +210,12 @@ impl From<TypeInferenceError> for Diagnostic<usize> {
prim,
observed
)),
TypeInferenceError::CannotCastBetweenFunctinoTypes(loc, t1, t2) => loc
.labelled_error("cannot cast between function types")
TypeInferenceError::CannotCast(loc, t1, t2) => loc
.labelled_error("cannot cast between types")
.with_message(format!(
"tried to cast from {} to {}",
t1, t2,
)),
TypeInferenceError::CannotCastFromFunctionType(loc, t) => loc
.labelled_error("cannot cast from a function type to anything else")
.with_message(format!(
"function type was {}", t,
)),
TypeInferenceError::CannotCastToFunctionType(loc, t) => loc
.labelled_error("cannot cast to a function type")
.with_message(format!(
"function type was {}", t,
)),
TypeInferenceError::CannotMakeNumberAFunction(loc, t, val) => loc
.labelled_error(if let Some(val) = val {
format!("cannot turn {} into a function", val)
@@ -219,17 +223,32 @@ impl From<TypeInferenceError> for Diagnostic<usize> {
"cannot use a constant as a function type".to_string()
})
.with_message(format!("function type was {}", t)),
TypeInferenceError::CannotMakeNumberAStructure(loc, t, val) => loc
.labelled_error(if let Some(val) = val {
format!("cannot turn {} into a function", val)
} else {
"cannot use a constant as a function type".to_string()
})
.with_message(format!("function type was {}", t)),
TypeInferenceError::FunctionsAreNotPrintable(loc) => loc
.labelled_error("cannot print function values"),
TypeInferenceError::IsNotSigned(loc, pt) => loc
.labelled_error(format!("type {} is not signed", pt))
.with_message("and so it cannot be negated"),
TypeInferenceError::NoFieldForType(loc, field, t) => loc
.labelled_error(format!("no field {} available for type {}", field, t)),
TypeInferenceError::UnknownTypeName(loc , name) => loc
.labelled_error(format!("unknown type named {}", name)),
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if it was safe to cast from {} to {:#?}",
"could not determine if it was safe to cast from {} to {}",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::TypeHasField(loc, a, field, _)) => {
loc.labelled_error("internal error")
.with_message(format!("fould not determine if type {} has field {}", a, field))
}
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if {} and {} were equivalent",
@@ -263,6 +282,9 @@ impl From<TypeInferenceError> for Diagnostic<usize> {
TypeInferenceError::CouldNotSolve(Constraint::IsSigned(loc, t)) => loc
.labelled_error("internal error")
.with_message(format!("could not infer that type {} was signed", t)),
TypeInferenceError::CouldNotSolve(Constraint::NamedTypeIs(loc, name, ty)) => loc
.labelled_error("internal error")
.with_message(format!("could not infer that the name {} refers to {}", name, ty)),
}
}
}
@@ -296,6 +318,7 @@ impl From<TypeInferenceWarning> for Diagnostic<usize> {
/// successful set of type resolutions (mappings from type variables to their values), or
/// a series of inference errors.
pub fn solve_constraints(
known_types: &HashMap<ArcIntern<String>, TypeOrVar>,
mut constraint_db: Vec<Constraint>,
) -> TypeInferenceResult<TypeResolutions> {
let mut errors = vec![];
@@ -346,7 +369,7 @@ pub fn solve_constraints(
}
all_constraints_solved = false;
} else {
errors.push(TypeInferenceError::CannotCastBetweenFunctinoTypes(
errors.push(TypeInferenceError::CannotCast(
loc,
TypeOrVar::Function(args1, ret1),
TypeOrVar::Function(args2, ret2),
@@ -360,21 +383,100 @@ pub fn solve_constraints(
Constraint::CanCastTo(
loc,
ft @ TypeOrVar::Function(_, _),
pt @ TypeOrVar::Primitive(_),
st1 @ TypeOrVar::Structure(_),
st2 @ TypeOrVar::Structure(_),
) => {
tracing::trace!(function_type = %ft, primitive_type = %pt, "we can't cast a function type to a primitive type");
errors.push(TypeInferenceError::CannotCastFromFunctionType(loc, pt));
tracing::trace!(
"structures can be equivalent, if their fields and types are exactly the same"
);
new_constraints.push(Constraint::Equivalent(loc, st1, st2));
changed_something = true;
}
Constraint::CanCastTo(
loc,
ft @ TypeOrVar::Function(_, _),
ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Structure(_),
) => {
tracing::trace!(function_type = %ft, other_type = %ot, "we can't cast a function type to a primitive or structure type");
errors.push(TypeInferenceError::CannotCast(loc, ft, ot));
changed_something = true;
}
Constraint::CanCastTo(
loc,
pt @ TypeOrVar::Primitive(_),
ft @ TypeOrVar::Function(_, _),
ot @ TypeOrVar::Function(_, _) | ot @ TypeOrVar::Structure(_),
) => {
tracing::trace!(function_type = %ft, primitive_type = %pt, "we can't cast a primitive type to a function type");
errors.push(TypeInferenceError::CannotCastToFunctionType(loc, pt));
tracing::trace!(other_type = %ot, primitive_type = %pt, "we can't cast a primitive type to a function or structure type");
errors.push(TypeInferenceError::CannotCast(loc, pt, ot));
changed_something = true;
}
Constraint::CanCastTo(
loc,
st @ TypeOrVar::Structure(_),
ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _),
) => {
tracing::trace!(structure_type = %st, other_type = %ot, "we can't cast a structure type to a function or primitive type");
errors.push(TypeInferenceError::CannotCast(loc, st, ot));
changed_something = true;
}
Constraint::NamedTypeIs(loc, name, ty) => match known_types.get(&name) {
None => {
tracing::trace!(type_name = %name, "we don't know a type named name");
errors.push(TypeInferenceError::UnknownTypeName(loc, name));
changed_something = true;
}
Some(declared_type) => {
tracing::trace!(type_name = %name, declared = %declared_type, provided = %ty, "validating that named type is equivalent to provided");
new_constraints.push(Constraint::Equivalent(
loc,
declared_type.clone(),
ty,
));
changed_something = true;
}
},
Constraint::TypeHasField(
loc,
TypeOrVar::Structure(mut fields),
field,
result_type,
) => match fields.remove(&field) {
None => {
let reconstituted = TypeOrVar::Structure(fields);
tracing::trace!(structure_type = %reconstituted, %field, "no field found in type");
errors.push(TypeInferenceError::NoFieldForType(
loc,
field,
reconstituted,
));
changed_something = true;
}
Some(field_subtype) => {
tracing::trace!(%field_subtype, %result_type, %field, "validating that field's subtype matches target result type");
new_constraints.push(Constraint::Equivalent(
loc,
result_type,
field_subtype,
));
changed_something = true;
}
},
Constraint::TypeHasField(
loc,
ot @ TypeOrVar::Primitive(_) | ot @ TypeOrVar::Function(_, _),
field,
_,
) => {
tracing::trace!(other_type = %ot, %field, "can't get field from primitive or function type");
errors.push(TypeInferenceError::NoFieldForType(loc, field, ot));
changed_something = true;
}
@@ -394,6 +496,13 @@ pub fn solve_constraints(
changed_something = true;
}
// if we're testing if a function type is numeric, then throw a useful warning
Constraint::ConstantNumericType(loc, t @ TypeOrVar::Structure(_)) => {
tracing::trace!(structure_type = %t, "structures can't be constant numbers");
errors.push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None));
changed_something = true;
}
// if we're testing if a number can fit into a numeric type, we can just do that!
Constraint::FitsInNumType(loc, TypeOrVar::Primitive(ctype), val) => {
match ctype.max_value() {
@@ -420,9 +529,21 @@ pub fn solve_constraints(
changed_something = true;
}
// if we're testing if a function type can fit into a numeric type, that's a problem
Constraint::FitsInNumType(loc, t @ TypeOrVar::Structure(_), val) => {
tracing::trace!(function_type = %t, "values don't fit in structure types");
errors.push(TypeInferenceError::CannotMakeNumberAStructure(
loc,
t,
Some(val),
));
changed_something = true;
}
// if we want to know if a type is something, and it is something, then we're done
Constraint::IsSomething(_, t @ TypeOrVar::Function(_, _))
| Constraint::IsSomething(_, t @ TypeOrVar::Primitive(_)) => {
| Constraint::IsSomething(_, t @ TypeOrVar::Primitive(_))
| Constraint::IsSomething(_, t @ TypeOrVar::Structure(_)) => {
tracing::trace!(tested_type = %t, "type is definitely something");
changed_something = true;
}
@@ -431,7 +552,10 @@ pub fn solve_constraints(
Constraint::IsSigned(loc, TypeOrVar::Primitive(pt)) => {
tracing::trace!(primitive_type = %pt, "we can check if a primitive is signed");
if !pt.valid_operators().contains(&("-", 1)) {
errors.push(TypeInferenceError::IsNotSigned(loc, pt));
errors.push(TypeInferenceError::IsNotSigned(
loc,
TypeOrVar::Primitive(pt),
));
}
changed_something = true;
}
@@ -439,7 +563,14 @@ pub fn solve_constraints(
// again with the functions and the numbers
Constraint::IsSigned(loc, t @ TypeOrVar::Function(_, _)) => {
tracing::trace!(function_type = %t, "functions are not signed");
errors.push(TypeInferenceError::CannotCastFromFunctionType(loc, t));
errors.push(TypeInferenceError::IsNotSigned(loc, t));
changed_something = true;
}
// again with the functions and the numbers
Constraint::IsSigned(loc, t @ TypeOrVar::Structure(_)) => {
tracing::trace!(structure_type = %t, "structures are not signed");
errors.push(TypeInferenceError::IsNotSigned(loc, t));
changed_something = true;
}
@@ -459,6 +590,13 @@ pub fn solve_constraints(
changed_something = true;
}
// if we're testing if a structure type is numeric, then throw a useful warning
Constraint::NumericType(loc, t @ TypeOrVar::Structure(_)) => {
tracing::trace!(structure_type = %t, "structure types aren't numeric");
errors.push(TypeInferenceError::CannotMakeNumberAStructure(loc, t, None));
changed_something = true;
}
// all of our primitive types are printable
Constraint::Printable(_, TypeOrVar::Primitive(pt)) => {
tracing::trace!(primitive_type = %pt, "primitive types are printable");
@@ -472,6 +610,17 @@ pub fn solve_constraints(
changed_something = true;
}
// structure types are printable if all the types inside them are printable
Constraint::Printable(loc, TypeOrVar::Structure(fields)) => {
tracing::trace!(
"structure types are printable if all their subtypes are printable"
);
for (_, subtype) in fields.into_iter() {
new_constraints.push(Constraint::Printable(loc.clone(), subtype));
}
changed_something = true;
}
Constraint::ProperPrimitiveArgs(loc, prim, mut args, ret) => match prim {
Primitive::Plus | Primitive::Minus | Primitive::Times | Primitive::Divide
if args.len() == 2 =>
@@ -558,6 +707,36 @@ pub fn solve_constraints(
changed_something = true;
}
Constraint::Equivalent(
loc,
pt @ TypeOrVar::Primitive(_),
st @ TypeOrVar::Structure(_),
)
| Constraint::Equivalent(
loc,
st @ TypeOrVar::Structure(_),
pt @ TypeOrVar::Primitive(_),
) => {
tracing::trace!(primitive_type = %pt, structure_type = %st, "structure and primitive types cannot be equivalent");
errors.push(TypeInferenceError::NotEquivalent(loc, pt, st));
changed_something = true;
}
Constraint::Equivalent(
loc,
st @ TypeOrVar::Structure(_),
ft @ TypeOrVar::Function(_, _),
)
| Constraint::Equivalent(
loc,
ft @ TypeOrVar::Function(_, _),
st @ TypeOrVar::Structure(_),
) => {
tracing::trace!(structure_type = %st, function_type = %ft, "structure and primitive types cannot be equivalent");
errors.push(TypeInferenceError::NotEquivalent(loc, st, ft));
changed_something = true;
}
Constraint::Equivalent(
_,
TypeOrVar::Variable(_, name1),
@@ -588,6 +767,35 @@ pub fn solve_constraints(
tracing::trace!("we checked/rewrote if function types are equivalent");
}
Constraint::Equivalent(
loc,
TypeOrVar::Structure(fields1),
TypeOrVar::Structure(mut fields2),
) => {
if fields1.len() == fields2.len()
&& fields1.keys().all(|x| fields2.contains_key(x))
{
for (name, subtype1) in fields1.into_iter() {
let subtype2 = fields2
.remove(&name)
.expect("can find matching field after equivalence check");
new_constraints.push(Constraint::Equivalent(
loc.clone(),
subtype1,
subtype2,
));
}
} else {
errors.push(TypeInferenceError::NotEquivalent(
loc,
TypeOrVar::Structure(fields1),
TypeOrVar::Structure(fields2),
))
}
changed_something = true;
tracing::trace!("we checked/rewrote if structures are equivalent");
}
Constraint::Equivalent(_, TypeOrVar::Variable(_, ref name), ref rhs) => {
changed_something |= replace_variable(&mut constraint_db, name, rhs);
changed_something |= replace_variable(&mut new_constraints, name, rhs);
@@ -607,6 +815,7 @@ pub fn solve_constraints(
Constraint::CanCastTo(_, TypeOrVar::Variable(_, _), _)
| Constraint::CanCastTo(_, _, TypeOrVar::Variable(_, _))
| Constraint::TypeHasField(_, TypeOrVar::Variable(_, _), _, _)
| Constraint::ConstantNumericType(_, TypeOrVar::Variable(_, _))
| Constraint::FitsInNumType(_, TypeOrVar::Variable(_, _), _)
| Constraint::IsSomething(_, TypeOrVar::Variable(_, _))