From 7ebb31b42f12102250d61ea516d13dce58cd2dee Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Thu, 28 Dec 2023 20:57:03 -0800 Subject: [PATCH] checkpoint --- src/backend/eval.rs | 2 +- src/backend/into_crane.rs | 5 ++- src/eval/primtype.rs | 3 ++ src/ir/ast.rs | 6 +++- src/syntax/ast.rs | 5 +++ src/syntax/eval.rs | 67 +++++++++++++++++++++++++------------- src/syntax/parser.lalrpop | 6 ++-- src/syntax/pretty.rs | 19 +++++++++++ src/syntax/validate.rs | 18 +++++++++- src/type_infer/convert.rs | 24 ++++++++++++-- src/type_infer/finalize.rs | 20 ++++++++++++ src/type_infer/solve.rs | 22 +++++++++++++ 12 files changed, 166 insertions(+), 31 deletions(-) diff --git a/src/backend/eval.rs b/src/backend/eval.rs index e3a77eb..f169e6f 100644 --- a/src/backend/eval.rs +++ b/src/backend/eval.rs @@ -159,7 +159,7 @@ impl Backend { .arg(executable_path) .output()?; - if !output.stderr.is_empty() { + if !output.status.success() { return Err(EvalError::Linker( std::string::String::from_utf8_lossy(&output.stderr).to_string(), )); diff --git a/src/backend/into_crane.rs b/src/backend/into_crane.rs index 7fb11fe..8db56bc 100644 --- a/src/backend/into_crane.rs +++ b/src/backend/into_crane.rs @@ -413,7 +413,10 @@ impl Backend { // negative number for us. Which sets the high bits, which makes Cranelift unhappy. // So first we cast the i8 as u8, to get rid of the whole concept of sign extension, // and *then* we cast to i64. - Ok((builder.ins().iconst(types::I8, v as u8 as i64), ConstantType::I8)) + Ok(( + builder.ins().iconst(types::I8, v as u8 as i64), + ConstantType::I8, + )) } Value::I16(_, v) => Ok(( // see above note for the "... as ... as" diff --git a/src/eval/primtype.rs b/src/eval/primtype.rs index 0624ba2..fed57f3 100644 --- a/src/eval/primtype.rs +++ b/src/eval/primtype.rs @@ -113,6 +113,7 @@ impl FromStr for PrimitiveType { "u16" => Ok(PrimitiveType::U16), "u32" => Ok(PrimitiveType::U32), "u64" => Ok(PrimitiveType::U64), + "void" => Ok(PrimitiveType::Void), _ => Err(UnknownPrimType::UnknownPrimType(s.to_owned())), } } @@ -190,6 +191,8 @@ impl PrimitiveType { (PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)), (PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)), + (PrimitiveType::Void, Value::Void) => Ok(Value::Void), + _ => Err(PrimOpError::UnsafeCast { from: PrimitiveType::try_from(source)?, to: *self, diff --git a/src/ir/ast.rs b/src/ir/ast.rs index b3d83d3..78c6998 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -228,9 +228,13 @@ where .text("print") .append(allocator.space()) .append(allocator.text(var.as_ref().to_string())), - Expression::Bind(_, var, _, expr) => allocator + Expression::Bind(_, var, ty, expr) => allocator .text(var.as_ref().to_string()) .append(allocator.space()) + .append(allocator.text(":")) + .append(allocator.space()) + .append(ty.pretty(allocator)) + .append(allocator.space()) .append(allocator.text("=")) .append(allocator.space()) .append(expr.pretty(allocator)), diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index b97851f..2ec0894 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -124,6 +124,7 @@ pub enum Expression { Reference(Location, String), Cast(Location, String, Box), Primitive(Location, String, Vec), + Block(Location, Vec), } impl PartialEq for Expression { @@ -145,6 +146,10 @@ impl PartialEq for Expression { Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, _ => false, }, + Expression::Block(_, stmts1) => match other { + Expression::Block(_, stmts2) => stmts1 == stmts2, + _ => false, + }, } } } diff --git a/src/syntax/eval.rs b/src/syntax/eval.rs index 6f96102..6093633 100644 --- a/src/syntax/eval.rs +++ b/src/syntax/eval.rs @@ -35,25 +35,7 @@ impl Program { } } - TopLevel::Statement(Statement::Binding(_, name, value)) => { - let actual_value = value.eval(&env)?; - env.insert(name.clone().intern(), actual_value); - last_result = Value::Void; - } - - TopLevel::Statement(Statement::Print(loc, name)) => { - let value = env - .get(&name.clone().intern()) - .ok_or_else(|| EvalError::LookupFailed(loc.clone(), name.name.clone()))?; - let value = if let Value::Number(x) = value { - Value::U64(*x) - } else { - value.clone() - }; - let line = format!("{} = {}\n", name, value); - stdout.push_str(&line); - last_result = Value::Void; - } + TopLevel::Statement(stmt) => last_result = stmt.eval(&mut stdout, &mut env)?, } } @@ -61,10 +43,41 @@ impl Program { } } +impl Statement { + fn eval( + &self, + stdout: &mut String, + env: &mut ScopedMap, Value>, + ) -> Result, EvalError> { + match self { + Statement::Binding(_, name, value) => { + let actual_value = value.eval(stdout, env)?; + env.insert(name.clone().intern(), actual_value); + Ok(Value::Void) + } + + Statement::Print(loc, name) => { + let value = env + .get(&name.clone().intern()) + .ok_or_else(|| EvalError::LookupFailed(loc.clone(), name.name.clone()))?; + let value = if let Value::Number(x) = value { + Value::U64(*x) + } else { + value.clone() + }; + let line = format!("{} = {}\n", name, value); + stdout.push_str(&line); + Ok(Value::Void) + } + } + } +} + impl Expression { fn eval( &self, - env: &ScopedMap, Value>, + stdout: &mut String, + env: &mut ScopedMap, Value>, ) -> Result, EvalError> { match self { Expression::Value(_, v) => match v { @@ -90,7 +103,7 @@ impl Expression { Expression::Cast(_, target, expr) => { let target_type = PrimitiveType::from_str(target)?; - let value = expr.eval(env)?; + let value = expr.eval(stdout, env)?; Ok(target_type.safe_cast(&value)?) } @@ -99,11 +112,21 @@ impl Expression { for arg in args.iter() { // yay, recursion! makes this pretty straightforward - arg_values.push(arg.eval(env)?); + arg_values.push(arg.eval(stdout, env)?); } Ok(Value::calculate(op, arg_values)?) } + + Expression::Block(_, stmts) => { + let mut result = Value::Void; + + for stmt in stmts.iter() { + result = stmt.eval(stdout, env)?; + } + + Ok(result) + } } } } diff --git a/src/syntax/parser.lalrpop b/src/syntax/parser.lalrpop index 40987f3..8c66e63 100644 --- a/src/syntax/parser.lalrpop +++ b/src/syntax/parser.lalrpop @@ -87,11 +87,13 @@ OptionalName: Option = { } Arguments: Vec = { - => { + "," => { args.push(arg); args }, + => vec![arg], + => Vec::new(), } @@ -214,7 +216,7 @@ AtomicExpression: Expression = { // just a number "> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)), // this expression could actually be a block! - "{" "}" => unimplemented!(), + "{" "}" => Expression::Block(Location::new(file_idx, s..e), stmts), // finally, let people parenthesize expressions and get back to a // lower precedence "(" ")" => e, diff --git a/src/syntax/pretty.rs b/src/syntax/pretty.rs index 9bb0461..859f51f 100644 --- a/src/syntax/pretty.rs +++ b/src/syntax/pretty.rs @@ -105,6 +105,25 @@ where let comma_sepped_args = allocator.intersperse(args, CommaSep {}); call.append(comma_sepped_args.parens()) } + Expression::Block(_, stmts) => match stmts.split_last() { + None => allocator.text("()"), + Some((last, &[])) => last.pretty(allocator), + Some((last, start)) => { + let mut result = allocator.text("{").append(allocator.hardline()); + + for stmt in start.iter() { + result = result + .append(stmt.pretty(allocator)) + .append(allocator.text(";")) + .append(allocator.hardline()); + } + + result + .append(last.pretty(allocator)) + .append(allocator.hardline()) + .append(allocator.text("}")) + } + }, } } } diff --git a/src/syntax/validate.rs b/src/syntax/validate.rs index 46c471c..1503c65 100644 --- a/src/syntax/validate.rs +++ b/src/syntax/validate.rs @@ -180,7 +180,10 @@ impl Statement { } impl Expression { - fn validate(&self, variable_map: &ScopedMap) -> (Vec, Vec) { + fn validate( + &self, + variable_map: &mut ScopedMap, + ) -> (Vec, Vec) { match self { Expression::Value(_, _) => (vec![], vec![]), Expression::Reference(_, var) if variable_map.contains_key(var) => (vec![], vec![]), @@ -207,6 +210,19 @@ impl Expression { warnings.append(&mut warn); } + (errors, warnings) + } + Expression::Block(_, stmts) => { + let mut errors = vec![]; + let mut warnings = vec![]; + + for stmt in stmts.iter() { + let (mut local_errors, mut local_warnings) = stmt.validate(variable_map); + + errors.append(&mut local_errors); + warnings.append(&mut local_warnings); + } + (errors, warnings) } } diff --git a/src/type_infer/convert.rs b/src/type_infer/convert.rs index 3e968b3..4a1800b 100644 --- a/src/type_infer/convert.rs +++ b/src/type_infer/convert.rs @@ -71,9 +71,10 @@ pub fn convert_top_level( args.iter().map(|x| ArcIntern::new(x.to_string())).collect(); assert_eq!(argtypes.len(), iargs.len()); let mut function_args = vec![]; - for (arg_name, arg_type) in iargs.iter().zip(argtypes) { + for ((arg_name, arg_type), orig_name) in iargs.iter().zip(argtypes).zip(args) { bindings.insert(arg_name.clone(), arg_type.clone()); - function_args.push((arg_name.clone(), arg_type)); + function_args.push((arg_name.clone(), arg_type.clone())); + constraint_db.push(Constraint::IsSomething(orig_name.location, arg_type)); } let (expr, ty) = convert_expression(expr, constraint_db, renames, bindings); @@ -150,7 +151,7 @@ fn convert_statement( fn convert_expression( expression: syntax::Expression, constraint_db: &mut Vec, - renames: &ScopedMap, ArcIntern>, + renames: &mut ScopedMap, ArcIntern>, bindings: &mut ScopedMap, ir::TypeOrVar>, ) -> (ir::Expression, ir::TypeOrVar) { match expression { @@ -282,6 +283,23 @@ fn convert_expression( ) } } + + syntax::Expression::Block(loc, stmts) => { + let mut ret_type = ir::TypeOrVar::Primitive(PrimitiveType::Void); + let mut exprs = vec![]; + + for statement in stmts { + let expr = convert_statement(statement, constraint_db, renames, bindings); + + ret_type = expr.type_of(); + exprs.push(expr); + } + + ( + ir::Expression::Block(loc, ret_type.clone(), exprs), + ret_type, + ) + } } } diff --git a/src/type_infer/finalize.rs b/src/type_infer/finalize.rs index d681c34..e2af988 100644 --- a/src/type_infer/finalize.rs +++ b/src/type_infer/finalize.rs @@ -6,6 +6,26 @@ pub fn finalize_program( mut program: Program, resolutions: &TypeResolutions, ) -> Program { + println!("RESOLUTIONS:"); + for (name, ty) in resolutions.iter() { + println!("{} => {}", name, ty); + } + println!("PROGRAM:"); + { + use pretty::{DocAllocator, Pretty}; + let allocator = pretty::BoxAllocator; + allocator + .text("---------------") + .append(allocator.hardline()) + .append(program.pretty(&allocator)) + .1 + .render_colored( + 70, + pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto), + ) + .expect("rendering works"); + } + Program { items: program .items diff --git a/src/type_infer/solve.rs b/src/type_infer/solve.rs index f3ba8c8..34127de 100644 --- a/src/type_infer/solve.rs +++ b/src/type_infer/solve.rs @@ -25,6 +25,8 @@ pub enum Constraint { ConstantNumericType(Location, TypeOrVar), /// The two types should be equivalent Equivalent(Location, TypeOrVar, TypeOrVar), + /// The given type can be resolved to something + IsSomething(Location, TypeOrVar), } impl fmt::Display for Constraint { @@ -43,6 +45,7 @@ impl fmt::Display for Constraint { Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty), Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty), Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2), + Constraint::IsSomething(_, ty) => write!(f, "SOMETHING {}", ty), } } } @@ -213,6 +216,10 @@ impl From for Diagnostic { prim )) } + TypeInferenceError::CouldNotSolve(Constraint::IsSomething(loc, _)) => { + loc.labelled_error("could not infer type") + .with_message("Could not find *any* type information; is this an unused function argument?") + } } } } @@ -253,6 +260,11 @@ pub fn solve_constraints( let mut resolutions = HashMap::new(); let mut changed_something = true; + println!("CONSTRAINTS:"); + for constraint in constraint_db.iter() { + println!("{}", constraint); + } + // We want to run this inference endlessly, until either we have solved all of our // constraints. Internal to the loop, we have a check that will make sure that we // do (eventually) stop. @@ -275,6 +287,16 @@ pub fn solve_constraints( // Currently, all of our types are printable Constraint::Printable(_loc, _ty) => changed_something = true, + // If we're looking for a type to be something (anything!), and it's not a type + // variable, then yay, we've solved it. + Constraint::IsSomething(_, TypeOrVar::Function(_, _)) + | Constraint::IsSomething(_, TypeOrVar::Primitive(_)) => changed_something = true, + + // Otherwise, we'll keep looking for it. + Constraint::IsSomething(_, TypeOrVar::Variable(_, _)) => { + constraint_db.push(constraint); + } + // Case #1a: We have two primitive types. If they're equal, we've discharged this // constraint! We can just continue. If they're not equal, add an error and then // see what else we come up with.