🤔 Add a type inference engine, along with typed literals. (#4)

The typed literal formatting mirrors that of Rust. If no type can be
inferred for an untagged literal, the type inference engine will warn
the user and then assume that they meant an unsigned 64-bit number.
(This is slightly inconvenient, because there can be cases in which our
Arbitrary instance may generate a unary negation, in which we should
assume that it's a signed 64-bit number; we may want to revisit this
later.)

The type inference engine is a standard two phase one, in which we first
generate a series of type constraints, and then we solve those
constraints. In this particular implementation, we actually use a third
phase to generate a final AST.

Finally, to increase the amount of testing performed, I've removed the
overflow checking in the evaluator. The only thing we now check for is
division by zero. This does make things a trace slower in testing, but
hopefully we get more coverage this way.
This commit was merged in pull request #4.
This commit is contained in:
2023-09-19 20:40:05 -07:00
committed by GitHub
parent 1fbfd0c2d2
commit bd3b9af469
44 changed files with 3258 additions and 702 deletions

View File

@@ -1,10 +1,14 @@
use crate::syntax::Location;
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use pretty::{BoxAllocator, DocAllocator, Pretty};
use proptest::{
prelude::Arbitrary,
strategy::{BoxedStrategy, Strategy},
};
use std::{fmt, str::FromStr};
/// We're going to represent variables as interned strings.
///
@@ -52,12 +56,15 @@ where
}
impl Arbitrary for Program {
type Parameters = ();
type Parameters = crate::syntax::arbitrary::GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
crate::syntax::Program::arbitrary_with(args)
.prop_map(Program::from)
.prop_map(|x| {
x.type_infer()
.expect("arbitrary_with should generate type-correct programs")
})
.boxed()
}
}
@@ -74,8 +81,8 @@ impl Arbitrary for Program {
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Expression),
Print(Location, Variable),
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
@@ -85,13 +92,13 @@ where
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, expr) => allocator
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, var) => allocator
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
@@ -113,9 +120,32 @@ where
/// variable reference.
#[derive(Debug)]
pub enum Expression {
Value(Location, Value),
Reference(Location, Variable),
Primitive(Location, Primitive, Vec<ValueOrRef>),
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
@@ -125,12 +155,16 @@ where
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Value(_, val) => val.pretty(allocator),
Expression::Reference(_, var) => allocator.text(var.as_ref().to_string()),
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => {
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
@@ -140,7 +174,7 @@ where
.append(right)
.parens()
}
Expression::Primitive(_, op, exprs) => {
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
@@ -161,10 +195,10 @@ pub enum Primitive {
Divide,
}
impl<'a> TryFrom<&'a str> for Primitive {
type Error = String;
impl FromStr for Primitive {
type Err = String;
fn try_from(value: &str) -> Result<Self, Self::Error> {
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
"+" => Ok(Primitive::Plus),
"-" => Ok(Primitive::Minus),
@@ -190,15 +224,21 @@ where
}
}
impl fmt::Display for Primitive {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f)
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub enum ValueOrRef {
Value(Location, Value),
Ref(Location, ArcIntern<String>),
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
@@ -208,30 +248,50 @@ where
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, v) => v.pretty(allocator),
ValueOrRef::Ref(_, v) => allocator.text(v.as_ref().to_string()),
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
match value {
ValueOrRef::Value(loc, val) => Expression::Value(loc, val),
ValueOrRef::Ref(loc, var) => Expression::Reference(loc, var),
}
Expression::Atomic(value)
}
}
/// A constant in the IR.
#[derive(Debug)]
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug)]
pub enum Value {
/// A numerical constant.
///
/// The optional argument is the base that was used by the user to input
/// the number. By retaining it, we can ensure that if we need to print the
/// number back out, we can do so in the form that the user entered it.
Number(Option<u8>, i64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl Value {
/// Return the type described by this value
pub fn type_of(&self) -> Type {
match self {
Value::I8(_, _) => Type::Primitive(PrimitiveType::I8),
Value::I16(_, _) => Type::Primitive(PrimitiveType::I16),
Value::I32(_, _) => Type::Primitive(PrimitiveType::I32),
Value::I64(_, _) => Type::Primitive(PrimitiveType::I64),
Value::U8(_, _) => Type::Primitive(PrimitiveType::U8),
Value::U16(_, _) => Type::Primitive(PrimitiveType::U16),
Value::U32(_, _) => Type::Primitive(PrimitiveType::U32),
Value::U64(_, _) => Type::Primitive(PrimitiveType::U64),
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
@@ -240,19 +300,64 @@ where
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Value::Number(opt_base, value) => {
let value_str = match opt_base {
None => format!("{}", value),
Some(2) => format!("0b{:b}", value),
Some(8) => format!("0o{:o}", value),
Some(10) => format!("0d{}", value),
Some(16) => format!("0x{:x}", value),
Some(_) => format!("!!{:x}!!", value),
};
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
allocator.text(value_str)
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Primitive(PrimitiveType),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Primitive(pt) => pt.fmt(f),
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::eval::{EvalEnvironment, EvalError, Value};
use crate::ir::{Expression, Program, Statement};
use super::{Primitive, ValueOrRef};
use super::{Primitive, Type, ValueOrRef};
impl Program {
/// Evaluate the program, returning either an error or a string containing everything
@@ -14,12 +14,12 @@ impl Program {
for stmt in self.statements.iter() {
match stmt {
Statement::Binding(_, name, value) => {
Statement::Binding(_, name, _, value) => {
let actual_value = value.eval(&env)?;
env = env.extend(name.clone(), actual_value);
}
Statement::Print(_, name) => {
Statement::Print(_, _, name) => {
let value = env.lookup(name.clone())?;
let line = format!("{} = {}\n", name, value);
stdout.push_str(&line);
@@ -34,26 +34,21 @@ impl Program {
impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
Expression::Value(_, v) => match v {
super::Value::Number(_, v) => Ok(Value::I64(*v)),
},
Expression::Atomic(x) => x.eval(env),
Expression::Reference(_, n) => Ok(env.lookup(n.clone())?),
Expression::Cast(_, t, valref) => {
let value = valref.eval(env)?;
Expression::Primitive(_, op, args) => {
let mut arg_values = Vec::with_capacity(args.len());
// we implement primitive operations by first evaluating each of the
// arguments to the function, and then gathering up all the values
// produced.
for arg in args.iter() {
match arg {
ValueOrRef::Ref(_, n) => arg_values.push(env.lookup(n.clone())?),
ValueOrRef::Value(_, super::Value::Number(_, v)) => {
arg_values.push(Value::I64(*v))
}
}
match t {
Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
}
}
Expression::Primitive(_, _, op, args) => {
let arg_values = args
.iter()
.map(|x| x.eval(env))
.collect::<Result<Vec<Value>, EvalError>>()?;
// and then finally we call `calculate` to run them. trust me, it's nice
// to not have to deal with all the nonsense hidden under `calculate`.
@@ -68,19 +63,38 @@ impl Expression {
}
}
impl ValueOrRef {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
ValueOrRef::Value(_, _, v) => match v {
super::Value::I8(_, v) => Ok(Value::I8(*v)),
super::Value::I16(_, v) => Ok(Value::I16(*v)),
super::Value::I32(_, v) => Ok(Value::I32(*v)),
super::Value::I64(_, v) => Ok(Value::I64(*v)),
super::Value::U8(_, v) => Ok(Value::U8(*v)),
super::Value::U16(_, v) => Ok(Value::U16(*v)),
super::Value::U32(_, v) => Ok(Value::U32(*v)),
super::Value::U64(_, v) => Ok(Value::U64(*v)),
},
ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?),
}
}
}
#[test]
fn two_plus_three() {
let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works");
let ir = Program::from(input);
let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully");
assert_eq!("x = 5i64\n", &output);
assert_eq!("x = 5u64\n", &output);
}
#[test]
fn lotsa_math() {
let input =
crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let ir = Program::from(input);
let ir = input.type_infer().expect("test should be type-valid");
let output = ir.eval().expect("runs successfully");
assert_eq!("x = 7i64\n", &output);
assert_eq!("x = 7u64\n", &output);
}

View File

@@ -1,186 +0,0 @@
use internment::ArcIntern;
use std::sync::atomic::AtomicUsize;
use crate::ir::ast as ir;
use crate::syntax;
use super::ValueOrRef;
impl From<syntax::Program> for ir::Program {
/// We implement the top-level conversion of a syntax::Program into an
/// ir::Program using just the standard `From::from`, because we don't
/// need to return any arguments and we shouldn't produce any errors.
/// Technically there's an `unwrap` deep under the hood that we could
/// float out, but the validator really should've made sure that never
/// happens, so we're just going to assume.
fn from(mut value: syntax::Program) -> Self {
let mut statements = Vec::new();
for stmt in value.statements.drain(..) {
statements.append(&mut stmt.simplify());
}
ir::Program { statements }
}
}
impl From<syntax::Statement> for ir::Program {
/// One interesting thing about this conversion is that there isn't
/// a natural translation from syntax::Statement to ir::Statement,
/// because the syntax version can have nested expressions and the
/// IR version can't.
///
/// As a result, we can naturally convert a syntax::Statement into
/// an ir::Program, because we can allow the additional binding
/// sites to be generated, instead. And, bonus, it turns out that
/// this is what we wanted anyways.
fn from(value: syntax::Statement) -> Self {
ir::Program {
statements: value.simplify(),
}
}
}
impl syntax::Statement {
/// Simplify a syntax::Statement into a series of ir::Statements.
///
/// The reason this function is one-to-many is because we may have to
/// introduce new binding sites in order to avoid having nested
/// expressions. Nested expressions, like `(1 + 2) * 3`, are allowed
/// in syntax::Expression but are expressly *not* allowed in
/// ir::Expression. So this pass converts them into bindings, like
/// this:
///
/// x = (1 + 2) * 3;
///
/// ==>
///
/// x:1 = 1 + 2;
/// x:2 = x:1 * 3;
/// x = x:2
///
/// Thus ensuring that things are nice and simple. Note that the
/// binding of `x:2` is not, strictly speaking, necessary, but it
/// makes the code below much easier to read.
fn simplify(self) -> Vec<ir::Statement> {
let mut new_statements = vec![];
match self {
// Print statements we don't have to do much with
syntax::Statement::Print(loc, name) => {
new_statements.push(ir::Statement::Print(loc, ArcIntern::new(name)))
}
// Bindings, however, may involve a single expression turning into
// a series of statements and then an expression.
syntax::Statement::Binding(loc, name, value) => {
let (mut prereqs, new_value) = value.rebind(&name);
new_statements.append(&mut prereqs);
new_statements.push(ir::Statement::Binding(
loc,
ArcIntern::new(name),
new_value.into(),
))
}
}
new_statements
}
}
impl syntax::Expression {
/// This actually does the meat of the simplification work, here, by rebinding
/// any nested expressions into their own variables. We have this return
/// `ValueOrRef` in all cases because it makes for slighly less code; in the
/// case when we actually want an `Expression`, we can just use `into()`.
fn rebind(self, base_name: &str) -> (Vec<ir::Statement>, ir::ValueOrRef) {
match self {
// Values just convert in the obvious way, and require no prereqs
syntax::Expression::Value(loc, val) => (vec![], ValueOrRef::Value(loc, val.into())),
// Similarly, references just convert in the obvious way, and require
// no prereqs
syntax::Expression::Reference(loc, name) => {
(vec![], ValueOrRef::Ref(loc, ArcIntern::new(name)))
}
// Primitive expressions are where we do the real work.
syntax::Expression::Primitive(loc, prim, mut expressions) => {
// generate a fresh new name for the binding site we're going to
// introduce, basing the name on wherever we came from; so if this
// expression was bound to `x` originally, it might become `x:23`.
//
// gensym is guaranteed to give us a name that is unused anywhere
// else in the program.
let new_name = gensym(base_name);
let mut prereqs = Vec::new();
let mut new_exprs = Vec::new();
// here we loop through every argument, and recurse on the expressions
// we find. that will give us any new binding sites that *they* introduce,
// and a simple value or reference that we can use in our result.
for expr in expressions.drain(..) {
let (mut cur_prereqs, arg) = expr.rebind(new_name.as_str());
prereqs.append(&mut cur_prereqs);
new_exprs.push(arg);
}
// now we're going to use those new arguments to run the primitive, binding
// the results to the new variable we introduced.
let prim =
ir::Primitive::try_from(prim.as_str()).expect("is valid primitive function");
prereqs.push(ir::Statement::Binding(
loc.clone(),
new_name.clone(),
ir::Expression::Primitive(loc.clone(), prim, new_exprs),
));
// and finally, we can return all the new bindings, and a reference to
// the variable we just introduced to hold the value of the primitive
// invocation.
(prereqs, ValueOrRef::Ref(loc, new_name))
}
}
}
}
impl From<syntax::Value> for ir::Value {
fn from(value: syntax::Value) -> Self {
match value {
syntax::Value::Number(base, val) => ir::Value::Number(base, val),
}
}
}
impl From<String> for ir::Primitive {
fn from(value: String) -> Self {
value.try_into().unwrap()
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
proptest::proptest! {
#[test]
fn translation_maintains_semantics(input: syntax::Program) {
let syntax_result = input.eval();
let ir = ir::Program::from(input);
let ir_result = ir.eval();
assert_eq!(syntax_result, ir_result);
}
}

View File

@@ -21,12 +21,12 @@ impl Program {
impl Statement {
fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) {
match self {
Statement::Binding(_, name, expr) => {
Statement::Binding(_, name, _, expr) => {
string_set.insert(name.clone());
expr.register_strings(string_set);
}
Statement::Print(_, name) => {
Statement::Print(_, _, name) => {
string_set.insert(name.clone());
}
}