λ Support functions! #5

Open
acw wants to merge 59 commits from awick/functions into develop
74 changed files with 18865 additions and 2976 deletions

View File

@@ -25,9 +25,9 @@ jobs:
toolchain: ${{ matrix.rust }} toolchain: ${{ matrix.rust }}
default: true default: true
override: true override: true
- name: Build
run: cargo build
- name: Format Check - name: Format Check
run: cargo fmt --check run: cargo fmt --check
- name: Build
run: cargo build
- name: Run tests - name: Run tests
run: cargo test run: cargo test

View File

@@ -9,25 +9,30 @@ name = "ngr"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
clap = { version = "^3.0.14", features = ["derive"] } clap = { version = "4.5.1", features = ["derive"] }
codespan = "0.11.1" codespan = "0.11.1"
codespan-reporting = "0.11.1" codespan-reporting = "0.11.1"
cranelift-codegen = "0.99.2" cranelift-codegen = "0.105.2"
cranelift-jit = "0.99.2" cranelift-jit = "0.105.2"
cranelift-frontend = "0.99.2" cranelift-frontend = "0.105.2"
cranelift-module = "0.99.2" cranelift-module = "0.105.2"
cranelift-native = "0.99.2" cranelift-native = "0.105.2"
cranelift-object = "0.99.2" cranelift-object = "0.105.2"
internment = { version = "0.7.0", default-features = false, features = ["arc"] } internment = { version = "0.7.4", default-features = false, features = ["arc"] }
lalrpop-util = "^0.20.0" lalrpop-util = "0.20.2"
lazy_static = "^1.4.0" lazy_static = "1.4.0"
logos = "^0.12.0" logos = "0.14.0"
pretty = { version = "^0.11.2", features = ["termcolor"] } pretty = { version = "0.12.3", features = ["termcolor"] }
proptest = "^1.0.0" proptest = "1.4.0"
rustyline = "^11.0.0" rand = "0.8.5"
target-lexicon = "^0.12.5" rustyline = "13.0.0"
tempfile = "^3.5.0" target-lexicon = "0.12.14"
thiserror = "^1.0.30" tempfile = "3.10.1"
thiserror = "1.0.57"
anyhow = "1.0.80"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["time", "json", "env-filter"] }
names = "0.14.0"
[build-dependencies] [build-dependencies]
lalrpop = "^0.20.0" lalrpop = "0.20.2"

View File

@@ -45,17 +45,20 @@ fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> {
writeln!(f, " let mut file_database = SimpleFiles::new();")?; writeln!(f, " let mut file_database = SimpleFiles::new();")?;
writeln!( writeln!(
f, f,
" let syntax = Syntax::parse_file(&mut file_database, {:?});", " let syntax = crate::syntax::parse_file(&mut file_database, {:?});",
entry.path().display() entry.path().display()
)?; )?;
if entry.path().to_string_lossy().contains("broken") { if entry.path().to_string_lossy().contains("broken") {
writeln!(f, " if syntax.is_err() {{")?; writeln!(f, " if syntax.is_err() {{")?;
writeln!(f, " return;")?; writeln!(f, " return;")?;
writeln!(f, " }}")?; writeln!(f, " }}")?;
writeln!(f, " let (errors, _) = syntax.unwrap().validate();")?;
writeln!( writeln!(
f, f,
" assert_ne!(errors.len(), 0, \"should have seen an error\");" " let mut validation_result = Syntax::validate(syntax.unwrap());"
)?;
writeln!(
f,
" assert!(validation_result.is_err(), \"should have seen an error\");"
)?; )?;
} else { } else {
// NOTE: Since the advent of defaulting rules and type checking, we // NOTE: Since the advent of defaulting rules and type checking, we
@@ -67,18 +70,38 @@ fn generate_tests(f: &mut File, path_so_far: PathBuf) -> std::io::Result<()> {
f, f,
" let syntax = syntax.expect(\"file should have parsed\");" " let syntax = syntax.expect(\"file should have parsed\");"
)?; )?;
writeln!(f, " let (errors, _) = syntax.validate();")?; writeln!(f, " let validation_result = Syntax::validate(syntax);")?;
writeln!( writeln!(
f, f,
" assert_eq!(errors.len(), 0, \"file should have no validation errors\");" " assert!(validation_result.is_ok(), \"file should have no validation errors\");"
)?; )?;
writeln!(
f,
" let syntax = validation_result.into_result().unwrap();"
)?;
writeln!(f, " let syntax_result = syntax.eval();")?;
writeln!( writeln!(
f, f,
" let ir = syntax.type_infer().expect(\"example is typed correctly\");" " let ir = syntax.type_infer().expect(\"example is typed correctly\");"
)?; )?;
writeln!(f, " let ir_result = ir.eval();")?; writeln!(f, " let ir_evaluator = crate::ir::Evaluator::default();")?;
writeln!(f, " let ir_result = ir_evaluator.eval(ir.clone());")?;
writeln!(f, " match (&syntax_result, &ir_result) {{")?;
writeln!(f, " (Err(e1), Err(e2)) => assert_eq!(e1, e2),")?;
writeln!(f, " (Ok((v1, o1)), Ok((v2, o2))) => {{")?;
writeln!(f, " assert_eq!(v1, v2);")?;
writeln!(f, " assert_eq!(o1, o2);")?;
writeln!(f, " }}")?;
writeln!(f, " _ => panic!(\"mismatched outputs, {{:?}} and {{:?}}\", syntax_result, ir_result)")?;
writeln!(f, " }}")?;
writeln!(f, " let compiled_result = Backend::<JITModule>::eval(ir);")?; writeln!(f, " let compiled_result = Backend::<JITModule>::eval(ir);")?;
writeln!(f, " assert_eq!(ir_result, compiled_result);")?; writeln!(f, " match (&compiled_result, &ir_result) {{")?;
writeln!(f, " (Err(e1), Err(e2)) => assert_eq!(e1, e2),")?;
writeln!(f, " (Ok(o1), Ok((_, o2))) => {{")?;
writeln!(f, " assert_eq!(o1, o2);")?;
writeln!(f, " }}")?;
writeln!(f, " _ => panic!(\"mismatched outputs, {{:?}} and {{:?}}\", compiled_result, ir_result)")?;
writeln!(f, " }}")?;
} }
writeln!(f, "}}")?; writeln!(f, "}}")?;
} }

View File

@@ -1,4 +0,0 @@
x = 5;
x = 4*x + 3;
print x;
print y;

View File

@@ -0,0 +1,7 @@
x = 1;
function add_x(y) x + y;
a = 3;
print x;
result = add_x(a);
print x;
print result;

View File

@@ -0,0 +1,10 @@
x = 1;
function add_x(y) x + y;
a = 3;
function add_x_twice(y) add_x(y) + x;
print x;
result = add_x(a);
print x;
print result;
result2 = add_x_twice(a);
print result2;

View File

@@ -0,0 +1,11 @@
x = 1u64;
function mean_x(y) {
base = x + y;
result = base / 2;
result
};
a = 3;
mean_x_and_a = mean_x(a);
mean_x_and_9 = mean_x(9);
print mean_x_and_a;
print mean_x_and_9;

View File

@@ -0,0 +1,10 @@
function make_adder(x)
function (y)
x + y;
add1 = make_adder(1);
add2 = make_adder(2);
one_plus_one = add1(1);
one_plus_three = add1(3);
print one_plus_one;
print one_plus_three;

View File

@@ -0,0 +1,3 @@
x = 4u64;
function f(y) (x + y);
print x;

View File

@@ -0,0 +1,7 @@
b = -7662558304906888395i64;
z = 1030390794u32;
v = z;
q = <i64>z;
s = -2115098981i32;
t = <i32>s;
print t;

View File

@@ -0,0 +1,4 @@
n = (49u8 + 155u8);
q = n;
function u (b) n + b;
v = n;

View File

@@ -0,0 +1,150 @@
d3347 = {
s3349 = v3348 = 17175152522826808410u64;
print s3349;
j3350 = -6926831316600240717i64;
g3351 = b3352 = print s3349;
g3351;
u3353 = v3348;
c3460 = p3354 = {
w3441 = u3355 = k3356 = v3357 = {
j3364 = b3358 = b3359 = {
e3363 = o3360 = a3361 = n3362 = v3348 * v3348;
print e3363;
a3361
};
v3365 = j3364;
o3376 = {
p3366 = 62081u16;
t3369 = {
k3367 = 3742184609455079849u64;
y3368 = print g3351;
k3367
};
e3371 = o3370 = p3366;
print s3349;
l3372 = 50u8;
g3373 = 1086766998u32;
u3374 = <u64>g3373;
p3375 = 13826883074707422152u64
};
h3379 = v3377 = h3378 = 1513207896u32;
x3382 = f3380 = i3381 = <i32>-72i8;
a3383 = g3351;
q3440 = q3384 = {
r3385 = v3365;
z3437 = n3386 = {
o3387 = -1428233008i32;
c3388 = s3349;
b3352;
c3389 = c3388;
a3383;
c3390 = 1056u16;
l3433 = {
b3392 = f3391 = -881200191i32;
print o3376;
print v3377;
h3395 = y3393 = j3394 = -2456592492064497053i64;
c3396 = c3388;
f3397 = 2442824079u32;
d3428 = {
n3400 = h3398 = m3399 = v3365;
j3401 = v3348;
t3402 = -10i8;
e3403 = g3404 = print r3385;
e3403;
d3425 = {
u3405 = 51313u16;
l3406 = 235u8;
l3407 = 7030u16;
i3413 = {
b3352;
v3408 = o3387;
i3409 = 42u8;
q3411 = {
u3410 = -70i8;
n3400
};
q3412 = print i3409;
x3382
};
z3414 = b3392;
j3417 = s3415 = o3416 = 7220082853233268797u64;
print i3381;
n3419 = a3418 = -1270109327i32;
o3420 = r3385;
o3421 = x3382;
j3422 = print i3381;
q3423 = -4497i16;
x3424 = -98995788i32;
f3391
};
{
x3426 = r3385 - 13032422114254415490u64;
e3427 = -51i8;
j3401
}
};
k3429 = <u64>c3396;
print j3350;
g3351;
f3430 = -12293i16;
v3431 = 4016608549u32;
t3432 = f3397;
f3391 - -716040069i32
};
r3434 = -12984i16;
s3435 = 293908485953501586u64;
{
print j3364;
h3436 = <u64>7732092399687242928u64
}
};
u3438 = v3377 * 3168795739u32;
p3439 = <i32>-12313i16;
u3438
};
v3348
};
x3442 = -5279281110772475785i64;
y3443 = v3348;
b3444 = 2732851783u32;
l3456 = l3445 = {
v3446 = -27499i16;
i3447 = 3517560837u32;
z3448 = u3353 * u3355;
u3449 = x3442;
t3450 = u3353;
m3452 = n3451 = u3355;
b3453 = m3452;
x3455 = t3454 = 2134443760u32;
print b3444;
b3444
};
c3457 = l3456;
g3351;
m3458 = 1822439673528019141u64;
d3459 = l3456
};
k3462 = f3461 = <i32>13u8;
g3351;
u3463 = -10083i16;
v3348
};
y3464 = {
y3465 = 163u8;
y3466 = -7760i16;
e3467 = d3347;
q3468 = 58708u16;
-426970972827051249i64
};
t3469 = 524885465u32;
function b3470 (y3471,c3472,u3473) <u32>1606677228u32;
y3464;
-240502590i32;
v3474 = z3475 = {
t3476 = y3464;
p3477 = t3469;
i3478 = p3477;
a3480 = e3479 = t3476;
p3477
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
struct Point {
x: u64;
y: u64;
}
v = 1u64;
function getX(p: Point) -> u64
p.x;
function getY(p: Point) -> u64
p.y;
function newPoint(x, y) -> Point
Point {
x: x;
y: y;
};
function slope(p1, p2) -> u64
(getY(p2) - p1.y) / (getX(p2) - p1.x);
origin = newPoint(0, 0);
farther = newPoint(2, 4);
mySlope = slope(origin, farther);
print mySlope;

View File

@@ -0,0 +1,13 @@
struct Point {
x: u64;
y: u64;
}
test = Point {
x: 1;
y: 2;
};
foo = test.x;
print foo;

View File

@@ -0,0 +1,27 @@
struct Point {
x: u64;
y: u64;
}
struct Line {
start: Point;
end: Point;
}
function slope(l) -> u64
(l.end.y - l.start.y) / (l.end.x - l.start.x);
test = Line {
start: Point {
x: 1;
y: 2;
};
end: Point {
x: 2;
y: 4;
};
};
foo = slope(test);
print foo;

View File

@@ -0,0 +1,4 @@
x = 5;
y = 4*x + 3;
print x;
print y;

View File

@@ -0,0 +1 @@
v = -2370389138213399653i64;

View File

@@ -0,0 +1,4 @@
function u (t,h,z,u,a,f,c) 27u8;
function s (p,y,k) 10318938949979263534u64;
o = -98i8;
print o;

View File

@@ -0,0 +1,2 @@
t = v = 5i64;
print t;

View File

@@ -1,7 +1,13 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <inttypes.h> #include <inttypes.h>
#define HEAP_SIZE (1024 * 1024)
void *__global_allocation_pointer__ = NULL;
void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) { void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) {
switch(vtype) { switch(vtype) {
case /* U8 = */ 10: case /* U8 = */ 10:
@@ -28,12 +34,29 @@ void print(char *_ignore, char *variable_name, int64_t vtype, int64_t value) {
case /* I64 = */ 23: case /* I64 = */ 23:
printf("%s = %" PRIi64 "i64\n", variable_name, value); printf("%s = %" PRIi64 "i64\n", variable_name, value);
break; break;
case /* void = */ 255:
printf("%s = <void>\n", variable_name);
break;
default:
printf("%s = UNKNOWN VTYPE %lld\n", variable_name, vtype);
} }
} }
extern void gogogo(); extern void gogogo();
int main(int argc, char **argv) { int main(int argc, char **argv) {
__global_allocation_pointer__ = malloc(HEAP_SIZE);
if(__global_allocation_pointer__ == NULL) {
printf("ERROR: Couldn't allocation heap space.");
return 1;
}
if(memset(__global_allocation_pointer__, 0, HEAP_SIZE) != __global_allocation_pointer__) {
printf("ERROR: Weird return trying to zero out heap.");
return 2;
}
gogogo(); gogogo();
return 0; return 0;
} }

View File

@@ -33,7 +33,10 @@ mod runtime;
pub use self::error::BackendError; pub use self::error::BackendError;
pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions}; pub use self::runtime::{RuntimeFunctionError, RuntimeFunctions};
use crate::syntax::ConstantType; use crate::ir::Name;
use crate::syntax::{ConstantType, Location};
use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::types;
use cranelift_codegen::settings::Configurable; use cranelift_codegen::settings::Configurable;
use cranelift_codegen::{isa, settings}; use cranelift_codegen::{isa, settings};
use cranelift_jit::{JITBuilder, JITModule}; use cranelift_jit::{JITBuilder, JITModule};
@@ -58,9 +61,11 @@ pub struct Backend<M: Module> {
data_ctx: DataDescription, data_ctx: DataDescription,
runtime_functions: RuntimeFunctions, runtime_functions: RuntimeFunctions,
defined_strings: HashMap<String, DataId>, defined_strings: HashMap<String, DataId>,
defined_symbols: HashMap<String, (DataId, ConstantType)>, defined_functions: HashMap<Name, FuncId>,
defined_symbols: HashMap<Name, (DataId, types::Type)>,
output_buffer: Option<String>, output_buffer: Option<String>,
platform: Triple, platform: Triple,
next_variable: usize,
} }
impl Backend<JITModule> { impl Backend<JITModule> {
@@ -84,15 +89,28 @@ impl Backend<JITModule> {
let mut module = JITModule::new(builder); let mut module = JITModule::new(builder);
let runtime_functions = RuntimeFunctions::new(&platform, &mut module)?; let runtime_functions = RuntimeFunctions::new(&platform, &mut module)?;
Ok(Backend { let mut retval = Backend {
module, module,
data_ctx: DataDescription::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_functions: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer, output_buffer,
platform: Triple::host(), platform: Triple::host(),
}) next_variable: 23,
};
let alloc = "__global_allocation_pointer__".to_string();
let id = retval
.module
.declare_data(&alloc, Linkage::Import, true, false)?;
retval.defined_symbols.insert(
Name::new(alloc, Location::manufactured()),
(id, retval.module.target_config().pointer_type()),
);
Ok(retval)
} }
/// Given a compiled function ID, get a pointer to where that function was written /// Given a compiled function ID, get a pointer to where that function was written
@@ -123,15 +141,28 @@ impl Backend<ObjectModule> {
let mut module = ObjectModule::new(object_builder); let mut module = ObjectModule::new(object_builder);
let runtime_functions = RuntimeFunctions::new(&platform, &mut module)?; let runtime_functions = RuntimeFunctions::new(&platform, &mut module)?;
Ok(Backend { let mut retval = Backend {
module, module,
data_ctx: DataDescription::new(), data_ctx: DataDescription::new(),
runtime_functions, runtime_functions,
defined_strings: HashMap::new(), defined_strings: HashMap::new(),
defined_functions: HashMap::new(),
defined_symbols: HashMap::new(), defined_symbols: HashMap::new(),
output_buffer: None, output_buffer: None,
platform, platform,
}) next_variable: 23,
};
let alloc = "__global_allocation_pointer__".to_string();
let id = retval
.module
.declare_data(&alloc, Linkage::Import, true, false)?;
retval.defined_symbols.insert(
Name::new(alloc, Location::manufactured()),
(id, retval.module.target_config().pointer_type()),
);
Ok(retval)
} }
/// Given all the functions defined, return the bytes the object file should contain. /// Given all the functions defined, return the bytes the object file should contain.
@@ -172,19 +203,45 @@ impl<M: Module> Backend<M> {
/// value will be null. /// value will be null.
pub fn define_variable( pub fn define_variable(
&mut self, &mut self,
name: String, name: Name,
ctype: ConstantType, ctype: ConstantType,
) -> Result<DataId, BackendError> { ) -> Result<DataId, BackendError> {
self.data_ctx.define(Box::new(EMPTY_DATUM)); self.data_ctx.define(Box::new(EMPTY_DATUM));
let id = self let id = self
.module .module
.declare_data(&name, Linkage::Export, true, false)?; .declare_data(name.current_name(), Linkage::Export, true, false)?;
self.module.define_data(id, &self.data_ctx)?; self.module.define_data(id, &self.data_ctx)?;
self.data_ctx.clear(); self.data_ctx.clear();
self.defined_symbols.insert(name, (id, ctype)); self.defined_symbols.insert(name, (id, ctype.into()));
Ok(id) Ok(id)
} }
pub fn string_reference(&mut self, string: &str) -> Result<DataId, BackendError> {
match self.defined_strings.get(string) {
Some(x) => Ok(*x),
None => self.define_string(string),
}
}
/// Reset the local variable counter, because we're going to start a new
/// function.
pub fn reset_local_variable_tracker(&mut self) {
self.next_variable = 5;
}
/// Declare a local variable with the given name and type.
///
/// This variable should only be used to reference variables within the current
/// function. If you try to reference a variable from another function, random
/// things could happen; hopefully, Cranelift will yell at you, but there's a
/// small chance everything would work out and you'd end up referencing something
/// unexpected.
pub fn generate_local(&mut self) -> cranelift_frontend::Variable {
let var = cranelift_frontend::Variable::new(self.next_variable);
self.next_variable += 1;
var
}
/// Get a pointer to the output buffer for `print`ing, or `null`. /// Get a pointer to the output buffer for `print`ing, or `null`.
/// ///
/// As suggested, returns `null` in the case where the user has not provided an /// As suggested, returns `null` in the case where the user has not provided an

View File

@@ -1,4 +1,5 @@
use crate::{backend::runtime::RuntimeFunctionError, eval::PrimitiveType, ir::Type}; use crate::backend::runtime::RuntimeFunctionError;
use crate::ir::{Name, Type};
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError}; use cranelift_codegen::{isa::LookupError, settings::SetError, CodegenError};
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
@@ -30,7 +31,7 @@ pub enum BackendError {
#[error("Builtin function error: {0}")] #[error("Builtin function error: {0}")]
BuiltinError(#[from] RuntimeFunctionError), BuiltinError(#[from] RuntimeFunctionError),
#[error("Internal variable lookup error")] #[error("Internal variable lookup error")]
VariableLookupFailure(ArcIntern<String>), VariableLookupFailure(Name),
#[error(transparent)] #[error(transparent)]
CodegenError(#[from] CodegenError), CodegenError(#[from] CodegenError),
#[error(transparent)] #[error(transparent)]
@@ -40,15 +41,25 @@ pub enum BackendError {
#[error(transparent)] #[error(transparent)]
Write(#[from] cranelift_object::object::write::Error), Write(#[from] cranelift_object::object::write::Error),
#[error("Invalid type cast from {from} to {to}")] #[error("Invalid type cast from {from} to {to}")]
InvalidTypeCast { from: PrimitiveType, to: Type }, InvalidTypeCast {
from: cranelift_codegen::ir::types::Type,
to: Type,
},
#[error("Unknown string constant '{0}")]
UnknownString(ArcIntern<String>),
#[error("Compiler doesn't currently support function arguments")]
NoFunctionArguments {
function_name: String,
argument_name: String,
},
} }
impl From<BackendError> for Diagnostic<usize> { impl From<BackendError> for Diagnostic<usize> {
fn from(value: BackendError) -> Self { fn from(value: BackendError) -> Self {
match value { match value {
BackendError::Cranelift(me) => { BackendError::Cranelift(me) => Diagnostic::error()
Diagnostic::error().with_message(format!("Internal cranelift error: {}", me)) .with_message(format!("Internal cranelift error: {}", me))
} .with_notes(vec![format!("{:?}", me)]),
BackendError::BuiltinError(me) => { BackendError::BuiltinError(me) => {
Diagnostic::error().with_message(format!("Internal runtime function error: {}", me)) Diagnostic::error().with_message(format!("Internal runtime function error: {}", me))
} }
@@ -69,6 +80,15 @@ impl From<BackendError> for Diagnostic<usize> {
BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message( BackendError::InvalidTypeCast { from, to } => Diagnostic::error().with_message(
format!("Internal error trying to cast from {} to {}", from, to), format!("Internal error trying to cast from {} to {}", from, to),
), ),
BackendError::UnknownString(str) => Diagnostic::error()
.with_message(format!("Unknown string found trying to compile: '{}'", str)),
BackendError::NoFunctionArguments {
function_name,
argument_name,
} => Diagnostic::error().with_message(format!(
"Function {} takes a function argument ({}), which is not supported",
function_name, argument_name
)),
} }
} }
} }
@@ -119,6 +139,22 @@ impl PartialEq for BackendError {
} => from1 == from2 && to1 == to2, } => from1 == from2 && to1 == to2,
_ => false, _ => false,
}, },
BackendError::UnknownString(a) => match other {
BackendError::UnknownString(b) => a == b,
_ => false,
},
BackendError::NoFunctionArguments {
function_name: f1,
argument_name: a1,
} => match other {
BackendError::NoFunctionArguments {
function_name: f2,
argument_name: a2,
} => f1 == f2 && a1 == a2,
_ => false,
},
} }
} }
} }

View File

@@ -1,8 +1,7 @@
use crate::backend::Backend; use crate::backend::Backend;
use crate::eval::EvalError; use crate::eval::EvalError;
use crate::ir::Program; use crate::ir::{Expression, Name, Program, Type};
#[cfg(test)] use crate::syntax::Location;
use crate::syntax::arbitrary::GenerationEnvironment;
use cranelift_jit::JITModule; use cranelift_jit::JITModule;
use cranelift_object::ObjectModule; use cranelift_object::ObjectModule;
#[cfg(test)] #[cfg(test)]
@@ -24,9 +23,12 @@ impl Backend<JITModule> {
/// library do. So, if you're validating equivalence between them, you'll want to weed /// library do. So, if you're validating equivalence between them, you'll want to weed
/// out examples that overflow/underflow before checking equivalence. (This is the behavior /// out examples that overflow/underflow before checking equivalence. (This is the behavior
/// of the built-in test systems.) /// of the built-in test systems.)
pub fn eval(program: Program) -> Result<String, EvalError> { pub fn eval(program: Program<Type>) -> Result<String, EvalError<Expression<Type>>> {
let mut jitter = Backend::jit(Some(String::new()))?; let mut jitter = Backend::jit(Some(String::new()))?;
let function_id = jitter.compile_function("test", program)?; let function_id = jitter.compile_program(
Name::new("___test_jit_eval___", Location::manufactured()),
program,
)?;
jitter.module.finalize_definitions()?; jitter.module.finalize_definitions()?;
let compiled_bytes = jitter.bytes(function_id); let compiled_bytes = jitter.bytes(function_id);
let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; let compiled_function = unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
@@ -51,17 +53,16 @@ impl Backend<ObjectModule> {
/// library do. So, if you're validating equivalence between them, you'll want to weed /// library do. So, if you're validating equivalence between them, you'll want to weed
/// out examples that overflow/underflow before checking equivalence. (This is the behavior /// out examples that overflow/underflow before checking equivalence. (This is the behavior
/// of the built-in test systems.) /// of the built-in test systems.)
pub fn eval(program: Program) -> Result<String, EvalError> { pub fn eval(program: Program<Type>) -> Result<String, EvalError<Expression<Type>>> {
//use pretty::{Arena, Pretty}; //use pretty::{Arena, Pretty};
//let allocator = Arena::<()>::new(); //let allocator = Arena::<()>::new();
//program.pretty(&allocator).render(80, &mut std::io::stdout())?; //program.pretty(&allocator).render(80, &mut std::io::stdout())?;
let mut backend = Self::object_file(Triple::host())?; let mut backend = Self::object_file(Triple::host())?;
let my_directory = tempfile::tempdir()?; let my_directory = tempfile::tempdir()?;
let object_path = my_directory.path().join("object.o"); let object_path = my_directory.path().join("object.o");
let executable_path = my_directory.path().join("test_executable"); let executable_path = my_directory.path().join("test_executable");
backend.compile_function("gogogo", program)?; backend.compile_program(Name::new("gogogo", Location::manufactured()), program)?;
let bytes = backend.bytes()?; let bytes = backend.bytes()?;
std::fs::write(&object_path, bytes)?; std::fs::write(&object_path, bytes)?;
Self::link(&object_path, &executable_path)?; Self::link(&object_path, &executable_path)?;
@@ -91,7 +92,7 @@ impl Backend<ObjectModule> {
/// This function assumes that this compilation and linking should run without any /// This function assumes that this compilation and linking should run without any
/// output, so changes to the RTS should make 100% sure that they do not generate /// output, so changes to the RTS should make 100% sure that they do not generate
/// any compiler warnings. /// any compiler warnings.
fn link(object_file: &Path, executable_path: &Path) -> Result<(), EvalError> { fn link(object_file: &Path, executable_path: &Path) -> Result<(), EvalError<Expression<Type>>> {
use std::path::PathBuf; use std::path::PathBuf;
let output = std::process::Command::new("clang") let output = std::process::Command::new("clang")
@@ -100,12 +101,13 @@ impl Backend<ObjectModule> {
.join("runtime") .join("runtime")
.join("rts.c"), .join("rts.c"),
) )
.arg("-Wl,-ld_classic")
.arg(object_file) .arg(object_file)
.arg("-o") .arg("-o")
.arg(executable_path) .arg(executable_path)
.output()?; .output()?;
if !output.stderr.is_empty() { if !output.status.success() {
return Err(EvalError::Linker( return Err(EvalError::Linker(
std::string::String::from_utf8_lossy(&output.stderr).to_string(), std::string::String::from_utf8_lossy(&output.stderr).to_string(),
)); ));
@@ -120,10 +122,20 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that // without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow. // doesn't involve overflow or underflow.
#[test] #[test]
fn static_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { fn static_backend(program in Program::arbitrary()) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
// use pretty::DocAllocator;
// let allocator = pretty::Arena::new();
// let result = allocator.text("-------------")
// .append(allocator.line())
// .append(program.pretty(&allocator))
// .append(allocator.line());
// result.render_raw(70, &mut pretty::IoWrite::new(std::io::stdout()))
// .expect("rendering works");
let basic_result = program.eval();
let ir_evaluator = crate::ir::Evaluator::default();
let basic_result = ir_evaluator.eval(program.clone()).map(|(_,o)| o);
// windows `printf` is going to terminate lines with "\r\n", so we need to adjust // windows `printf` is going to terminate lines with "\r\n", so we need to adjust
// our test result here. // our test result here.
@@ -131,14 +143,13 @@ proptest::proptest! {
let basic_result = basic_result.map(|x| x.replace('\n', "\r\n")); let basic_result = basic_result.map(|x| x.replace('\n', "\r\n"));
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
// use pretty::{DocAllocator, Pretty}; //use pretty::DocAllocator;
// let allocator = pretty::BoxAllocator; //let allocator = pretty::Arena::new();
// allocator //let result = allocator.text("-------------")
// .text("---------------") // .append(allocator.line())
// .append(allocator.hardline())
// .append(program.pretty(&allocator)) // .append(program.pretty(&allocator))
// .1 // .append(allocator.line());
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) //result.render_raw(70, &mut pretty::IoWrite::new(std::io::stdout()))
// .expect("rendering works"); // .expect("rendering works");
let compiled_result = Backend::<ObjectModule>::eval(program); let compiled_result = Backend::<ObjectModule>::eval(program);
@@ -150,7 +161,7 @@ proptest::proptest! {
// without error, assuming any possible input ... well, any possible input that // without error, assuming any possible input ... well, any possible input that
// doesn't involve overflow or underflow. // doesn't involve overflow or underflow.
#[test] #[test]
fn jit_backend(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { fn jit_backend(program in Program::arbitrary()) {
use crate::eval::PrimOpError; use crate::eval::PrimOpError;
// use pretty::{DocAllocator, Pretty}; // use pretty::{DocAllocator, Pretty};
// let allocator = pretty::BoxAllocator; // let allocator = pretty::BoxAllocator;
@@ -162,8 +173,8 @@ proptest::proptest! {
// .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto)) // .render_colored(70, pretty::termcolor::StandardStream::stdout(pretty::termcolor::ColorChoice::Auto))
// .expect("rendering works"); // .expect("rendering works");
let ir_evaluator = crate::ir::Evaluator::default();
let basic_result = program.eval(); let basic_result = ir_evaluator.eval(program.clone()).map(|(_,o)| o);
if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) { if !matches!(basic_result, Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))) {
let compiled_result = Backend::<JITModule>::eval(program); let compiled_result = Backend::<JITModule>::eval(program);

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@ use cranelift_codegen::ir::{types, AbiParam, FuncRef, Function, Signature};
use cranelift_codegen::isa::CallConv; use cranelift_codegen::isa::CallConv;
use cranelift_jit::JITBuilder; use cranelift_jit::JITBuilder;
use cranelift_module::{FuncId, Linkage, Module, ModuleResult}; use cranelift_module::{FuncId, Linkage, Module, ModuleResult};
use std::alloc::Layout;
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::CStr; use std::ffi::CStr;
use std::fmt::Write; use std::fmt::Write;
@@ -91,7 +92,19 @@ impl RuntimeFunctions {
/// one; both to reduce the chance that they deviate, and to reduce overall /// one; both to reduce the chance that they deviate, and to reduce overall
/// maintenance burden. /// maintenance burden.
pub fn register_jit_implementations(builder: &mut JITBuilder) { pub fn register_jit_implementations(builder: &mut JITBuilder) {
let allocation_pointer = unsafe {
std::alloc::alloc_zeroed(
Layout::from_size_align(1024 * 1024, 1024 * 1024)
.expect("reasonable layout is reasonable"),
)
};
let allocation_pointer_pointer = unsafe {
let res = std::alloc::alloc(Layout::for_value(&allocation_pointer)) as *mut *mut u8;
*res = allocation_pointer;
res as *const u8
};
builder.symbol("print", runtime_print as *const u8); builder.symbol("print", runtime_print as *const u8);
builder.symbol("__global_allocation_pointer__", allocation_pointer_pointer);
} }
} }
@@ -110,6 +123,7 @@ extern "C" fn runtime_print(
let reconstituted = cstr.to_string_lossy(); let reconstituted = cstr.to_string_lossy();
let output = match vtype_repr.try_into() { let output = match vtype_repr.try_into() {
Ok(ConstantType::Void) => format!("{} = <void>", reconstituted),
Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8), Ok(ConstantType::I8) => format!("{} = {}i8", reconstituted, value as i8),
Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16), Ok(ConstantType::I16) => format!("{} = {}i16", reconstituted, value as i16),
Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32), Ok(ConstantType::I32) => format!("{} = {}i32", reconstituted, value as i32),
@@ -118,7 +132,7 @@ extern "C" fn runtime_print(
Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16), Ok(ConstantType::U16) => format!("{} = {}u16", reconstituted, value as u16),
Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32), Ok(ConstantType::U32) => format!("{} = {}u32", reconstituted, value as u32),
Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64), Ok(ConstantType::U64) => format!("{} = {}u64", reconstituted, value as u64),
Err(_) => format!("{} = {}<unknown type>", reconstituted, value), Err(_) => format!("{} = {}<unknown type {}>", reconstituted, value, vtype_repr),
}; };
if let Some(output_buffer) = unsafe { output_buffer.as_mut() } { if let Some(output_buffer) = unsafe { output_buffer.as_mut() } {

23
src/bin/gen-program.rs Normal file
View File

@@ -0,0 +1,23 @@
use ngr::syntax::ProgramGenerator;
use ngr::util::pretty::Allocator;
use proptest::strategy::{Strategy, ValueTree};
use proptest::test_runner::{Config, TestRunner};
fn main() -> Result<(), anyhow::Error> {
let generator = ProgramGenerator::default();
let runner_config = Config::default();
let mut runner = TestRunner::new(runner_config);
let program_tree = generator
.new_tree(&mut runner)
.map_err(|e| anyhow::anyhow!("Couldn't generate test program: {}", e))?;
let program = program_tree.current();
let allocator = Allocator::new();
let mut stdout = std::io::stdout();
for top_level in program.into_iter() {
let docbuilder = top_level.pretty(&allocator);
docbuilder.render(78, &mut stdout)?;
}
Ok(())
}

View File

@@ -14,6 +14,7 @@ struct CommandLineArguments {
} }
fn main() { fn main() {
tracing_subscriber::fmt::init();
let args = CommandLineArguments::parse(); let args = CommandLineArguments::parse();
let mut compiler = ngr::Compiler::default(); let mut compiler = ngr::Compiler::default();

View File

@@ -3,6 +3,7 @@ use rustyline::error::ReadlineError;
use rustyline::DefaultEditor; use rustyline::DefaultEditor;
fn main() -> Result<(), BackendError> { fn main() -> Result<(), BackendError> {
tracing_subscriber::fmt::init();
let mut editor = DefaultEditor::new().expect("rustyline works"); let mut editor = DefaultEditor::new().expect("rustyline works");
let mut line_no = 0; let mut line_no = 0;
let mut state = ngr::REPL::default(); let mut state = ngr::REPL::default();
@@ -19,7 +20,7 @@ fn main() -> Result<(), BackendError> {
// it's not clear to me what this could be, but OK // it's not clear to me what this could be, but OK
Err(ReadlineError::Io(e)) => { Err(ReadlineError::Io(e)) => {
eprintln!("IO error: {}", e); tracing::error!(error = %e, "IO error");
break; break;
} }
@@ -31,7 +32,7 @@ fn main() -> Result<(), BackendError> {
// what would cause this, but ... // what would cause this, but ...
#[cfg(not(windows))] #[cfg(not(windows))]
Err(ReadlineError::Errno(e)) => { Err(ReadlineError::Errno(e)) => {
eprintln!("Unknown syscall error: {}", e); tracing::error!(error = %e, "Unknown syscall.");
break; break;
} }
@@ -41,7 +42,7 @@ fn main() -> Result<(), BackendError> {
// Why on earth are there so many error types? // Why on earth are there so many error types?
Err(e) => { Err(e) => {
eprintln!("Unknown internal error: {}", e); tracing::error!(error = %e, "Unknown internal error");
break; break;
} }
} }

117
src/bin/ngrun.rs Normal file
View File

@@ -0,0 +1,117 @@
use clap::Parser;
use codespan_reporting::files::SimpleFiles;
use ngr::backend::Backend;
use ngr::eval::Value;
use ngr::syntax::{self, Location, Name};
use ngr::type_infer::TypeInferenceResult;
use pretty::termcolor::StandardStream;
use tracing_subscriber::prelude::*;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct CommandLineArguments {
/// Which interpreter to use: syntax, ir, or jit
#[arg(value_enum)]
interpreter: Interpreter,
/// The file to parse
file: String,
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum Interpreter {
/// Run the syntax-level interpreter
Syntax,
/// Run the IR-level interpreter
IR,
/// Run the JIT backend
JIT,
}
fn print_result<E>(result: (Value<E>, String)) {
println!("{}", result.1);
println!("RESULT: {}", result.0);
}
fn jit(ir: ngr::ir::Program<ngr::ir::Type>) -> Result<fn(), ngr::backend::BackendError> {
let mut backend = Backend::jit(None)?;
let function_id = backend.compile_program(Name::new("gogogo", Location::manufactured()), ir)?;
backend.module.finalize_definitions()?;
let compiled_bytes = backend.bytes(function_id);
Ok(unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) })
}
fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(tracing_subscriber::EnvFilter::from_default_env())
.init();
let cli = CommandLineArguments::parse();
let mut file_database = SimpleFiles::new();
let mut console = StandardStream::stdout(pretty::termcolor::ColorChoice::Auto);
let console_options = codespan_reporting::term::Config::default();
let syntax = syntax::parse_file(&mut file_database, cli.file.as_ref());
let mut emit = |x| {
let _ = codespan_reporting::term::emit(&mut console, &console_options, &file_database, &x);
};
let syntax = match syntax {
Ok(x) => x,
Err(e) => {
emit((&e).into());
return;
}
};
let mut validation_result = syntax::Program::validate(syntax);
for item in validation_result.diagnostics() {
emit(item);
}
if validation_result.is_err() {
return;
}
let syntax = validation_result
.into_result()
.expect("we already checked this");
if cli.interpreter == Interpreter::Syntax {
match syntax.eval() {
Err(e) => tracing::error!(error = %e, "Evaluation error"),
Ok(v) => print_result(v),
}
return;
}
let ir = match syntax.type_infer() {
TypeInferenceResult::Success { result, warnings } => {
for warning in warnings {
emit(warning.into());
}
result
}
TypeInferenceResult::Failure { errors, warnings } => {
for warning in warnings {
emit(warning.into());
}
for error in errors {
emit(error.into());
}
return;
}
};
if cli.interpreter == Interpreter::IR {
let evaluator = ngr::ir::Evaluator::default();
match evaluator.eval(ir) {
Err(e) => tracing::error!(error = %e, "Evaluation error"),
Ok(v) => print_result(v),
}
return;
}
match jit(ir) {
Err(e) => emit(e.into()),
Ok(compiled_function) => compiled_function(),
}
}

View File

@@ -1,10 +1,12 @@
use crate::syntax::Program as Syntax; use crate::backend::Backend;
use crate::{backend::Backend, type_infer::TypeInferenceResult}; use crate::syntax::{Location, Name, Program as Syntax};
use crate::type_infer::TypeInferenceResult;
use codespan_reporting::{ use codespan_reporting::{
diagnostic::Diagnostic, diagnostic::Diagnostic,
files::SimpleFiles, files::SimpleFiles,
term::{self, Config}, term::{self, Config},
}; };
use cranelift_module::Module;
use pretty::termcolor::{ColorChoice, StandardStream}; use pretty::termcolor::{ColorChoice, StandardStream};
use target_lexicon::Triple; use target_lexicon::Triple;
@@ -75,29 +77,23 @@ impl Compiler {
fn compile_internal(&mut self, input_file: &str) -> Result<Option<Vec<u8>>, CompilerError> { fn compile_internal(&mut self, input_file: &str) -> Result<Option<Vec<u8>>, CompilerError> {
// Try to parse the file into our syntax AST. If we fail, emit the error // Try to parse the file into our syntax AST. If we fail, emit the error
// and then immediately return `None`. // and then immediately return `None`.
let syntax = Syntax::parse_file(&mut self.file_database, input_file)?; let raw_syntax = crate::syntax::parse_file(&mut self.file_database, input_file)?;
// Now validate the user's syntax AST. This can possibly find errors and/or // Now validate the user's syntax AST. This can possibly find errors and/or
// create warnings. We can continue if we only get warnings, but need to stop // create warnings. We can continue if we only get warnings, but need to stop
// if we get any errors. // if we get any errors.
let (mut errors, mut warnings) = syntax.validate(); let mut validation_result = Syntax::validate(raw_syntax);
let stop = !errors.is_empty();
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
// emit all the messages we receive; warnings *and* errors // emit all the messages we receive; warnings *and* errors
for message in messages { for message in validation_result.diagnostics() {
self.emit(message); self.emit(message);
} }
// we got errors, so just stop right now. perhaps oddly, this is Ok(None); // we got errors, so just stop right now. perhaps oddly, this is Ok(None);
// we've already said all we're going to say in the messags above, so there's // we've already said all we're going to say in the messags above, so there's
// no need to provide another `Err` result. // no need to provide another `Err` result.
if stop { let Some(syntax) = validation_result.into_result() else {
return Ok(None); return Ok(None);
} };
// Now that we've validated it, let's do type inference, potentially turning // Now that we've validated it, let's do type inference, potentially turning
// into IR while we're at it. // into IR while we're at it.
@@ -134,7 +130,15 @@ impl Compiler {
// Finally, send all this to Cranelift for conversion into an object file. // Finally, send all this to Cranelift for conversion into an object file.
let mut backend = Backend::object_file(Triple::host())?; let mut backend = Backend::object_file(Triple::host())?;
backend.compile_function("gogogo", ir)?; let unknown = "<unknown>".to_string();
backend.compile_program(Name::new("gogogo", Location::manufactured()), ir)?;
for (_, decl) in backend.module.declarations().get_functions() {
tracing::debug!(name = %decl.name.as_ref().unwrap_or(&unknown), linkage = ?decl.linkage, "function definition");
}
for (_, decl) in backend.module.declarations().get_data_objects() {
tracing::debug!(name = %decl.name.as_ref().unwrap_or(&unknown), linkage = ?decl.linkage, "data definition");
}
Ok(Some(backend.bytes()?)) Ok(Some(backend.bytes()?))
} }

View File

@@ -33,19 +33,21 @@
//! because the implementation of some parts of these primitives is really //! because the implementation of some parts of these primitives is really
//! awful to look at. //! awful to look at.
//! //!
mod env; //mod env;
mod primop; mod primop;
mod primtype; mod primtype;
mod value; mod value;
use crate::syntax::Name;
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
pub use env::{EvalEnvironment, LookupError};
pub use primop::PrimOpError; pub use primop::PrimOpError;
pub use primtype::PrimitiveType; pub use primtype::PrimitiveType;
pub use value::Value; pub use value::Value;
use crate::backend::BackendError; use crate::backend::BackendError;
use self::primtype::UnknownPrimType;
/// All of the errors that can happen trying to evaluate an NGR program. /// All of the errors that can happen trying to evaluate an NGR program.
/// ///
/// This is yet another standard [`thiserror::Error`] type, but with the /// This is yet another standard [`thiserror::Error`] type, but with the
@@ -54,11 +56,9 @@ use crate::backend::BackendError;
/// of converting those errors to strings and then seeing if they're the /// of converting those errors to strings and then seeing if they're the
/// same. /// same.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum EvalError { pub enum EvalError<IR> {
#[error(transparent)] #[error(transparent)]
Lookup(#[from] LookupError), PrimOp(#[from] PrimOpError<IR>),
#[error(transparent)]
PrimOp(#[from] PrimOpError),
#[error(transparent)] #[error(transparent)]
Backend(#[from] BackendError), Backend(#[from] BackendError),
#[error("IO error: {0}")] #[error("IO error: {0}")]
@@ -71,16 +71,29 @@ pub enum EvalError {
ExitCode(std::process::ExitStatus), ExitCode(std::process::ExitStatus),
#[error("Unexpected output at runtime: {0}")] #[error("Unexpected output at runtime: {0}")]
RuntimeOutput(String), RuntimeOutput(String),
#[error("Cannot cast to function type: {0}")]
CastToFunction(String),
#[error(transparent)]
UnknownPrimType(#[from] UnknownPrimType),
#[error("Variable lookup failed for {1} at {0:?}")]
LookupFailed(crate::syntax::Location, String),
#[error("Attempted to call something that wasn't a function at {0:?} (it was a {1})")]
NotAFunction(crate::syntax::Location, Value<IR>),
#[error("Wrong argument call for function ({1:?}) at {0:?}; expected {2}, saw {3}")]
WrongArgCount(crate::syntax::Location, Option<Name>, usize, usize),
#[error("Value has no fields {1} (attempt to get field {2} at {0:?})")]
NoFieldForValue(crate::syntax::Location, Value<IR>, Name),
#[error("Bad field {2} for structure {1:?} at {0:?}")]
BadFieldForStructure(crate::syntax::Location, Option<Name>, Name),
} }
impl PartialEq for EvalError { impl<IR1: Clone, IR2: Clone> PartialEq<EvalError<IR1>> for EvalError<IR2> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &EvalError<IR1>) -> bool {
match self { match self {
EvalError::Lookup(a) => match other { EvalError::LookupFailed(a, b) => match other {
EvalError::Lookup(b) => a == b, EvalError::LookupFailed(x, y) => a == x && b == y,
_ => false, _ => false,
}, },
EvalError::PrimOp(a) => match other { EvalError::PrimOp(a) => match other {
EvalError::PrimOp(b) => a == b, EvalError::PrimOp(b) => a == b,
_ => false, _ => false,
@@ -115,6 +128,36 @@ impl PartialEq for EvalError {
EvalError::RuntimeOutput(b) => a == b, EvalError::RuntimeOutput(b) => a == b,
_ => false, _ => false,
}, },
EvalError::CastToFunction(a) => match other {
EvalError::CastToFunction(b) => a == b,
_ => false,
},
EvalError::UnknownPrimType(a) => match other {
EvalError::UnknownPrimType(b) => a == b,
_ => false,
},
EvalError::NotAFunction(a, b) => match other {
EvalError::NotAFunction(x, y) => a == x && b == y,
_ => false,
},
EvalError::WrongArgCount(a, b, c, d) => match other {
EvalError::WrongArgCount(w, x, y, z) => a == w && b == x && c == y && d == z,
_ => false,
},
EvalError::NoFieldForValue(a, b, c) => match other {
EvalError::NoFieldForValue(x, y, z) => a == x && b == y && c == z,
_ => false,
},
EvalError::BadFieldForStructure(a, b, c) => match other {
EvalError::BadFieldForStructure(x, y, z) => a == x && b == y && c == z,
_ => false,
},
} }
} }
} }

View File

@@ -1,20 +1,22 @@
use crate::eval::primtype::PrimitiveType; use crate::eval::primtype::PrimitiveType;
use crate::eval::value::Value; use crate::eval::value::Value;
use super::primtype::{UnknownPrimType, ValuePrimitiveTypeError};
/// Errors that can occur running primitive operations in the evaluators. /// Errors that can occur running primitive operations in the evaluators.
#[derive(Clone, Debug, PartialEq, thiserror::Error)] #[derive(Clone, Debug, thiserror::Error)]
pub enum PrimOpError { pub enum PrimOpError<IR> {
#[error("Math error (underflow or overflow) computing {0} operator")] #[error("Math error (underflow or overflow) computing {0} operator")]
MathFailure(&'static str), MathFailure(&'static str),
/// This particular variant covers the case in which a primitive /// This particular variant covers the case in which a primitive
/// operator takes two arguments that are supposed to be the same, /// operator takes two arguments that are supposed to be the same,
/// but they differ. (So, like, all the math operators.) /// but they differ. (So, like, all the math operators.)
#[error("Type mismatch ({1} vs {2}) computing {0} operator")] #[error("Type mismatch ({1} vs {2}) computing {0} operator")]
TypeMismatch(String, Value, Value), TypeMismatch(String, Value<IR>, Value<IR>),
/// This variant covers when an operator must take a particular /// This variant covers when an operator must take a particular
/// type, but the user has provided a different one. /// type, but the user has provided a different one.
#[error("Bad type for operator {0}: {1}")] #[error("Bad type for operator {0}: {1}")]
BadTypeFor(&'static str, Value), BadTypeFor(String, Value<IR>),
/// Probably obvious from the name, but just to be very clear: this /// Probably obvious from the name, but just to be very clear: this
/// happens when you pass three arguments to a two argument operator, /// happens when you pass three arguments to a two argument operator,
/// etc. Technically that's a type error of some sort, but we split /// etc. Technically that's a type error of some sort, but we split
@@ -28,8 +30,35 @@ pub enum PrimOpError {
from: PrimitiveType, from: PrimitiveType,
to: PrimitiveType, to: PrimitiveType,
}, },
#[error("Unknown primitive type {0}")] #[error(transparent)]
UnknownPrimType(String), UnknownPrimType(#[from] UnknownPrimType),
#[error(transparent)]
ValuePrimitiveTypeError(#[from] ValuePrimitiveTypeError),
}
impl<IR1: Clone, IR2: Clone> PartialEq<PrimOpError<IR2>> for PrimOpError<IR1> {
fn eq(&self, other: &PrimOpError<IR2>) -> bool {
match (self, other) {
(PrimOpError::MathFailure(a), PrimOpError::MathFailure(b)) => a == b,
(PrimOpError::TypeMismatch(a, b, c), PrimOpError::TypeMismatch(x, y, z)) => {
a == x && b.strip() == y.strip() && c.strip() == z.strip()
}
(PrimOpError::BadTypeFor(a, b), PrimOpError::BadTypeFor(x, y)) => {
a == x && b.strip() == y.strip()
}
(PrimOpError::BadArgCount(a, b), PrimOpError::BadArgCount(x, y)) => a == x && b == y,
(PrimOpError::UnknownPrimOp(a), PrimOpError::UnknownPrimOp(x)) => a == x,
(
PrimOpError::UnsafeCast { from: a, to: b },
PrimOpError::UnsafeCast { from: x, to: y },
) => a == x && b == y,
(PrimOpError::UnknownPrimType(a), PrimOpError::UnknownPrimType(x)) => a == x,
(PrimOpError::ValuePrimitiveTypeError(a), PrimOpError::ValuePrimitiveTypeError(x)) => {
a == x
}
_ => false,
}
}
} }
// Implementing primitives in an interpreter like this is *super* tedious, // Implementing primitives in an interpreter like this is *super* tedious,
@@ -55,24 +84,29 @@ macro_rules! run_op {
}; };
} }
impl Value { impl<IR: Clone> Value<IR> {
fn unary_op(operation: &str, value: &Value) -> Result<Value, PrimOpError> { fn unary_op(operation: &str, value: &Value<IR>) -> Result<Value<IR>, PrimOpError<IR>> {
match operation { match operation {
"-" => match value { "negate" => match value {
Value::I8(x) => Ok(Value::I8(x.wrapping_neg())), Value::I8(x) => Ok(Value::I8(x.wrapping_neg())),
Value::I16(x) => Ok(Value::I16(x.wrapping_neg())), Value::I16(x) => Ok(Value::I16(x.wrapping_neg())),
Value::I32(x) => Ok(Value::I32(x.wrapping_neg())), Value::I32(x) => Ok(Value::I32(x.wrapping_neg())),
Value::I64(x) => Ok(Value::I64(x.wrapping_neg())), Value::I64(x) => Ok(Value::I64(x.wrapping_neg())),
_ => Err(PrimOpError::BadTypeFor("-", value.clone())), _ => Err(PrimOpError::BadTypeFor("-".to_string(), value.clone())),
}, },
_ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)), _ => Err(PrimOpError::BadArgCount(operation.to_owned(), 1)),
} }
} }
fn binary_op(operation: &str, left: &Value, right: &Value) -> Result<Value, PrimOpError> { fn binary_op(
operation: &str,
left: &Value<IR>,
right: &Value<IR>,
) -> Result<Value<IR>, PrimOpError<IR>> {
match left { match left {
Value::I8(x) => match right { Value::I8(x) => match right {
Value::I8(y) => run_op!(operation, x, *y), Value::I8(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as i8),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -81,6 +115,7 @@ impl Value {
}, },
Value::I16(x) => match right { Value::I16(x) => match right {
Value::I16(y) => run_op!(operation, x, *y), Value::I16(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as i16),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -89,6 +124,7 @@ impl Value {
}, },
Value::I32(x) => match right { Value::I32(x) => match right {
Value::I32(y) => run_op!(operation, x, *y), Value::I32(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as i32),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -97,6 +133,7 @@ impl Value {
}, },
Value::I64(x) => match right { Value::I64(x) => match right {
Value::I64(y) => run_op!(operation, x, *y), Value::I64(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as i64),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -105,6 +142,7 @@ impl Value {
}, },
Value::U8(x) => match right { Value::U8(x) => match right {
Value::U8(y) => run_op!(operation, x, *y), Value::U8(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as u8),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -113,6 +151,7 @@ impl Value {
}, },
Value::U16(x) => match right { Value::U16(x) => match right {
Value::U16(y) => run_op!(operation, x, *y), Value::U16(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as u16),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -121,6 +160,7 @@ impl Value {
}, },
Value::U32(x) => match right { Value::U32(x) => match right {
Value::U32(y) => run_op!(operation, x, *y), Value::U32(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y as u32),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
@@ -129,12 +169,33 @@ impl Value {
}, },
Value::U64(x) => match right { Value::U64(x) => match right {
Value::U64(y) => run_op!(operation, x, *y), Value::U64(y) => run_op!(operation, x, *y),
Value::Number(y) => run_op!(operation, x, *y),
_ => Err(PrimOpError::TypeMismatch( _ => Err(PrimOpError::TypeMismatch(
operation.to_string(), operation.to_string(),
left.clone(), left.clone(),
right.clone(), right.clone(),
)), )),
}, },
Value::Number(x) => match right {
Value::Number(y) => run_op!(operation, x, *y),
Value::U8(y) => run_op!(operation, (*x as u8), *y),
Value::U16(y) => run_op!(operation, (*x as u16), *y),
Value::U32(y) => run_op!(operation, (*x as u32), *y),
Value::U64(y) => run_op!(operation, x, *y),
Value::I8(y) => run_op!(operation, (*x as i8), *y),
Value::I16(y) => run_op!(operation, (*x as i16), *y),
Value::I32(y) => run_op!(operation, (*x as i32), *y),
Value::I64(y) => run_op!(operation, (*x as i64), *y),
_ => Err(PrimOpError::TypeMismatch(
operation.to_string(),
left.clone(),
right.clone(),
)),
},
Value::Closure(_, _, _, _)
| Value::Structure(_, _)
| Value::Primitive(_)
| Value::Void => Err(PrimOpError::BadTypeFor(operation.to_string(), left.clone())),
} }
} }
@@ -146,7 +207,10 @@ impl Value {
/// implementation catches and raises an error on overflow or underflow, so /// implementation catches and raises an error on overflow or underflow, so
/// its worth being careful to make sure that your inputs won't cause either /// its worth being careful to make sure that your inputs won't cause either
/// condition. /// condition.
pub fn calculate(operation: &str, values: Vec<Value>) -> Result<Value, PrimOpError> { pub fn calculate(
operation: &str,
values: Vec<Value<IR>>,
) -> Result<Value<IR>, PrimOpError<IR>> {
match values.len() { match values.len() {
1 => Value::unary_op(operation, &values[0]), 1 => Value::unary_op(operation, &values[0]),
2 => Value::binary_op(operation, &values[0], &values[1]), 2 => Value::binary_op(operation, &values[0], &values[1]),

View File

@@ -2,10 +2,12 @@ use crate::{
eval::{PrimOpError, Value}, eval::{PrimOpError, Value},
syntax::ConstantType, syntax::ConstantType,
}; };
use pretty::{Arena, DocAllocator, DocBuilder};
use std::{fmt::Display, str::FromStr}; use std::{fmt::Display, str::FromStr};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum PrimitiveType { pub enum PrimitiveType {
Void,
U8, U8,
U16, U16,
U32, U32,
@@ -16,32 +18,68 @@ pub enum PrimitiveType {
I64, I64,
} }
impl Display for PrimitiveType { impl PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { pub fn pretty<'a>(&self, allocator: &'a Arena<'a, ()>) -> DocBuilder<'a, Arena<'a, ()>> {
match self { match self {
PrimitiveType::I8 => write!(f, "i8"), PrimitiveType::Void => allocator.text("void"),
PrimitiveType::I16 => write!(f, "i16"), PrimitiveType::I8 => allocator.text("i8"),
PrimitiveType::I32 => write!(f, "i32"), PrimitiveType::I16 => allocator.text("i16"),
PrimitiveType::I64 => write!(f, "i64"), PrimitiveType::I32 => allocator.text("i32"),
PrimitiveType::U8 => write!(f, "u8"), PrimitiveType::I64 => allocator.text("i64"),
PrimitiveType::U16 => write!(f, "u16"), PrimitiveType::U8 => allocator.text("u8"),
PrimitiveType::U32 => write!(f, "u32"), PrimitiveType::U16 => allocator.text("u16"),
PrimitiveType::U64 => write!(f, "u64"), PrimitiveType::U32 => allocator.text("u32"),
PrimitiveType::U64 => allocator.text("u64"),
} }
} }
} }
impl<'a> From<&'a Value> for PrimitiveType { impl Display for PrimitiveType {
fn from(value: &Value) -> Self { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let arena = Arena::new();
let doc = self.pretty(&arena);
doc.render_fmt(72, f)
}
}
#[allow(clippy::enum_variant_names)]
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
pub enum ValuePrimitiveTypeError {
#[error("Could not convert function value to primitive type (possible function name: {0:?}")]
CannotConvertFunction(Option<String>),
#[error("Could not convert structure value to primitive type (possible function name: {0:?}")]
CannotConvertStructure(Option<String>),
#[error(
"Could not convert primitive operator to primitive type (possible function name: {0:?}"
)]
CannotConvertPrimitive(String),
}
impl<'a, IR> TryFrom<&'a Value<IR>> for PrimitiveType {
type Error = ValuePrimitiveTypeError;
fn try_from(value: &'a Value<IR>) -> Result<Self, Self::Error> {
match value { match value {
Value::I8(_) => PrimitiveType::I8, Value::Void => Ok(PrimitiveType::Void),
Value::I16(_) => PrimitiveType::I16, Value::I8(_) => Ok(PrimitiveType::I8),
Value::I32(_) => PrimitiveType::I32, Value::I16(_) => Ok(PrimitiveType::I16),
Value::I64(_) => PrimitiveType::I64, Value::I32(_) => Ok(PrimitiveType::I32),
Value::U8(_) => PrimitiveType::U8, Value::I64(_) => Ok(PrimitiveType::I64),
Value::U16(_) => PrimitiveType::U16, Value::U8(_) => Ok(PrimitiveType::U8),
Value::U32(_) => PrimitiveType::U32, Value::U16(_) => Ok(PrimitiveType::U16),
Value::U64(_) => PrimitiveType::U64, Value::U32(_) => Ok(PrimitiveType::U32),
Value::U64(_) => Ok(PrimitiveType::U64),
// not sure this is the right call
Value::Number(_) => Ok(PrimitiveType::U64),
Value::Closure(name, _, _, _) => Err(ValuePrimitiveTypeError::CannotConvertFunction(
name.as_ref().map(|x| x.current_name().to_string()),
)),
Value::Structure(name, _) => Err(ValuePrimitiveTypeError::CannotConvertStructure(
name.as_ref().map(|x| x.current_name().to_string()),
)),
Value::Primitive(prim) => Err(ValuePrimitiveTypeError::CannotConvertPrimitive(
prim.clone(),
)),
} }
} }
} }
@@ -49,6 +87,7 @@ impl<'a> From<&'a Value> for PrimitiveType {
impl From<ConstantType> for PrimitiveType { impl From<ConstantType> for PrimitiveType {
fn from(value: ConstantType) -> Self { fn from(value: ConstantType) -> Self {
match value { match value {
ConstantType::Void => PrimitiveType::Void,
ConstantType::I8 => PrimitiveType::I8, ConstantType::I8 => PrimitiveType::I8,
ConstantType::I16 => PrimitiveType::I16, ConstantType::I16 => PrimitiveType::I16,
ConstantType::I32 => PrimitiveType::I32, ConstantType::I32 => PrimitiveType::I32,
@@ -61,8 +100,30 @@ impl From<ConstantType> for PrimitiveType {
} }
} }
impl From<PrimitiveType> for ConstantType {
fn from(value: PrimitiveType) -> Self {
match value {
PrimitiveType::Void => ConstantType::Void,
PrimitiveType::I8 => ConstantType::I8,
PrimitiveType::I16 => ConstantType::I16,
PrimitiveType::I32 => ConstantType::I32,
PrimitiveType::I64 => ConstantType::I64,
PrimitiveType::U8 => ConstantType::U8,
PrimitiveType::U16 => ConstantType::U16,
PrimitiveType::U32 => ConstantType::U32,
PrimitiveType::U64 => ConstantType::U64,
}
}
}
#[derive(thiserror::Error, Debug, Clone, PartialEq)]
pub enum UnknownPrimType {
#[error("Could not convert '{0}' into a primitive type")]
UnknownPrimType(String),
}
impl FromStr for PrimitiveType { impl FromStr for PrimitiveType {
type Err = PrimOpError; type Err = UnknownPrimType;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
match s { match s {
@@ -74,49 +135,50 @@ impl FromStr for PrimitiveType {
"u16" => Ok(PrimitiveType::U16), "u16" => Ok(PrimitiveType::U16),
"u32" => Ok(PrimitiveType::U32), "u32" => Ok(PrimitiveType::U32),
"u64" => Ok(PrimitiveType::U64), "u64" => Ok(PrimitiveType::U64),
_ => Err(PrimOpError::UnknownPrimType(s.to_string())), "void" => Ok(PrimitiveType::Void),
_ => Err(UnknownPrimType::UnknownPrimType(s.to_owned())),
} }
} }
} }
impl PrimitiveType { impl PrimitiveType {
/// Return the set of types that this type can be safely cast to
pub fn allowed_casts(&self) -> &'static [PrimitiveType] {
match self {
PrimitiveType::Void => &[PrimitiveType::Void],
PrimitiveType::U8 => &[
PrimitiveType::U8,
PrimitiveType::U16,
PrimitiveType::U32,
PrimitiveType::U64,
PrimitiveType::I16,
PrimitiveType::I32,
PrimitiveType::I64,
],
PrimitiveType::U16 => &[
PrimitiveType::U16,
PrimitiveType::U32,
PrimitiveType::U64,
PrimitiveType::I32,
PrimitiveType::I64,
],
PrimitiveType::U32 => &[PrimitiveType::U32, PrimitiveType::U64, PrimitiveType::I64],
PrimitiveType::U64 => &[PrimitiveType::U64],
PrimitiveType::I8 => &[
PrimitiveType::I8,
PrimitiveType::I16,
PrimitiveType::I32,
PrimitiveType::I64,
],
PrimitiveType::I16 => &[PrimitiveType::I16, PrimitiveType::I32, PrimitiveType::I64],
PrimitiveType::I32 => &[PrimitiveType::I32, PrimitiveType::I64],
PrimitiveType::I64 => &[PrimitiveType::I64],
}
}
/// Return true if this type can be safely cast into the target type. /// Return true if this type can be safely cast into the target type.
pub fn can_cast_to(&self, target: &PrimitiveType) -> bool { pub fn can_cast_to(&self, target: &PrimitiveType) -> bool {
match self { self.allowed_casts().contains(target)
PrimitiveType::U8 => matches!(
target,
PrimitiveType::U8
| PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I16
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U16 => matches!(
target,
PrimitiveType::U16
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::I32
| PrimitiveType::I64
),
PrimitiveType::U32 => matches!(
target,
PrimitiveType::U32 | PrimitiveType::U64 | PrimitiveType::I64
),
PrimitiveType::U64 => target == &PrimitiveType::U64,
PrimitiveType::I8 => matches!(
target,
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I16 => matches!(
target,
PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
),
PrimitiveType::I32 => matches!(target, PrimitiveType::I32 | PrimitiveType::I64),
PrimitiveType::I64 => target == &PrimitiveType::I64,
}
} }
/// Try to cast the given value to this type, returning the new value. /// Try to cast the given value to this type, returning the new value.
@@ -127,7 +189,7 @@ impl PrimitiveType {
/// type to the target type. (So, for example, "1i64" is a number that could /// type to the target type. (So, for example, "1i64" is a number that could
/// work as a "u64", but since negative numbers wouldn't work, a cast from /// work as a "u64", but since negative numbers wouldn't work, a cast from
/// "1i64" to "u64" will fail.) /// "1i64" to "u64" will fail.)
pub fn safe_cast(&self, source: &Value) -> Result<Value, PrimOpError> { pub fn safe_cast<IR>(&self, source: &Value<IR>) -> Result<Value<IR>, PrimOpError<IR>> {
match (self, source) { match (self, source) {
(PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)), (PrimitiveType::U8, Value::U8(x)) => Ok(Value::U8(*x)),
(PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)), (PrimitiveType::U16, Value::U8(x)) => Ok(Value::U16(*x as u16)),
@@ -151,23 +213,45 @@ impl PrimitiveType {
(PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)), (PrimitiveType::I64, Value::I32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)), (PrimitiveType::I64, Value::I64(x)) => Ok(Value::I64(*x)),
(PrimitiveType::I16, Value::U8(x)) => Ok(Value::I16(*x as i16)),
(PrimitiveType::I32, Value::U8(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I64, Value::U8(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I32, Value::U16(x)) => Ok(Value::I32(*x as i32)),
(PrimitiveType::I64, Value::U16(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::I64, Value::U32(x)) => Ok(Value::I64(*x as i64)),
(PrimitiveType::Void, Value::Void) => Ok(Value::Void),
_ => Err(PrimOpError::UnsafeCast { _ => Err(PrimOpError::UnsafeCast {
from: source.into(), from: PrimitiveType::try_from(source)?,
to: *self, to: *self,
}), }),
} }
} }
pub fn max_value(&self) -> u64 { pub fn max_value(&self) -> Option<u64> {
match self { match self {
PrimitiveType::U8 => u8::MAX as u64, PrimitiveType::Void => None,
PrimitiveType::U16 => u16::MAX as u64, PrimitiveType::U8 => Some(u8::MAX as u64),
PrimitiveType::U32 => u32::MAX as u64, PrimitiveType::U16 => Some(u16::MAX as u64),
PrimitiveType::U64 => u64::MAX, PrimitiveType::U32 => Some(u32::MAX as u64),
PrimitiveType::I8 => i8::MAX as u64, PrimitiveType::U64 => Some(u64::MAX),
PrimitiveType::I16 => i16::MAX as u64, PrimitiveType::I8 => Some(i8::MAX as u64),
PrimitiveType::I32 => i32::MAX as u64, PrimitiveType::I16 => Some(i16::MAX as u64),
PrimitiveType::I64 => i64::MAX as u64, PrimitiveType::I32 => Some(i32::MAX as u64),
PrimitiveType::I64 => Some(i64::MAX as u64),
}
}
pub fn valid_operators(&self) -> &'static [(&'static str, usize)] {
match self {
PrimitiveType::Void => &[],
PrimitiveType::U8 | PrimitiveType::U16 | PrimitiveType::U32 | PrimitiveType::U64 => {
&[("+", 2), ("-", 2), ("*", 2), ("/", 2)]
}
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64 => {
&[("+", 2), ("negate", 1), ("-", 2), ("*", 2), ("/", 2)]
}
} }
} }
} }

View File

@@ -1,12 +1,16 @@
use std::fmt::Display; use crate::syntax::Name;
use crate::util::scoped_map::ScopedMap;
use std::collections::HashMap;
use std::fmt;
/// Values in the interpreter. /// Values in the interpreter.
/// ///
/// Yes, this is yet another definition of a structure called `Value`, which /// Yes, this is yet another definition of a structure called `Value`, which
/// are almost entirely identical. However, it's nice to have them separated /// are almost entirely identical. However, it's nice to have them separated
/// by type so that we don't mix them up. /// by type so that we don't mix them up.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone)]
pub enum Value { pub enum Value<IR> {
Void,
I8(i8), I8(i8),
I16(i16), I16(i16),
I32(i32), I32(i32),
@@ -15,11 +19,51 @@ pub enum Value {
U16(u16), U16(u16),
U32(u32), U32(u32),
U64(u64), U64(u64),
// a number of unknown type
Number(u64),
Closure(Option<Name>, ScopedMap<Name, Value<IR>>, Vec<Name>, IR),
Structure(Option<Name>, HashMap<Name, Value<IR>>),
Primitive(String),
} }
impl Display for Value { impl<IR: Clone> Value<IR> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { /// Given a Value associated with some expression type, just strip out
/// the expressions and replace them with unit.
///
/// Doing this transformation will likely make this value useless for
/// computation, but is very useful in allowing equivalence checks.
pub fn strip(&self) -> Value<()> {
match self { match self {
Value::Void => Value::Void,
Value::U8(x) => Value::U8(*x),
Value::U16(x) => Value::U16(*x),
Value::U32(x) => Value::U32(*x),
Value::U64(x) => Value::U64(*x),
Value::I8(x) => Value::I8(*x),
Value::I16(x) => Value::I16(*x),
Value::I32(x) => Value::I32(*x),
Value::I64(x) => Value::I64(*x),
Value::Number(x) => Value::Number(*x),
Value::Closure(name, env, args, _) => {
let new_env = env.clone().map_values(|x| x.strip());
Value::Closure(name.clone(), new_env, args.clone(), ())
}
Value::Structure(name, fields) => Value::Structure(
name.clone(),
fields.iter().map(|(n, v)| (n.clone(), v.strip())).collect(),
),
Value::Primitive(name) => Value::Primitive(name.clone()),
}
}
pub fn primitive<S: ToString>(name: S) -> Self {
Value::Primitive(name.to_string())
}
}
fn format_value<IR>(value: &Value<IR>, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match value {
Value::Void => write!(f, "<void>"),
Value::I8(x) => write!(f, "{}i8", x), Value::I8(x) => write!(f, "{}i8", x),
Value::I16(x) => write!(f, "{}i16", x), Value::I16(x) => write!(f, "{}i16", x),
Value::I32(x) => write!(f, "{}i32", x), Value::I32(x) => write!(f, "{}i32", x),
@@ -28,53 +72,149 @@ impl Display for Value {
Value::U16(x) => write!(f, "{}u16", x), Value::U16(x) => write!(f, "{}u16", x),
Value::U32(x) => write!(f, "{}u32", x), Value::U32(x) => write!(f, "{}u32", x),
Value::U64(x) => write!(f, "{}u64", x), Value::U64(x) => write!(f, "{}u64", x),
Value::Number(x) => write!(f, "{}", x),
Value::Closure(Some(name), _, _, _) => write!(f, "<function {}>", name),
Value::Closure(None, _, _, _) => write!(f, "<function>"),
Value::Structure(on, fields) => {
if let Some(n) = on {
write!(f, "{}", n.current_name())?;
}
write!(f, "{{")?;
for (n, v) in fields.iter() {
write!(f, " {}: {},", n, v)?;
}
write!(f, " }}")
}
Value::Primitive(n) => write!(f, "{}", n),
}
}
impl<IR> fmt::Debug for Value<IR> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
format_value(self, f)
}
}
impl<IR> fmt::Display for Value<IR> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
format_value(self, f)
}
}
impl<IR1, IR2> PartialEq<Value<IR2>> for Value<IR1> {
fn eq(&self, other: &Value<IR2>) -> bool {
match self {
Value::Void => matches!(other, Value::Void),
Value::I8(x) => match other {
Value::I8(y) => x == y,
_ => false,
},
Value::I16(x) => match other {
Value::I16(y) => x == y,
_ => false,
},
Value::I32(x) => match other {
Value::I32(y) => x == y,
_ => false,
},
Value::I64(x) => match other {
Value::I64(y) => x == y,
_ => false,
},
Value::U8(x) => match other {
Value::U8(y) => x == y,
_ => false,
},
Value::U16(x) => match other {
Value::U16(y) => x == y,
_ => false,
},
Value::U32(x) => match other {
Value::U32(y) => x == y,
_ => false,
},
Value::U64(x) => match other {
Value::U64(y) => x == y,
_ => false,
},
Value::Number(x) => match other {
Value::I8(y) => (*x as i8) == *y,
Value::I16(y) => (*x as i16) == *y,
Value::I32(y) => (*x as i32) == *y,
Value::I64(y) => (*x as i64) == *y,
Value::U8(y) => (*x as u8) == *y,
Value::U16(y) => (*x as u16) == *y,
Value::U32(y) => (*x as u32) == *y,
Value::U64(y) => x == y,
Value::Number(y) => x == y,
_ => false,
},
Value::Closure(_, _, _, _) => false,
Value::Structure(on1, fields1) => match other {
Value::Structure(on2, fields2) => {
on1 == on2 && {
let left = fields1.keys().all(|x| fields2.contains_key(x));
let right = fields2.keys().all(|x| fields1.contains_key(x));
left && right
&& fields1
.iter()
.all(|(k, v)| fields2.get(k).map(|v2| v == v2).unwrap_or(false))
}
}
_ => false,
},
Value::Primitive(n1) => match other {
Value::Primitive(n2) => n1 == n2,
_ => false,
},
} }
} }
} }
impl From<i8> for Value { impl<IR> From<i8> for Value<IR> {
fn from(value: i8) -> Self { fn from(value: i8) -> Self {
Value::I8(value) Value::I8(value)
} }
} }
impl From<i16> for Value { impl<IR> From<i16> for Value<IR> {
fn from(value: i16) -> Self { fn from(value: i16) -> Self {
Value::I16(value) Value::I16(value)
} }
} }
impl From<i32> for Value { impl<IR> From<i32> for Value<IR> {
fn from(value: i32) -> Self { fn from(value: i32) -> Self {
Value::I32(value) Value::I32(value)
} }
} }
impl From<i64> for Value { impl<IR> From<i64> for Value<IR> {
fn from(value: i64) -> Self { fn from(value: i64) -> Self {
Value::I64(value) Value::I64(value)
} }
} }
impl From<u8> for Value { impl<IR> From<u8> for Value<IR> {
fn from(value: u8) -> Self { fn from(value: u8) -> Self {
Value::U8(value) Value::U8(value)
} }
} }
impl From<u16> for Value { impl<IR> From<u16> for Value<IR> {
fn from(value: u16) -> Self { fn from(value: u16) -> Self {
Value::U16(value) Value::U16(value)
} }
} }
impl From<u32> for Value { impl<IR> From<u32> for Value<IR> {
fn from(value: u32) -> Self { fn from(value: u32) -> Self {
Value::U32(value) Value::U32(value)
} }
} }
impl From<u64> for Value { impl<IR> From<u64> for Value<IR> {
fn from(value: u64) -> Self { fn from(value: u64) -> Self {
Value::U64(value) Value::U64(value)
} }

View File

@@ -12,8 +12,13 @@
//! validating syntax, and then figuring out how to turn it into Cranelift //! validating syntax, and then figuring out how to turn it into Cranelift
//! and object code. After that point, however, this will be the module to //! and object code. After that point, however, this will be the module to
//! come to for analysis and optimization work. //! come to for analysis and optimization work.
mod arbitrary;
pub mod ast; pub mod ast;
mod eval; mod eval;
mod fields;
mod pretty;
mod strings; mod strings;
mod top_level;
pub use ast::*; pub use ast::*;
pub use eval::Evaluator;

466
src/ir/arbitrary.rs Normal file
View File

@@ -0,0 +1,466 @@
use crate::eval::PrimitiveType;
use crate::ir::{Expression, Name, Primitive, Program, Type, TypeWithVoid, Value, ValueOrRef};
use crate::syntax::Location;
use crate::util::scoped_map::ScopedMap;
use proptest::strategy::{NewTree, Strategy, ValueTree};
use proptest::test_runner::{TestRng, TestRunner};
use rand::distributions::{Distribution, WeightedIndex};
use rand::seq::SliceRandom;
use rand::Rng;
use std::collections::HashMap;
use std::str::FromStr;
lazy_static::lazy_static! {
static ref PROGRAM_LENGTH_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new([
0, // % chance of 0
10, // % chance of 1
10, // % chance of 2
15, // % chance of 3
10, // % chance of 4
10, // % chance of 5
10, // % chance of 6
5, // % chance of 7
5, // % chance of 8
5, // % chance of 9
5, // % chance of 10
5, // % chance of 11
3, // % chance of 12
3, // % chance of 13
3, // % chance of 14
1, // % chance of 15
]).unwrap();
static ref BLOCK_LENGTH_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new([
1, // % chance of 0
10, // % chance of 1
20, // % chance of 2
15, // % chance of 3
10, // % chance of 4
8, // % chance of 5
8, // % chance of 6
5, // % chance of 7
5, // % chance of 8
4, // % chance of 9
3, // % chance of 10
3, // % chance of 11
3, // % chance of 12
2, // % chance of 13
2, // % chance of 14
1, // % chance of 15
]).unwrap();
static ref FUNCTION_ARGUMENTS_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new([
5, // % chance of 0
20, // % chance of 1
20, // % chance of 2
20, // % chance of 3
15, // % chance of 4
10, // % chance of 5
5, // % chance of 6
2, // % chance of 7
1, // % chance of 8
1, // % chance of 9
1, // % chance of 10
]).unwrap();
static ref STATEMENT_TYPE_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new(
STATEMENT_TYPE_FREQUENCIES.iter().map(|x| x.1)
).unwrap();
static ref EXPRESSION_TYPE_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new(
EXPRESSION_TYPE_FREQUENCIES.iter().map(|x| x.1)
).unwrap();
static ref ARGUMENT_TYPE_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new(
ARGUMENT_TYPE_FREQUENCIES.iter().map(|x| x.1)
).unwrap();
static ref VALUE_TYPE_DISTRIBUTION: WeightedIndex<usize> = WeightedIndex::new(
VALUE_TYPE_FREQUENCIES.iter().map(|x| x.1)
).unwrap();
}
static STATEMENT_TYPE_FREQUENCIES: &[(StatementType, usize)] = &[
(StatementType::Binding, 3),
(StatementType::Function, 1),
(StatementType::Expression, 2),
];
static EXPRESSION_TYPE_FREQUENCIES: &[(ExpressionType, usize)] = &[
(ExpressionType::Atomic, 50),
(ExpressionType::Cast, 5),
(ExpressionType::Primitive, 5),
(ExpressionType::Block, 10),
(ExpressionType::Print, 10),
(ExpressionType::Bind, 20),
];
static ARGUMENT_TYPE_FREQUENCIES: &[(Type, usize)] = &[
(Type::Primitive(PrimitiveType::U8), 1),
(Type::Primitive(PrimitiveType::U16), 1),
(Type::Primitive(PrimitiveType::U32), 1),
(Type::Primitive(PrimitiveType::U64), 1),
(Type::Primitive(PrimitiveType::I8), 1),
(Type::Primitive(PrimitiveType::I16), 1),
(Type::Primitive(PrimitiveType::I32), 1),
(Type::Primitive(PrimitiveType::I64), 1),
];
enum StatementType {
Binding,
Function,
Expression,
}
enum ExpressionType {
Atomic,
Cast,
Primitive,
Block,
Print,
Bind,
}
// this knowingly excludes void
static VALUE_TYPE_FREQUENCIES: &[(ValueType, usize)] = &[
(ValueType::I8, 1),
(ValueType::I16, 1),
(ValueType::I32, 1),
(ValueType::I64, 1),
(ValueType::U8, 1),
(ValueType::U16, 1),
(ValueType::U32, 1),
(ValueType::U64, 1),
];
#[derive(Copy, Clone)]
enum ValueType {
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
Void,
}
impl From<PrimitiveType> for ValueType {
fn from(value: PrimitiveType) -> Self {
match value {
PrimitiveType::U8 => ValueType::U8,
PrimitiveType::U16 => ValueType::U16,
PrimitiveType::U32 => ValueType::U32,
PrimitiveType::U64 => ValueType::U64,
PrimitiveType::I8 => ValueType::I8,
PrimitiveType::I16 => ValueType::I16,
PrimitiveType::I32 => ValueType::I32,
PrimitiveType::I64 => ValueType::I64,
PrimitiveType::Void => ValueType::Void,
}
}
}
#[derive(Debug, Default)]
pub struct ProgramGenerator {}
impl Strategy for ProgramGenerator {
type Tree = ProgramTree;
type Value = Program<Type>;
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
NewTree::<ProgramGenerator>::Ok(ProgramTree::new(runner.new_rng()))
}
}
pub struct ProgramTree {
_rng: TestRng,
current: Program<Type>,
}
impl ProgramTree {
fn new(_rng: TestRng) -> Self {
// let mut items = vec![];
//let program_length = PROGRAM_LENGTH_DISTRIBUTION.sample(&mut rng);
// let mut env = ScopedMap::new();
// for _ in 0..program_length {
// match STATEMENT_TYPE_FREQUENCIES[STATEMENT_TYPE_DISTRIBUTION.sample(&mut rng)].0 {
// StatementType::Binding => {
// let binding = generate_random_binding(&mut rng, &mut env);
// items.push(TopLevel::Statement(binding));
// }
// StatementType::Expression => {
// let expr = generate_random_expression(&mut rng, &mut env);
// items.push(TopLevel::Statement(expr));
// }
// StatementType::Function => {
// env.new_scope();
// let name = generate_random_name(&mut rng);
// let mut args = vec![];
// let arg_count = FUNCTION_ARGUMENTS_DISTRIBUTION.sample(&mut rng);
// for _ in 0..arg_count {
// let name = generate_random_name(&mut rng);
// let ty = generate_random_argument_type(&mut rng);
// args.push((name, ty));
// }
// let body = generate_random_expression(&mut rng, &mut env);
// let rettype = body.type_of();
// env.release_scope();
// items.push(TopLevel::Function(name, args, rettype, body))
// }
// }
// }
let current = Program {
functions: HashMap::new(),
type_definitions: HashMap::new(),
body: Expression::Block(Location::manufactured(), Type::void(), vec![]),
};
ProgramTree { _rng, current }
}
}
impl ValueTree for ProgramTree {
type Value = Program<Type>;
fn current(&self) -> Self::Value {
self.current.clone()
}
fn simplify(&mut self) -> bool {
false
}
fn complicate(&mut self) -> bool {
false
}
}
#[derive(Debug)]
struct ExpressionGenerator {}
impl Strategy for ExpressionGenerator {
type Tree = ExpressionTree;
type Value = Expression<Type>;
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
NewTree::<ExpressionGenerator>::Ok(ExpressionTree::new(runner.new_rng()))
}
}
struct ExpressionTree {
_rng: TestRng,
current: Expression<Type>,
}
impl ValueTree for ExpressionTree {
type Value = Expression<Type>;
fn current(&self) -> Self::Value {
self.current.clone()
}
fn simplify(&mut self) -> bool {
unimplemented!()
}
fn complicate(&mut self) -> bool {
unimplemented!()
}
}
impl ExpressionTree {
fn new(mut rng: TestRng) -> Self {
let mut env = ScopedMap::new();
let current = generate_random_expression(&mut rng, &mut env);
ExpressionTree { _rng: rng, current }
}
}
fn generate_random_expression(
rng: &mut TestRng,
env: &mut ScopedMap<Name, Type>,
) -> Expression<Type> {
match EXPRESSION_TYPE_FREQUENCIES[EXPRESSION_TYPE_DISTRIBUTION.sample(rng)].0 {
ExpressionType::Atomic => Expression::Atomic(generate_random_valueref(rng, env, None)),
ExpressionType::Bind => generate_random_binding(rng, env),
ExpressionType::Block => {
let num_stmts = BLOCK_LENGTH_DISTRIBUTION.sample(rng);
let mut stmts = Vec::new();
if num_stmts == 0 {
return Expression::Block(Location::manufactured(), Type::void(), stmts);
}
env.new_scope();
for _ in 1..num_stmts {
let mut next = generate_random_expression(rng, env);
let next_type = next.type_of();
if !next_type.is_void() {
let name = generate_random_name(rng);
env.insert(name.clone(), next_type.clone());
next =
Expression::Bind(Location::manufactured(), name, next_type, Box::new(next));
}
stmts.push(next);
}
let last_expr = generate_random_expression(rng, env);
let last_type = last_expr.type_of();
stmts.push(last_expr);
env.release_scope();
Expression::Block(Location::manufactured(), last_type, stmts)
}
ExpressionType::Cast => {
let inner = generate_random_valueref(rng, env, None);
match inner.type_of() {
// nevermind
Type::Function(_, _) => Expression::Atomic(inner),
Type::Primitive(primty) => {
let to_type = primty
.allowed_casts()
.choose(rng)
.expect("actually chose type");
Expression::Cast(Location::manufactured(), Type::Primitive(*to_type), inner)
}
Type::Structure(_) => unimplemented!(),
}
}
ExpressionType::Primitive => {
let base_expr = generate_random_valueref(rng, env, None);
let out_type = base_expr.type_of();
match out_type {
Type::Function(_, _) => Expression::Atomic(base_expr),
Type::Primitive(primty) => match primty.valid_operators().choose(rng) {
None => Expression::Atomic(base_expr),
Some((operator, arg_count)) => {
let primop = Primitive::from_str(operator).expect("chose valid primitive");
let mut args = vec![base_expr];
let mut argtys = vec![];
while args.len() < *arg_count {
args.push(generate_random_valueref(rng, env, Some(primty)));
argtys.push(Type::Primitive(primty));
}
let primtype = Type::Function(argtys, Box::new(Type::Primitive(primty)));
Expression::Call(
Location::manufactured(),
out_type,
ValueOrRef::Primitive(Location::manufactured(), primtype, primop),
args,
)
}
},
Type::Structure(_) => unimplemented!(),
}
}
ExpressionType::Print => {
let possible_variables = env
.bindings()
.iter()
.filter_map(|(variable, ty)| {
if ty.is_printable() {
Some((variable.clone(), ty.clone()))
} else {
None
}
})
.collect::<Vec<_>>();
if possible_variables.is_empty() {
generate_random_binding(rng, env)
} else {
let (variable, var_type) = possible_variables.choose(rng).unwrap();
Expression::Call(
Location::manufactured(),
Type::void(),
ValueOrRef::Primitive(Location::manufactured(), Type::void(), Primitive::Print),
vec![ValueOrRef::Ref(
Location::manufactured(),
var_type.clone(),
variable.clone(),
)],
)
}
}
}
}
fn generate_random_binding(rng: &mut TestRng, env: &mut ScopedMap<Name, Type>) -> Expression<Type> {
let name = generate_random_name(rng);
let expr = generate_random_expression(rng, env);
let ty = expr.type_of();
env.insert(name.clone(), ty.clone());
Expression::Bind(Location::manufactured(), name, ty, Box::new(expr))
}
fn generate_random_valueref(
rng: &mut TestRng,
env: &mut ScopedMap<Name, Type>,
target_type: Option<PrimitiveType>,
) -> ValueOrRef<Type> {
let mut bindings = env.bindings();
bindings.retain(|_, value| {
target_type
.map(|x| value == &Type::Primitive(x))
.unwrap_or(true)
});
if rng.gen() || bindings.is_empty() {
let value_type = if let Some(target_type) = target_type {
ValueType::from(target_type)
} else {
VALUE_TYPE_FREQUENCIES[VALUE_TYPE_DISTRIBUTION.sample(rng)].0
};
// generate a constant
let val = match value_type {
ValueType::I8 => Value::I8(None, rng.gen()),
ValueType::I16 => Value::I16(None, rng.gen()),
ValueType::I32 => Value::I32(None, rng.gen()),
ValueType::I64 => Value::I64(None, rng.gen()),
ValueType::U8 => Value::U8(None, rng.gen()),
ValueType::U16 => Value::U16(None, rng.gen()),
ValueType::U32 => Value::U32(None, rng.gen()),
ValueType::U64 => Value::U64(None, rng.gen()),
ValueType::Void => Value::Void,
};
ValueOrRef::Value(Location::manufactured(), val.type_of(), val)
} else {
// generate a reference
let weighted_keys = bindings.keys().map(|x| (1, x)).collect::<Vec<_>>();
let distribution = WeightedIndex::new(weighted_keys.iter().map(|x| x.0)).unwrap();
let var = weighted_keys[distribution.sample(rng)].1.clone();
let ty = bindings
.remove(&var)
.expect("chose unbound variable somehow?");
ValueOrRef::Ref(Location::manufactured(), ty, var)
}
}
fn generate_random_name(rng: &mut TestRng) -> Name {
let start = rng.gen_range('a'..='z');
Name::gensym(start)
}
//fn generate_random_argument_type(rng: &mut TestRng) -> Type {
// ARGUMENT_TYPE_FREQUENCIES[ARGUMENT_TYPE_DISTRIBUTION.sample(rng)]
// .0
// .clone()
//}

View File

@@ -1,108 +1,54 @@
use crate::{ use crate::eval::PrimitiveType;
eval::PrimitiveType, pub use crate::ir::fields::Fields;
syntax::{self, ConstantType, Location}, pub use crate::syntax::Name;
}; use crate::syntax::{ConstantType, Location};
use internment::ArcIntern; use proptest::arbitrary::Arbitrary;
use pretty::{BoxAllocator, DocAllocator, Pretty}; use std::collections::HashMap;
use proptest::{ use std::convert::TryFrom;
prelude::Arbitrary, use std::str::FromStr;
strategy::{BoxedStrategy, Strategy},
};
use std::{fmt, str::FromStr};
/// We're going to represent variables as interned strings. use super::arbitrary::ProgramGenerator;
///
/// These should be fast enough for comparison that it's OK, since it's going to end up
/// being pretty much the pointer to the string.
type Variable = ArcIntern<String>;
/// The representation of a program within our IR. For now, this is exactly one file. /// The representation of a program within our IR. For now, this is exactly one file.
/// ///
/// In addition, for the moment there's not really much of interest to hold here besides /// A program consists of a series of statements and functions. The statements should
/// the list of statements read from the file. Order is important. In the future, you /// be executed in order. The functions currently may not reference any variables
/// could imagine caching analysis information in this structure. /// at the top level, so their order only matters in relation to each other (functions
/// may not be referenced before they are defined).
/// ///
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used /// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
/// to print the structure whenever possible, especially if you value your or your /// to print the structure whenever possible, especially if you value your or your
/// user's time. The latter is useful for testing that conversions of `Program` retain /// user's time. The latter is useful for testing that conversions of `Program` retain
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be /// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
/// syntactically valid, although they may contain runtime issue like over- or underflow. /// syntactically valid, although they may contain runtime issue like over- or underflow.
#[derive(Debug)]
pub struct Program {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) statements: Vec<Statement>,
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let mut result = allocator.nil();
for stmt in self.statements.iter() {
// there's probably a better way to do this, rather than constantly
// adding to the end, but this works.
result = result
.append(stmt.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline());
}
result
}
}
impl Arbitrary for Program {
type Parameters = crate::syntax::arbitrary::GenerationEnvironment;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
crate::syntax::Program::arbitrary_with(args)
.prop_map(|x| {
x.type_infer()
.expect("arbitrary_with should generate type-correct programs")
})
.boxed()
}
}
/// The representation of a statement in the language.
/// ///
/// For now, this is either a binding site (`x = 4`) or a print statement /// The type variable is, somewhat confusingly, the current definition of a type within
/// (`print x`). Someday, though, more! /// the IR. Since the makeup of this structure may change over the life of the compiler,
/// /// it's easiest to just make it an argument.
/// As with `Program`, this type implements [`Pretty`], which should #[derive(Clone, Debug)]
/// be used to display the structure whenever possible. It does not pub struct Program<Type> {
/// implement [`Arbitrary`], though, mostly because it's slightly // The set of functions declared in this program.
/// complicated to do so. pub functions: HashMap<Name, FunctionDefinition<Type>>,
/// // The set of types declared in this program.
#[derive(Debug)] pub type_definitions: HashMap<Name, Type>,
pub enum Statement { // The thing to evaluate in the end.
Binding(Location, Variable, Type, Expression), pub body: Expression<Type>,
Print(Location, Type, Variable),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement #[derive(Clone, Debug)]
where pub struct FunctionDefinition<Type> {
A: 'a, pub name: Name,
D: ?Sized + DocAllocator<'a, A>, pub arguments: Vec<(Name, Type)>,
{ pub return_type: Type,
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { pub body: Expression<Type>,
match self {
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
} }
impl Arbitrary for Program<Type> {
type Parameters = ();
type Strategy = ProgramGenerator;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
ProgramGenerator::default()
} }
} }
@@ -118,22 +64,29 @@ where
/// a primitive), any subexpressions have been bound to variables so /// a primitive), any subexpressions have been bound to variables so
/// that the referenced data will always either be a constant or a /// that the referenced data will always either be a constant or a
/// variable reference. /// variable reference.
#[derive(Debug)] #[derive(Clone, Debug)]
pub enum Expression { pub enum Expression<Type> {
Atomic(ValueOrRef), Atomic(ValueOrRef<Type>),
Cast(Location, Type, ValueOrRef), Cast(Location, Type, ValueOrRef<Type>),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>), Construct(Location, Type, Name, HashMap<Name, ValueOrRef<Type>>),
FieldRef(Location, Type, Type, ValueOrRef<Type>, Name),
Block(Location, Type, Vec<Expression<Type>>),
Call(Location, Type, ValueOrRef<Type>, Vec<ValueOrRef<Type>>),
Bind(Location, Name, Type, Box<Expression<Type>>),
} }
impl Expression { impl<Type: Clone + TypeWithVoid> Expression<Type> {
/// Return a reference to the type of the expression, as inferred or recently /// Return the type of the expression, as inferred or recently
/// computed. /// computed.
pub fn type_of(&self) -> &Type { pub fn type_of(&self) -> Type {
match self { match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t, Expression::Atomic(x) => x.type_of(),
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t, Expression::Cast(_, t, _) => t.clone(),
Expression::Cast(_, t, _) => t, Expression::Construct(_, t, _, _) => t.clone(),
Expression::Primitive(_, t, _, _) => t, Expression::FieldRef(_, t, _, _, _) => t.clone(),
Expression::Block(_, t, _) => t.clone(),
Expression::Call(_, t, _, _) => t.clone(),
Expression::Bind(_, _, t, _) => t.clone(),
} }
} }
@@ -142,41 +95,13 @@ impl Expression {
match self { match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l, Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l, Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Primitive(l, _, _)) => l,
Expression::Cast(l, _, _) => l, Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l, Expression::Construct(l, _, _, _) => l,
} Expression::FieldRef(l, _, _, _, _) => l,
} Expression::Block(l, _, _) => l,
} Expression::Call(l, _, _, _) => l,
Expression::Bind(l, _, _, _) => l,
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(op.pretty(allocator))
.append(allocator.space())
.append(right)
.parens()
}
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
} }
} }
} }
@@ -193,6 +118,8 @@ pub enum Primitive {
Minus, Minus,
Times, Times,
Divide, Divide,
Print,
Negate,
} }
impl FromStr for Primitive { impl FromStr for Primitive {
@@ -204,58 +131,37 @@ impl FromStr for Primitive {
"-" => Ok(Primitive::Minus), "-" => Ok(Primitive::Minus),
"*" => Ok(Primitive::Times), "*" => Ok(Primitive::Times),
"/" => Ok(Primitive::Divide), "/" => Ok(Primitive::Divide),
"print" => Ok(Primitive::Print),
"negate" => Ok(Primitive::Negate),
_ => Err(format!("Illegal primitive {}", value)), _ => Err(format!("Illegal primitive {}", value)),
} }
} }
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Primitive
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Primitive::Plus => allocator.text("+"),
Primitive::Minus => allocator.text("-"),
Primitive::Times => allocator.text("*"),
Primitive::Divide => allocator.text("/"),
}
}
}
impl fmt::Display for Primitive {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<&Primitive as Pretty<'_, BoxAllocator, ()>>::pretty(self, &BoxAllocator).render_fmt(72, f)
}
}
/// An expression that is always either a value or a reference. /// An expression that is always either a value or a reference.
/// ///
/// This is the type used to guarantee that we don't nest expressions /// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one /// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference. /// of these, which can only be a constant or a reference.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum ValueOrRef { pub enum ValueOrRef<Type> {
Value(Location, Type, Value), Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>), Ref(Location, Type, Name),
Primitive(Location, Type, Primitive),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef impl<Type: Clone> ValueOrRef<Type> {
where pub fn type_of(&self) -> Type {
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self { match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator), ValueOrRef::Ref(_, t, _) => t.clone(),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()), ValueOrRef::Value(_, t, _) => t.clone(),
ValueOrRef::Primitive(_, t, _) => t.clone(),
} }
} }
} }
impl From<ValueOrRef> for Expression { impl<Type> From<ValueOrRef<Type>> for Expression<Type> {
fn from(value: ValueOrRef) -> Self { fn from(value: ValueOrRef<Type>) -> Self {
Expression::Atomic(value) Expression::Atomic(value)
} }
} }
@@ -276,6 +182,7 @@ pub enum Value {
U16(Option<u8>, u16), U16(Option<u8>, u16),
U32(Option<u8>, u32), U32(Option<u8>, u32),
U64(Option<u8>, u64), U64(Option<u8>, u64),
Void,
} }
impl Value { impl Value {
@@ -290,49 +197,7 @@ impl Value {
Value::U16(_, _) => Type::Primitive(PrimitiveType::U16), Value::U16(_, _) => Type::Primitive(PrimitiveType::U16),
Value::U32(_, _) => Type::Primitive(PrimitiveType::U32), Value::U32(_, _) => Type::Primitive(PrimitiveType::U32),
Value::U64(_, _) => Type::Primitive(PrimitiveType::U64), Value::U64(_, _) => Type::Primitive(PrimitiveType::U64),
} Value::Void => Type::void(),
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
} }
} }
} }
@@ -340,24 +205,254 @@ where
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type { pub enum Type {
Primitive(PrimitiveType), Primitive(PrimitiveType),
Function(Vec<Type>, Box<Type>),
Structure(Fields<Type>),
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type impl Type {
where /// Returns true if this variable can reasonably be passed to the print
A: 'a, /// expression for printing.
D: ?Sized + DocAllocator<'a, A>, pub fn is_printable(&self) -> bool {
{ matches!(self, Type::Primitive(_))
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> { }
/// Returns true if the variable is signed.
pub fn is_signed(&self) -> bool {
matches!(
self,
Type::Primitive(
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 | PrimitiveType::I64
)
)
}
}
impl From<PrimitiveType> for Type {
fn from(value: PrimitiveType) -> Self {
Type::Primitive(value)
}
}
impl<'a> TryInto<ConstantType> for &'a Type {
type Error = ();
fn try_into(self) -> Result<ConstantType, Self::Error> {
match self { match self {
Type::Primitive(pt) => allocator.text(format!("{}", pt)), Type::Primitive(pt) => Ok((*pt).into()),
Type::Function(_, _) => Err(()),
Type::Structure(_) => Err(()),
} }
} }
} }
impl fmt::Display for Type { #[derive(Clone, Debug, Eq, PartialEq)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { pub enum TypeOrVar {
Primitive(PrimitiveType),
Variable(Location, Name),
Function(Vec<TypeOrVar>, Box<TypeOrVar>),
Structure(Fields<TypeOrVar>),
}
impl Default for TypeOrVar {
fn default() -> Self {
TypeOrVar::new()
}
}
impl TypeOrVar {
/// Generate a fresh type variable that is different from all previous type variables.
///
/// This type variable is guaranteed to be unique across the process lifetime. Overuse
/// of this function could potentially cause overflow problems, but you're going to have
/// to try really hard (like, 2^64 times) to make that happen. The location bound to
/// this address will be purely manufactured; if you want to specify a location, use
/// [`TypeOrVar::new_located`].
pub fn new() -> Self {
Self::new_located(Location::manufactured())
}
/// Generate a fresh type variable that is different from all previous type variables.
///
/// This type variable is guaranteed to be unique across the process lifetime. Overuse
/// of this function could potentially cause overflow problems, but you're going to have
/// to try really hard (like, 2^64 times) to make that happen.
pub fn new_located(loc: Location) -> Self {
TypeOrVar::Variable(loc.clone(), Name::located_gensym(loc, "t"))
}
/// Try replacing the given type variable with the given type, returning true if anything
/// was changed.
pub fn replace(&mut self, name: &Name, replace_with: &TypeOrVar) -> bool {
match self { match self {
Type::Primitive(pt) => pt.fmt(f), TypeOrVar::Variable(_, var_name) if name == var_name => {
*self = replace_with.clone();
true
}
TypeOrVar::Variable(_, _) => false,
TypeOrVar::Function(args, ret) => {
ret.replace(name, replace_with)
| args.iter_mut().any(|x| x.replace(name, replace_with))
}
TypeOrVar::Primitive(_) => false,
TypeOrVar::Structure(fields) => {
fields.types_mut().any(|x| x.replace(name, replace_with))
} }
} }
} }
/// Returns whether or not this type is resolved (meaning that it contains no type
/// variables.)
pub fn is_resolved(&self) -> bool {
match self {
TypeOrVar::Variable(_, _) => false,
TypeOrVar::Primitive(_) => true,
TypeOrVar::Function(args, ret) => {
args.iter().all(TypeOrVar::is_resolved) && ret.is_resolved()
}
TypeOrVar::Structure(fields) => fields.types().all(TypeOrVar::is_resolved),
}
}
}
impl PartialEq<Type> for TypeOrVar {
fn eq(&self, other: &Type) -> bool {
match other {
Type::Function(a, b) => match self {
TypeOrVar::Function(x, y) => x == a && y.as_ref() == b.as_ref(),
_ => false,
},
Type::Primitive(a) => match self {
TypeOrVar::Primitive(x) => a == x,
_ => false,
},
Type::Structure(fields1) => match self {
TypeOrVar::Structure(fields2) => {
fields1.count() == fields2.count()
&& fields1.iter().all(|(name, subtype)| {
fields2.get(name).map(|x| x == subtype).unwrap_or(false)
})
}
_ => false,
},
}
}
}
pub trait TypeWithVoid {
fn void() -> Self;
fn is_void(&self) -> bool;
}
impl TypeWithVoid for Type {
fn void() -> Self {
Type::Primitive(PrimitiveType::Void)
}
fn is_void(&self) -> bool {
self == &Type::Primitive(PrimitiveType::Void)
}
}
impl TypeWithVoid for TypeOrVar {
fn void() -> Self {
TypeOrVar::Primitive(PrimitiveType::Void)
}
fn is_void(&self) -> bool {
self == &TypeOrVar::Primitive(PrimitiveType::Void)
}
}
pub trait TypeWithFunction: Sized {
fn build_function_type(arg_types: Vec<Self>, ret_type: Self) -> Self;
}
impl TypeWithFunction for Type {
fn build_function_type(arg_types: Vec<Self>, ret_type: Self) -> Self {
Type::Function(arg_types, Box::new(ret_type))
}
}
impl<T: Into<Type>> From<T> for TypeOrVar {
fn from(value: T) -> Self {
match value.into() {
Type::Primitive(p) => TypeOrVar::Primitive(p),
Type::Function(args, ret) => TypeOrVar::Function(
args.into_iter().map(Into::into).collect(),
Box::new((*ret).into()),
),
Type::Structure(fields) => TypeOrVar::Structure(fields.map(Into::into)),
}
}
}
impl TryFrom<TypeOrVar> for Type {
type Error = TypeOrVar;
fn try_from(value: TypeOrVar) -> Result<Self, Self::Error> {
match value {
TypeOrVar::Function(args, ret) => {
let converted_args = args
.iter()
.cloned()
.map(Type::try_from)
.collect::<Result<_, _>>();
let converted_ret = Type::try_from((*ret).clone());
if let Ok(args) = converted_args {
if let Ok(ret) = converted_ret {
return Ok(Type::Function(args, Box::new(ret)));
}
}
Err(TypeOrVar::Function(args, ret))
}
TypeOrVar::Primitive(t) => Ok(Type::Primitive(t)),
TypeOrVar::Structure(fields) => {
let mut new_fields = Fields::new(fields.ordering());
let mut errored = false;
for (name, field) in fields.iter() {
if let Ok(new_field) = field.clone().try_into() {
new_fields.insert(name.clone(), new_field);
} else {
errored = true;
}
}
if errored {
return Err(TypeOrVar::Structure(fields));
}
Ok(Type::Structure(new_fields))
}
TypeOrVar::Variable(_, _) => Err(value),
}
}
}
#[test]
fn struct_sizes_are_rational() {
assert_eq!(8, std::mem::size_of::<Name>());
assert_eq!(24, std::mem::size_of::<Location>());
assert_eq!(1, std::mem::size_of::<Primitive>());
assert_eq!(1, std::mem::size_of::<PrimitiveType>());
assert_eq!(32, std::mem::size_of::<Fields<Type>>());
assert_eq!(40, std::mem::size_of::<Type>());
assert_eq!(40, std::mem::size_of::<TypeOrVar>());
assert_eq!(80, std::mem::size_of::<ValueOrRef<Type>>());
assert_eq!(80, std::mem::size_of::<ValueOrRef<TypeOrVar>>());
assert_eq!(200, std::mem::size_of::<Expression<Type>>());
assert_eq!(200, std::mem::size_of::<Expression<TypeOrVar>>());
assert_eq!(72, std::mem::size_of::<Program<Type>>());
assert_eq!(72, std::mem::size_of::<Program<TypeOrVar>>());
}

View File

@@ -1,100 +1,208 @@
use crate::eval::{EvalEnvironment, EvalError, Value}; use super::{FunctionDefinition, Type, ValueOrRef};
use crate::ir::{Expression, Program, Statement}; use crate::eval::{EvalError, Value};
use crate::ir::{Expression, Name, Program};
use crate::util::scoped_map::ScopedMap;
use std::collections::HashMap;
use std::fmt::Display;
use super::{Primitive, Type, ValueOrRef}; type IRValue<T> = Value<Expression<T>>;
type IREvalError<T> = EvalError<Expression<T>>;
impl Program { pub struct Evaluator<T> {
/// Evaluate the program, returning either an error or a string containing everything env: ScopedMap<Name, IRValue<T>>,
/// the program printed out. functions: HashMap<Name, FunctionDefinition<T>>,
stdout: String,
}
impl<T> Default for Evaluator<T> {
fn default() -> Self {
Evaluator {
env: ScopedMap::new(),
functions: HashMap::new(),
stdout: String::new(),
}
}
}
impl<T> Evaluator<T>
where
T: Clone + Into<Type>,
Expression<T>: Display,
{
/// Evaluate the program, returning either an error or the result of the final
/// statement and the complete contents of the console output.
/// ///
/// The print outs will be newline separated, with one print out per line. /// The print outs will be newline separated, with one print out per line.
pub fn eval(&self) -> Result<String, EvalError> { pub fn eval(mut self, program: Program<T>) -> Result<(IRValue<T>, String), IREvalError<T>> {
let mut env = EvalEnvironment::empty(); self.functions.extend(program.functions);
let mut stdout = String::new(); let retval = self.eval_expr(program.body)?;
Ok((retval, self.stdout))
for stmt in self.statements.iter() {
match stmt {
Statement::Binding(_, name, _, value) => {
let actual_value = value.eval(&env)?;
env = env.extend(name.clone(), actual_value);
} }
Statement::Print(_, _, name) => { /// Get the current output of the evaluated program.
let value = env.lookup(name.clone())?; pub fn stdout(self) -> String {
let line = format!("{} = {}\n", name, value); self.stdout
stdout.push_str(&line);
}
}
} }
Ok(stdout) fn eval_expr(&mut self, expr: Expression<T>) -> Result<IRValue<T>, IREvalError<T>> {
} match expr {
} Expression::Atomic(x) => self.eval_atomic(x),
impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> {
match self {
Expression::Atomic(x) => x.eval(env),
Expression::Cast(_, t, valref) => { Expression::Cast(_, t, valref) => {
let value = valref.eval(env)?; let value = self.eval_atomic(valref)?;
let ty = t.into();
match t { match ty {
Type::Primitive(pt) => Ok(pt.safe_cast(&value)?), Type::Primitive(pt) => Ok(pt.safe_cast(&value)?),
Type::Function(_, _) => Err(EvalError::CastToFunction(ty.to_string())),
Type::Structure(_) => unimplemented!(),
} }
} }
Expression::Primitive(_, _, op, args) => { Expression::Construct(_, _, name, fields) => {
let arg_values = args let mut result_fields = HashMap::with_capacity(fields.len());
.iter()
.map(|x| x.eval(env))
.collect::<Result<Vec<Value>, EvalError>>()?;
// and then finally we call `calculate` to run them. trust me, it's nice for (name, subexpr) in fields.into_iter() {
// to not have to deal with all the nonsense hidden under `calculate`. result_fields.insert(name.clone(), self.eval_atomic(subexpr)?);
match op {
Primitive::Plus => Ok(Value::calculate("+", arg_values)?),
Primitive::Minus => Ok(Value::calculate("-", arg_values)?),
Primitive::Times => Ok(Value::calculate("*", arg_values)?),
Primitive::Divide => Ok(Value::calculate("/", arg_values)?),
}
}
}
}
} }
impl ValueOrRef { Ok(Value::Structure(Some(name.clone()), result_fields))
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { }
match self {
ValueOrRef::Value(_, _, v) => match v { Expression::FieldRef(loc, _, _, valref, field) => match self.eval_atomic(valref)? {
super::Value::I8(_, v) => Ok(Value::I8(*v)), Value::Structure(oname, mut fields) => match fields.remove(&field) {
super::Value::I16(_, v) => Ok(Value::I16(*v)), None => Err(EvalError::NoFieldForValue(
super::Value::I32(_, v) => Ok(Value::I32(*v)), loc.clone(),
super::Value::I64(_, v) => Ok(Value::I64(*v)), Value::Structure(oname, fields),
super::Value::U8(_, v) => Ok(Value::U8(*v)), field.clone(),
super::Value::U16(_, v) => Ok(Value::U16(*v)), )),
super::Value::U32(_, v) => Ok(Value::U32(*v)), Some(value) => Ok(value),
super::Value::U64(_, v) => Ok(Value::U64(*v)),
}, },
ValueOrRef::Ref(_, _, n) => Ok(env.lookup(n.clone())?), x => Err(EvalError::NoFieldForValue(loc.clone(), x, field.clone())),
},
Expression::Block(_, _, stmts) => {
let mut result = Value::Void;
for stmt in stmts.into_iter() {
result = self.eval_expr(stmt)?;
}
Ok(result)
}
Expression::Bind(_, name, _, value) => {
let value = self.eval_expr(*value)?;
self.env.insert(name.clone(), value.clone());
Ok(value)
}
Expression::Call(loc, _, fun, args) => {
let function = self.eval_atomic(fun)?;
match function {
Value::Closure(name, mut closure_env, arguments, body) => {
if args.len() != arguments.len() {
return Err(EvalError::WrongArgCount(
loc.clone(),
name,
arguments.len(),
args.len(),
));
}
closure_env.new_scope();
for (name, value) in arguments.into_iter().zip(args) {
let value = self.eval_atomic(value)?;
closure_env.insert(name, value);
}
let temp_ref = &mut closure_env;
std::mem::swap(&mut self.env, temp_ref);
let result = self.eval_expr(body);
std::mem::swap(&mut self.env, temp_ref);
closure_env.release_scope();
result
}
Value::Primitive(name) if name == "print" => {
if let [ValueOrRef::Ref(loc, ty, name)] = &args[..] {
let value = self.eval_atomic(ValueOrRef::Ref(
loc.clone(),
ty.clone(),
name.clone(),
))?;
let addendum = format!("{} = {}\n", name, value);
self.stdout.push_str(&addendum);
Ok(Value::Void)
} else {
panic!("Non-reference/non-singleton argument to 'print'");
}
}
Value::Primitive(name) => {
let values = args
.into_iter()
.map(|x| self.eval_atomic(x))
.collect::<Result<_, _>>()?;
Value::calculate(name.as_str(), values).map_err(Into::into)
}
_ => Err(EvalError::NotAFunction(loc.clone(), function)),
}
}
}
}
fn eval_atomic(&self, value: ValueOrRef<T>) -> Result<IRValue<T>, IREvalError<T>> {
match value {
ValueOrRef::Value(_, _, v) => match v {
super::Value::I8(_, v) => Ok(Value::I8(v)),
super::Value::I16(_, v) => Ok(Value::I16(v)),
super::Value::I32(_, v) => Ok(Value::I32(v)),
super::Value::I64(_, v) => Ok(Value::I64(v)),
super::Value::U8(_, v) => Ok(Value::U8(v)),
super::Value::U16(_, v) => Ok(Value::U16(v)),
super::Value::U32(_, v) => Ok(Value::U32(v)),
super::Value::U64(_, v) => Ok(Value::U64(v)),
super::Value::Void => Ok(Value::Void),
},
ValueOrRef::Ref(loc, _, n) => self
.env
.get(&n)
.cloned()
.ok_or_else(|| EvalError::LookupFailed(loc.clone(), n.to_string())),
ValueOrRef::Primitive(_, _, prim) => Ok(Value::primitive(prim)),
} }
} }
} }
#[test] #[test]
fn two_plus_three() { fn two_plus_three() {
let input = crate::syntax::Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = crate::syntax::parse_string(0, "x = 2 + 3; print x;").expect("parse works");
let ir = input.type_infer().expect("test should be type-valid"); let program = crate::syntax::Program::validate(input)
let output = ir.eval().expect("runs successfully"); .into_result()
assert_eq!("x = 5u64\n", &output); .unwrap();
let ir = program.type_infer().expect("test should be type-valid");
let evaluator = Evaluator::default();
let (_, result) = evaluator.eval(ir).expect("runs successfully");
assert_eq!("x = 5u64\n", &result);
} }
#[test] #[test]
fn lotsa_math() { fn lotsa_math() {
let input = let input =
crate::syntax::Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); crate::syntax::parse_string(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let ir = input.type_infer().expect("test should be type-valid"); let program = crate::syntax::Program::validate(input)
let output = ir.eval().expect("runs successfully"); .into_result()
assert_eq!("x = 7u64\n", &output); .unwrap();
let ir = program.type_infer().expect("test should be type-valid");
let evaluator = Evaluator::default();
let (_, result) = evaluator.eval(ir).expect("runs successfully");
assert_eq!("x = 7u64\n", &result);
} }

144
src/ir/fields.rs Normal file
View File

@@ -0,0 +1,144 @@
use crate::syntax::Name;
use cranelift_module::DataDescription;
use std::fmt;
#[derive(Clone)]
pub struct Fields<T> {
ordering: FieldOrdering,
total_size: usize,
fields: Vec<(Name, T)>,
}
impl<T: PartialEq> PartialEq for Fields<T> {
fn eq(&self, other: &Self) -> bool {
self.ordering == other.ordering && self.fields == other.fields
}
}
impl<T: Eq> Eq for Fields<T> {}
impl<T: fmt::Debug> fmt::Debug for Fields<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Fields:")?;
self.fields.fmt(f)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FieldOrdering {
Standard,
}
impl<T> Default for Fields<T> {
fn default() -> Self {
Self::new(FieldOrdering::Standard)
}
}
impl<T> Fields<T> {
pub fn new(ordering: FieldOrdering) -> Fields<T> {
Fields {
ordering,
total_size: 0,
fields: vec![],
}
}
pub fn ordering(&self) -> FieldOrdering {
self.ordering
}
pub fn insert(&mut self, name: Name, t: T) {
self.total_size += 8;
self.fields.push((name, t));
}
pub fn get(&self, name: &Name) -> Option<&T> {
for (n, res) in self.fields.iter() {
if n == name {
return Some(res);
}
}
None
}
pub fn map<T2, F: Fn(T) -> T2>(self, f: F) -> Fields<T2> {
Fields {
ordering: self.ordering,
total_size: self.total_size,
fields: self.fields.into_iter().map(|(n, t)| (n, f(t))).collect(),
}
}
pub fn count(&self) -> usize {
self.fields.len()
}
pub fn has_field(&self, name: &Name) -> bool {
self.fields.iter().any(|(current, _)| current == name)
}
pub fn remove_field(&mut self, name: &Name) -> Option<T> {
let mut field_index = None;
for (idx, (current, _)) in self.fields.iter().enumerate() {
if current == name {
field_index = Some(idx);
break;
}
}
field_index.map(|i| self.fields.remove(i).1)
}
pub fn iter(&self) -> impl Iterator<Item = (&Name, &T)> {
self.fields.iter().map(|(x, y)| (x, y))
}
pub fn field_names(&self) -> impl Iterator<Item = &Name> {
self.fields.iter().map(|(n, _)| n)
}
pub fn types(&self) -> impl Iterator<Item = &T> {
self.fields.iter().map(|(_, x)| x)
}
pub fn types_mut(&mut self) -> impl Iterator<Item = &mut T> {
self.fields.iter_mut().map(|(_, x)| x)
}
pub fn blank_data(&mut self) -> DataDescription {
let mut cranelift_description = DataDescription::new();
cranelift_description.set_align(8);
cranelift_description.define_zeroinit(self.total_size);
cranelift_description
}
pub fn field_type_and_offset(&self, field: &Name) -> Option<(&T, i32)> {
let mut offset = 0;
for (current, ty) in self.fields.iter() {
if current == field {
return Some((ty, offset));
}
offset += 8;
}
None
}
pub fn object_size(&self) -> usize {
self.total_size
}
}
impl<T> IntoIterator for Fields<T> {
type Item = (Name, T);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.fields.into_iter()
}
}

248
src/ir/pretty.rs Normal file
View File

@@ -0,0 +1,248 @@
use crate::ir::{Expression, Primitive, Program, Type, TypeOrVar, Value, ValueOrRef};
use crate::syntax::{self, ConstantType};
use crate::util::pretty::{derived_display, pretty_function_type, Allocator};
use pretty::{Arena, DocAllocator, DocBuilder};
impl Program<Type> {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
let mut result = allocator.nil();
for (name, ty) in self.type_definitions.iter() {
result = result
.append(allocator.text("type"))
.append(allocator.space())
.append(allocator.text(name.current_name().to_string()))
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(ty.pretty(allocator))
.append(allocator.hardline());
}
if !self.type_definitions.is_empty() {
result = result.append(allocator.hardline());
}
for function in self.functions.values() {
result = result
.append(allocator.text("function"))
.append(allocator.space())
.append(allocator.text(function.name.current_name().to_string()))
.append(allocator.text("("))
.append(allocator.intersperse(
function.arguments.iter().map(|(x, t)| {
allocator
.text(x.original_name().to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(t.pretty(allocator))
}),
allocator.text(","),
))
.append(allocator.text(")"))
.append(allocator.space())
.append(allocator.text("->"))
.append(allocator.space())
.append(function.return_type.pretty(allocator))
.append(allocator.softline())
.append(function.body.pretty(allocator))
.append(allocator.hardline());
}
if !self.functions.is_empty() {
result = result.append(allocator.hardline());
}
result.append(self.body.pretty(allocator))
}
}
impl Expression<Type> {
pub fn pretty<'a>(
&self,
allocator: &'a Arena<'a, ()>,
) -> pretty::DocBuilder<'a, Arena<'a, ()>, ()> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Construct(_, _, name, fields) => {
let inner = allocator
.intersperse(
fields.iter().map(|(k, v)| {
allocator
.text(k.to_string())
.append(":")
.append(allocator.space())
.append(v.pretty(allocator))
.append(allocator.text(";"))
}),
allocator.line(),
)
.indent(2)
.braces();
allocator.text(name.to_string()).append(inner)
}
Expression::FieldRef(_, _, _, val, field) => val.pretty(allocator).append(
allocator
.text(".")
.append(allocator.text(field.to_string())),
),
Expression::Call(_, _, fun, args) => {
let args = args.iter().map(|x| x.pretty(allocator));
let comma_sepped_args = allocator.intersperse(args, allocator.text(","));
fun.pretty(allocator).append(comma_sepped_args.parens())
}
Expression::Block(_, _, exprs) => match exprs.split_last() {
None => allocator.text("()"),
Some((last, &[])) => last.pretty(allocator),
Some((last, start)) => {
let mut result = allocator.text("{").append(allocator.hardline());
let starts = start.iter().map(|x| {
x.pretty(allocator)
.append(allocator.text(";"))
.append(allocator.hardline())
.indent(4)
});
let last = last
.pretty(allocator)
.append(allocator.hardline())
.indent(4);
for start in starts {
result = result.append(start);
}
result.append(last).append(allocator.text("}"))
}
},
Expression::Bind(_, var, ty, expr) => allocator
.text(var.current_name().to_string())
.append(allocator.space())
.append(allocator.text(":"))
.append(allocator.space())
.append(ty.pretty(allocator))
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
}
}
}
impl Primitive {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
Primitive::Plus => allocator.text("+"),
Primitive::Minus => allocator.text("-"),
Primitive::Times => allocator.text("*"),
Primitive::Divide => allocator.text("/"),
Primitive::Print => allocator.text("print"),
Primitive::Negate => allocator.text("negate"),
}
}
}
impl ValueOrRef<Type> {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.current_name().to_string()),
ValueOrRef::Primitive(_, _, p) => p.pretty(allocator),
}
}
}
impl Value {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
Value::Void => allocator.text("<void>"),
}
}
}
impl Type {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
Type::Function(args, rettype) => pretty_function_type!(allocator, args, rettype),
Type::Primitive(prim) => prim.pretty(allocator),
Type::Structure(fields) => allocator.text("struct").append(
allocator
.concat(fields.iter().map(|(n, t)| {
allocator
.text(n.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(t.pretty(allocator))
.append(allocator.text(";"))
}))
.braces(),
),
}
}
}
impl TypeOrVar {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
TypeOrVar::Function(args, rettype) => pretty_function_type!(allocator, args, rettype),
TypeOrVar::Primitive(prim) => prim.pretty(allocator),
TypeOrVar::Variable(_, name) => allocator.text(name.to_string()),
TypeOrVar::Structure(fields) => allocator.text("struct").append(
allocator
.concat(fields.iter().map(|(n, t)| {
allocator
.text(n.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(t.pretty(allocator))
.append(allocator.text(";"))
}))
.braces(),
),
}
}
}
derived_display!(Program<Type>);
derived_display!(Expression<Type>);
derived_display!(Primitive);
derived_display!(Type);
derived_display!(TypeOrVar);
derived_display!(ValueOrRef<Type>);
derived_display!(Value);

View File

@@ -1,8 +1,8 @@
use super::ast::{Expression, Program, Statement}; use super::ast::{Expression, Program};
use internment::ArcIntern; use internment::ArcIntern;
use std::collections::HashSet; use std::collections::HashSet;
impl Program { impl<T> Program<T> {
/// Get the complete list of strings used within the program. /// Get the complete list of strings used within the program.
/// ///
/// For the purposes of this function, strings are the variables used in /// For the purposes of this function, strings are the variables used in
@@ -10,31 +10,17 @@ impl Program {
pub fn strings(&self) -> HashSet<ArcIntern<String>> { pub fn strings(&self) -> HashSet<ArcIntern<String>> {
let mut result = HashSet::new(); let mut result = HashSet::new();
for stmt in self.statements.iter() { for function in self.functions.values() {
stmt.register_strings(&mut result); function.body.register_strings(&mut result);
} }
result result
} }
} }
impl Statement { impl<T> Expression<T> {
fn register_strings(&self, string_set: &mut HashSet<ArcIntern<String>>) {
match self {
Statement::Binding(_, name, _, expr) => {
string_set.insert(name.clone());
expr.register_strings(string_set);
}
Statement::Print(_, _, name) => {
string_set.insert(name.clone());
}
}
}
}
impl Expression {
fn register_strings(&self, _string_set: &mut HashSet<ArcIntern<String>>) { fn register_strings(&self, _string_set: &mut HashSet<ArcIntern<String>>) {
// nothing has a string in here, at the moment // nothing has a string in here, at the moment
unimplemented!()
} }
} }

25
src/ir/top_level.rs Normal file
View File

@@ -0,0 +1,25 @@
use crate::ir::{Expression, Name, Program, TypeWithFunction, TypeWithVoid};
use std::collections::HashMap;
impl<T: Clone + TypeWithVoid + TypeWithFunction> Program<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this program.
pub fn get_top_level_variables(&self) -> HashMap<Name, T> {
self.body.get_top_level_variables()
}
}
impl<T: Clone> Expression<T> {
/// Retrieve the complete set of variables that are defined at the top level of
/// this expression. Basically, returns the variable named in bind.
pub fn get_top_level_variables(&self) -> HashMap<Name, T> {
match self {
Expression::Bind(_, name, ty, expr) => {
let mut tlvs = expr.get_top_level_variables();
tlvs.insert(name.clone(), ty.clone());
tlvs
}
_ => HashMap::new(),
}
}
}

View File

@@ -68,6 +68,7 @@ mod examples;
pub mod ir; pub mod ir;
pub mod syntax; pub mod syntax;
pub mod type_infer; pub mod type_infer;
pub mod util;
/// Implementation module for the high-level compiler. /// Implementation module for the high-level compiler.
mod compiler; mod compiler;

View File

@@ -1,13 +1,13 @@
use crate::backend::{Backend, BackendError}; use crate::backend::{Backend, BackendError};
use crate::syntax::{ConstantType, Location, ParserError, Statement}; use crate::syntax::{ConstantType, Expression, Location, Name, ParserError, Program, TopLevel};
use crate::type_infer::TypeInferenceResult; use crate::type_infer::TypeInferenceResult;
use crate::util::scoped_map::ScopedMap;
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use codespan_reporting::files::SimpleFiles; use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::{self, Config}; use codespan_reporting::term::{self, Config};
use cranelift_jit::JITModule; use cranelift_jit::JITModule;
use cranelift_module::ModuleError; use cranelift_module::ModuleError;
use pretty::termcolor::{ColorChoice, StandardStream}; use pretty::termcolor::{ColorChoice, StandardStream};
use std::collections::HashMap;
/// A high-level REPL helper for NGR. /// A high-level REPL helper for NGR.
/// ///
@@ -23,7 +23,7 @@ use std::collections::HashMap;
pub struct REPL { pub struct REPL {
file_database: SimpleFiles<String, String>, file_database: SimpleFiles<String, String>,
jitter: Backend<JITModule>, jitter: Backend<JITModule>,
variable_binding_sites: HashMap<String, Location>, variable_binding_sites: ScopedMap<String, Location>,
console: StandardStream, console: StandardStream,
console_config: Config, console_config: Config,
} }
@@ -70,7 +70,7 @@ impl REPL {
Ok(REPL { Ok(REPL {
file_database: SimpleFiles::new(), file_database: SimpleFiles::new(),
jitter: Backend::jit(None)?, jitter: Backend::jit(None)?,
variable_binding_sites: HashMap::new(), variable_binding_sites: ScopedMap::new(),
console, console,
console_config, console_config,
}) })
@@ -106,9 +106,9 @@ impl REPL {
pub fn process_input(&mut self, line_no: usize, command: String) { pub fn process_input(&mut self, line_no: usize, command: String) {
if let Err(err) = self.process(line_no, command) { if let Err(err) = self.process(line_no, command) {
if let Err(e) = self.emit_diagnostic(Diagnostic::from(err)) { if let Err(e) = self.emit_diagnostic(Diagnostic::from(err)) {
eprintln!( tracing::error!(
"WOAH! System having trouble printing error messages. This is very bad. ({})", error = %e,
e "WOAH! System having trouble printing error messages. This is very bad.",
); );
} }
} }
@@ -127,48 +127,43 @@ impl REPL {
.get(entry) .get(entry)
.expect("entry exists") .expect("entry exists")
.source(); .source();
let syntax = Statement::parse(entry, source)?; let syntax = TopLevel::parse(entry, source)?;
let top_levels = match syntax {
let program = match syntax { TopLevel::Expression(Expression::Binding(loc, name, expr)) => {
Statement::Binding(loc, name, expr) => {
// if this is a variable binding, and we've never defined this variable before, // if this is a variable binding, and we've never defined this variable before,
// we should tell cranelift about it. this is optimistic; if we fail to compile, // we should tell cranelift about it. this is optimistic; if we fail to compile,
// then we won't use this definition until someone tries again. // then we won't use this definition until someone tries again.
if !self.variable_binding_sites.contains_key(&name.name) { if !self
self.jitter.define_string(&name.name)?; .variable_binding_sites
.contains_key(&name.current_name().to_string())
{
self.jitter.define_string(name.current_name())?;
self.jitter self.jitter
.define_variable(name.to_string(), ConstantType::U64)?; .define_variable(name.clone(), ConstantType::U64)?;
} }
crate::syntax::Program { vec![
statements: vec![ TopLevel::Expression(Expression::Binding(loc.clone(), name.clone(), expr)),
Statement::Binding(loc.clone(), name.clone(), expr), TopLevel::Expression(Expression::Call(
Statement::Print(loc, name), loc.clone(),
], Box::new(Expression::Primitive(
} loc.clone(),
crate::syntax::Name::manufactured("print"),
)),
vec![Expression::Reference(name.clone())],
)),
]
} }
nonbinding => crate::syntax::Program { x => vec![x],
statements: vec![nonbinding],
},
}; };
let mut validation_result =
let (mut errors, mut warnings) = Program::validate_with_bindings(top_levels, &mut self.variable_binding_sites);
program.validate_with_bindings(&mut self.variable_binding_sites); for message in validation_result.diagnostics() {
let stop = !errors.is_empty();
let messages = errors
.drain(..)
.map(Into::into)
.chain(warnings.drain(..).map(Into::into));
for message in messages {
self.emit_diagnostic(message)?; self.emit_diagnostic(message)?;
} }
if stop { if let Some(program) = validation_result.into_result() {
return Ok(());
}
match program.type_infer() { match program.type_infer() {
TypeInferenceResult::Failure { TypeInferenceResult::Failure {
mut errors, mut errors,
@@ -182,8 +177,6 @@ impl REPL {
for message in messages { for message in messages {
self.emit_diagnostic(message)?; self.emit_diagnostic(message)?;
} }
Ok(())
} }
TypeInferenceResult::Success { TypeInferenceResult::Success {
@@ -193,15 +186,18 @@ impl REPL {
for message in warnings.drain(..).map(Into::into) { for message in warnings.drain(..).map(Into::into) {
self.emit_diagnostic(message)?; self.emit_diagnostic(message)?;
} }
let name = format!("line{}", line_no);
let function_id = self.jitter.compile_function(&name, result)?; let name = Name::new(format!("line{}", line_no), Location::manufactured());
let function_id = self.jitter.compile_program(name, result)?;
self.jitter.module.finalize_definitions()?; self.jitter.module.finalize_definitions()?;
let compiled_bytes = self.jitter.bytes(function_id); let compiled_bytes = self.jitter.bytes(function_id);
let compiled_function = let compiled_function =
unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) }; unsafe { std::mem::transmute::<_, fn() -> ()>(compiled_bytes) };
compiled_function(); compiled_function();
}
}
}
Ok(()) Ok(())
} }
} }
}
}

View File

@@ -8,7 +8,7 @@
//! //!
//! * Turning the string into a series of language-specific [`Token`]s. //! * Turning the string into a series of language-specific [`Token`]s.
//! * Taking those tokens, and computing a basic syntax tree from them, //! * Taking those tokens, and computing a basic syntax tree from them,
//! using our parser ([`ProgramParser`] or [`StatementParser`], generated //! using our parser ([`ProgramParser`] or [`TopLevelParser`], generated
//! by [`lalrpop`](https://lalrpop.github.io/lalrpop/)). //! by [`lalrpop`](https://lalrpop.github.io/lalrpop/)).
//! * Validating the tree we have parsed, using [`Program::validate`], //! * Validating the tree we have parsed, using [`Program::validate`],
//! returning any warnings or errors we have found. //! returning any warnings or errors we have found.
@@ -29,30 +29,28 @@ use logos::Logos;
pub mod arbitrary; pub mod arbitrary;
mod ast; mod ast;
mod eval; pub mod eval;
mod free_variables;
mod location; mod location;
mod name;
mod replace_references;
mod tokens; mod tokens;
lalrpop_mod!( lalrpop_mod!(
#[allow(clippy::just_underscores_and_digits, clippy::clone_on_copy)] #[allow(clippy::just_underscores_and_digits, clippy::clone_on_copy)]
parser, parser,
"/syntax/parser.rs" "/syntax/parser.rs"
); );
mod pretty; pub mod pretty;
mod validate; mod validate;
#[cfg(test)] pub use crate::syntax::arbitrary::ProgramGenerator;
use crate::syntax::arbitrary::GenerationEnvironment;
pub use crate::syntax::ast::*; pub use crate::syntax::ast::*;
pub use crate::syntax::location::Location; pub use crate::syntax::location::Location;
pub use crate::syntax::parser::{ProgramParser, StatementParser}; pub use crate::syntax::name::Name;
pub use crate::syntax::parser::{ExpressionParser, ProgramParser, TopLevelParser};
pub use crate::syntax::tokens::{LexerError, Token}; pub use crate::syntax::tokens::{LexerError, Token};
#[cfg(test)]
use ::pretty::{Arena, Pretty};
use lalrpop_util::ParseError; use lalrpop_util::ParseError;
#[cfg(test)] use std::ops::Range;
use proptest::{arbitrary::Arbitrary, prop_assert, prop_assert_eq};
#[cfg(test)]
use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
/// One of the many errors that can occur when processing text input. /// One of the many errors that can occur when processing text input.
@@ -105,7 +103,7 @@ impl ParserError {
/// closely. The major thing we do here is convert [`lalrpop`]'s notion of a location, /// closely. The major thing we do here is convert [`lalrpop`]'s notion of a location,
/// which is just an offset that it got from the lexer, into an actual location that /// which is just an offset that it got from the lexer, into an actual location that
/// we can use in our [`Diagnostic`]s. /// we can use in our [`Diagnostic`]s.
fn convert(file_idx: usize, err: ParseError<usize, Token, LexerError>) -> Self { fn convert(file_idx: usize, err: ParseError<usize, Token, ParserError>) -> Self {
match err { match err {
ParseError::InvalidToken { location } => { ParseError::InvalidToken { location } => {
ParserError::InvalidToken(Location::new(file_idx, location..location + 1)) ParserError::InvalidToken(Location::new(file_idx, location..location + 1))
@@ -123,11 +121,7 @@ impl ParserError {
ParseError::ExtraToken { ParseError::ExtraToken {
token: (start, token, end), token: (start, token, end),
} => ParserError::ExtraToken(Location::new(file_idx, start..end), token), } => ParserError::ExtraToken(Location::new(file_idx, start..end), token),
ParseError::User { error } => match error { ParseError::User { error } => error,
LexerError::LexFailure(offset) => {
ParserError::LexFailure(Location::new(file_idx, offset..offset + 1))
}
},
} }
} }
} }
@@ -209,7 +203,6 @@ impl<'a> From<&'a ParserError> for Diagnostic<usize> {
} }
} }
impl Program {
/// Parse the given file, adding it to the database as part of the process. /// Parse the given file, adding it to the database as part of the process.
/// ///
/// This operation reads the file from disk and adds it to the database for future /// This operation reads the file from disk and adds it to the database for future
@@ -221,11 +214,11 @@ impl Program {
pub fn parse_file( pub fn parse_file(
file_database: &mut SimpleFiles<String, String>, file_database: &mut SimpleFiles<String, String>,
file_name: &str, file_name: &str,
) -> Result<Self, ParserError> { ) -> Result<Vec<TopLevel>, ParserError> {
let file_contents = std::fs::read_to_string(file_name)?; let file_contents = std::fs::read_to_string(file_name)?;
let file_handle = file_database.add(file_name.to_string(), file_contents); let file_handle = file_database.add(file_name.to_string(), file_contents);
let file_db_info = file_database.get(file_handle)?; let file_db_info = file_database.get(file_handle)?;
Program::parse(file_handle, file_db_info.source()) parse_string(file_handle, file_db_info.source())
} }
/// Parse a block of text you have in memory, using the given index for [`Location`]s. /// Parse a block of text you have in memory, using the given index for [`Location`]s.
@@ -233,39 +226,58 @@ impl Program {
/// If you use a nonsensical file index, everything will work fine until you try to /// If you use a nonsensical file index, everything will work fine until you try to
/// report an error, at which point [`codespan_reporting`] may have some nasty things /// report an error, at which point [`codespan_reporting`] may have some nasty things
/// to say to you. /// to say to you.
pub fn parse(file_idx: usize, buffer: &str) -> Result<Program, ParserError> { pub fn parse_string(file_idx: usize, buffer: &str) -> Result<Vec<TopLevel>, ParserError> {
let lexer = Token::lexer(buffer) let lexer = Token::lexer(buffer)
.spanned() .spanned()
.map(|(token, range)| (range.start, token, range.end)); .map(|x| permute_lexer_result(file_idx, x));
ProgramParser::new() ProgramParser::new()
.parse(file_idx, lexer) .parse(file_idx, lexer)
.map_err(|e| ParserError::convert(file_idx, e)) .map_err(|e| ParserError::convert(file_idx, e))
} }
}
impl Statement { impl TopLevel {
/// Parse a statement that you have in memory, using the given index for [`Location`]s. /// Parse a top-level item that you have in memory, using the given index for [`Location`]s.
/// ///
/// As with [`Program::parse`], if you use a bad file index, you'll get weird behaviors /// As with [`Program::parse`], if you use a bad file index, you'll get weird behaviors
/// when you try to print errors, but things should otherwise work fine. This function /// when you try to print errors, but things should otherwise work fine. This function
/// will only parse a single statement, which is useful in the REPL, but probably shouldn't /// will only parse a single statement, which is useful in the REPL, but probably shouldn't
/// be used when reading in whole files. /// be used when reading in whole files.
pub fn parse(file_idx: usize, buffer: &str) -> Result<Statement, ParserError> { pub fn parse(file_idx: usize, buffer: &str) -> Result<TopLevel, ParserError> {
let lexer = Token::lexer(buffer) let lexer = Token::lexer(buffer)
.spanned() .spanned()
.map(|(token, range)| (range.start, token, range.end)); .map(|x| permute_lexer_result(file_idx, x));
StatementParser::new() TopLevelParser::new()
.parse(file_idx, lexer) .parse(file_idx, lexer)
.map_err(|e| ParserError::convert(file_idx, e)) .map_err(|e| ParserError::convert(file_idx, e))
} }
} }
#[cfg(test)] impl Expression {
impl FromStr for Program { /// Parse an expression from a string, using the given index for [`Location`]s.
type Err = ParserError; ///
/// As with [`Program::parse`], if you use a bad file index, you'll get weird behaviors
/// when you try to print errors, but things should otherwise work fine. This function
/// will only parse a single expression, which is useful for testing, but probably shouldn't
/// be used when reading in whole files.
pub fn parse(file_idx: usize, buffer: &str) -> Result<Expression, ParserError> {
let lexer = Token::lexer(buffer)
.spanned()
.map(|x| permute_lexer_result(file_idx, x));
ExpressionParser::new()
.parse(file_idx, lexer)
.map_err(|e| ParserError::convert(file_idx, e))
}
}
fn from_str(s: &str) -> Result<Program, ParserError> { fn permute_lexer_result(
Program::parse(0, s) file_idx: usize,
result: (Result<Token, ()>, Range<usize>),
) -> Result<(usize, Token, usize), ParserError> {
let (token, range) = result;
match token {
Ok(v) => Ok((range.start, v, range.end)),
Err(()) => Err(ParserError::LexFailure(Location::new(file_idx, range))),
} }
} }
@@ -274,22 +286,24 @@ fn order_of_operations() {
let muladd1 = "x = 1 + 2 * 3;"; let muladd1 = "x = 1 + 2 * 3;";
let testfile = 0; let testfile = 0;
assert_eq!( assert_eq!(
Program::from_str(muladd1).unwrap(), parse_string(0, muladd1).unwrap(),
Program { vec![TopLevel::Expression(Expression::Binding(
statements: vec![Statement::Binding(
Location::new(testfile, 0..1), Location::new(testfile, 0..1),
Name::manufactured("x"), Name::manufactured("x"),
Expression::Primitive( Box::new(Expression::Call(
Location::new(testfile, 6..7), Location::new(testfile, 6..7),
"+".to_string(), Box::new(Expression::Primitive(
Location::new(testfile, 6..7),
Name::manufactured("+")
)),
vec![ vec![
Expression::Value( Expression::Value(Location::new(testfile, 4..5), Value::Number(None, None, 1),),
Location::new(testfile, 4..5), Expression::Call(
Value::Number(None, None, 1),
),
Expression::Primitive(
Location::new(testfile, 10..11), Location::new(testfile, 10..11),
"*".to_string(), Box::new(Expression::Primitive(
Location::new(testfile, 10..11),
Name::manufactured("*")
)),
vec![ vec![
Expression::Value( Expression::Value(
Location::new(testfile, 8..9), Location::new(testfile, 8..9),
@@ -302,43 +316,44 @@ fn order_of_operations() {
] ]
) )
] ]
) ))
),], ))],
}
); );
} }
proptest::proptest! { proptest::proptest! {
#[test] #[test]
fn random_render_parses_equal(program: Program) { fn syntax_asts_roundtrip(program in self::arbitrary::ProgramGenerator::default()) {
let mut file_database = SimpleFiles::new(); use crate::util::pretty::Allocator;
let writer = ::pretty::termcolor::StandardStream::stderr(::pretty::termcolor::ColorChoice::Auto);
let config = codespan_reporting::term::Config::default();
let allocator = Arena::<()>::new();
let mut out_vector = vec![]; let allocator = Allocator::new();
prop_assert!(program.pretty(&allocator).render(80, &mut out_vector).is_ok()); let mut outbytes = Vec::new();
let string = std::str::from_utf8(&out_vector).expect("emitted valid string");
let file_handle = file_database.add("test", string);
let file_db_info = file_database.get(file_handle).expect("find thing just inserted");
let parsed = Program::parse(file_handle, file_db_info.source());
if let Err(e) = &parsed { for top_level in program.iter() {
eprintln!("failed to parse:\n{}", string); let docbuilder = top_level.pretty(&allocator);
codespan_reporting::term::emit(&mut writer.lock(), &config, &file_database, &e.into()).unwrap(); docbuilder.render(78, &mut outbytes).expect("can write program text");
} }
prop_assert_eq!(program, parsed.unwrap());
let outstr = std::str::from_utf8(&outbytes).expect("generated utf8");
println!("---------------- GENERATED TEXT -------------------");
println!("{}", outstr);
println!("---------------------------------------------------");
let syntax = crate::syntax::parse_string(0, &outstr).expect("generated text parses");
assert_eq!(program, syntax);
} }
#[test] #[test]
fn random_syntaxes_validate(program: Program) { fn random_syntaxes_validate(program in self::arbitrary::ProgramGenerator::default()) {
let (errors, _) = program.validate(); let result = Program::validate(program);
prop_assert!(errors.is_empty()); proptest::prop_assert!(result.is_ok());
} }
#[test] #[test]
fn generated_run_or_overflow(program in Program::arbitrary_with(GenerationEnvironment::new(false))) { fn generated_run_or_overflow(program in self::arbitrary::ProgramGenerator::default()) {
use crate::eval::{EvalError, PrimOpError}; use crate::eval::{EvalError, PrimOpError};
prop_assert!(matches!(program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_))))); let validated = Program::validate(program);
let actual_program = validated.into_result().expect("got a valid result");
proptest::prop_assert!(matches!(actual_program.eval(), Ok(_) | Err(EvalError::PrimOp(PrimOpError::MathFailure(_)))));
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,9 @@
use std::fmt; use crate::syntax::name::Name;
use std::hash::Hash;
use internment::ArcIntern;
pub use crate::syntax::tokens::ConstantType; pub use crate::syntax::tokens::ConstantType;
use crate::syntax::Location; use crate::syntax::Location;
use std::collections::HashMap;
use super::location::Located;
/// A structure represented a parsed program. /// A structure represented a parsed program.
/// ///
@@ -16,87 +15,75 @@ use crate::syntax::Location;
/// `validate` and it comes back without errors. /// `validate` and it comes back without errors.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct Program { pub struct Program {
pub statements: Vec<Statement>, pub functions: HashMap<Name, FunctionDefinition>,
pub structures: HashMap<Name, StructureDefinition>,
pub body: Expression,
} }
/// A Name. /// A function that we want to compile.
/// ///
/// This is basically a string, but annotated with the place the string /// Later, when we've done a lot of analysis, the `Option<Type>`s
/// is in the source file. /// will turn into concrete types. For now, though, we stick with
#[derive(Clone, Debug)] /// the surface syntax and leave them as optional. The name of the
pub struct Name { /// function is intentionally duplicated, to make our life easier.
pub name: String, #[derive(Clone, Debug, PartialEq)]
pub struct FunctionDefinition {
pub name: Name,
pub arguments: Vec<(Name, Option<Type>)>,
pub return_type: Option<Type>,
pub body: Expression,
}
impl FunctionDefinition {
pub fn new(
name: Name,
arguments: Vec<(Name, Option<Type>)>,
return_type: Option<Type>,
body: Expression,
) -> Self {
FunctionDefinition {
name,
arguments,
return_type,
body,
}
}
}
/// A structure type that we might want to reference in the future.
#[derive(Clone, Debug, PartialEq)]
pub struct StructureDefinition {
pub name: Name,
pub location: Location, pub location: Location,
pub fields: Vec<(Name, Option<Type>)>,
} }
impl Name { impl StructureDefinition {
pub fn new<S: ToString>(n: S, location: Location) -> Name { pub fn new(location: Location, name: Name, fields: Vec<(Name, Option<Type>)>) -> Self {
Name { StructureDefinition {
name: n.to_string(), name,
location, location,
fields,
}
} }
} }
pub fn manufactured<S: ToString>(n: S) -> Name { /// A thing that can sit at the top level of a file.
Name {
name: n.to_string(),
location: Location::manufactured(),
}
}
pub fn intern(self) -> ArcIntern<String> {
ArcIntern::new(self.name)
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state)
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name.fmt(f)
}
}
/// A parsed statement.
/// ///
/// Statements are guaranteed to be syntactically valid, but may be /// For the moment, these are statements and functions. Other things
/// complete nonsense at the semantic level. Which is to say, all the /// will likely be added in the future, but for now: just statements
/// print statements were correctly formatted, and all the variables /// and functions
/// referenced are definitely valid symbols, but they may not have #[derive(Clone, Debug, PartialEq)]
/// been defined or anything. pub enum TopLevel {
/// Expression(Expression),
/// Note that equivalence testing on statements is independent of Structure(Location, Name, Vec<(Name, Option<Type>)>),
/// source location; it is testing if the two statements say the same
/// thing, not if they are the exact same statement.
#[derive(Clone, Debug)]
pub enum Statement {
Binding(Location, Name, Expression),
Print(Location, Name),
} }
impl PartialEq for Statement { impl Located for TopLevel {
fn eq(&self, other: &Self) -> bool { fn location(&self) -> &Location {
match self { match self {
Statement::Binding(_, name1, expr1) => match other { TopLevel::Expression(exp) => exp.location(),
Statement::Binding(_, name2, expr2) => name1 == name2 && expr1 == expr2, TopLevel::Structure(loc, _, _) => loc,
_ => false,
},
Statement::Print(_, name1) => match other {
Statement::Print(_, name2) => name1 == name2,
_ => false,
},
} }
} }
} }
@@ -110,9 +97,34 @@ impl PartialEq for Statement {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Expression { pub enum Expression {
Value(Location, Value), Value(Location, Value),
Reference(Location, String), Constructor(Location, Name, Vec<(Name, Expression)>),
Reference(Name),
FieldRef(Location, Box<Expression>, Name),
Cast(Location, String, Box<Expression>), Cast(Location, String, Box<Expression>),
Primitive(Location, String, Vec<Expression>), Primitive(Location, Name),
Call(Location, Box<Expression>, Vec<Expression>),
Block(Location, Vec<Expression>),
Binding(Location, Name, Box<Expression>),
Function(
Location,
Option<Name>,
Vec<(Name, Option<Type>)>,
Option<Type>,
Box<Expression>,
),
}
impl Expression {
pub fn primitive(loc: Location, name: &str, args: Vec<Expression>) -> Expression {
Expression::Call(
loc.clone(),
Box::new(Expression::Primitive(
loc.clone(),
Name::new(name, loc.clone()),
)),
args,
)
}
} }
impl PartialEq for Expression { impl PartialEq for Expression {
@@ -122,18 +134,62 @@ impl PartialEq for Expression {
Expression::Value(_, val2) => val1 == val2, Expression::Value(_, val2) => val1 == val2,
_ => false, _ => false,
}, },
Expression::Reference(_, var1) => match other { Expression::Constructor(_, name1, fields1) => match other {
Expression::Reference(_, var2) => var1 == var2, Expression::Constructor(_, name2, fields2) => name1 == name2 && fields1 == fields2,
_ => false,
},
Expression::Reference(var1) => match other {
Expression::Reference(var2) => var1 == var2,
_ => false,
},
Expression::FieldRef(_, exp1, field1) => match other {
Expression::FieldRef(_, exp2, field2) => exp1 == exp2 && field1 == field2,
_ => false, _ => false,
}, },
Expression::Cast(_, t1, e1) => match other { Expression::Cast(_, t1, e1) => match other {
Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2, Expression::Cast(_, t2, e2) => t1 == t2 && e1 == e2,
_ => false, _ => false,
}, },
Expression::Primitive(_, prim1, args1) => match other { Expression::Primitive(_, prim1) => match other {
Expression::Primitive(_, prim2, args2) => prim1 == prim2 && args1 == args2, Expression::Primitive(_, prim2) => prim1 == prim2,
_ => false, _ => false,
}, },
Expression::Call(_, f1, a1) => match other {
Expression::Call(_, f2, a2) => f1 == f2 && a1 == a2,
_ => false,
},
Expression::Block(_, stmts1) => match other {
Expression::Block(_, stmts2) => stmts1 == stmts2,
_ => false,
},
Expression::Binding(_, name1, expr1) => match other {
Expression::Binding(_, name2, expr2) => name1 == name2 && expr1 == expr2,
_ => false,
},
Expression::Function(_, mname1, args1, mret1, body1) => match other {
Expression::Function(_, mname2, args2, mret2, body2) => {
mname1 == mname2 && args1 == args2 && mret1 == mret2 && body1 == body2
}
_ => false,
},
}
}
}
impl Located for Expression {
/// Get the location of the expression in the source file (if there is one).
fn location(&self) -> &Location {
match self {
Expression::Value(loc, _) => loc,
Expression::Constructor(loc, _, _) => loc,
Expression::Reference(n) => n.location(),
Expression::FieldRef(loc, _, _) => loc,
Expression::Cast(loc, _, _) => loc,
Expression::Primitive(loc, _) => loc,
Expression::Call(loc, _, _) => loc,
Expression::Block(loc, _) => loc,
Expression::Binding(loc, _, _) => loc,
Expression::Function(loc, _, _, _, _) => loc,
} }
} }
} }
@@ -149,4 +205,12 @@ pub enum Value {
/// operation "-" on the number 4. We'll translate this into a type-specific /// operation "-" on the number 4. We'll translate this into a type-specific
/// number at a later time. /// number at a later time.
Number(Option<u8>, Option<ConstantType>, u64), Number(Option<u8>, Option<ConstantType>, u64),
/// The empty value
Void,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Type {
Named(Name),
Struct(Vec<(Name, Option<Type>)>),
} }

View File

@@ -1,11 +1,12 @@
use internment::ArcIntern; use crate::eval::{EvalError, PrimitiveType, Value};
use crate::syntax::{ConstantType, Expression, Name, Program};
use crate::eval::{EvalEnvironment, EvalError, PrimitiveType, Value}; use crate::util::scoped_map::ScopedMap;
use crate::syntax::{ConstantType, Expression, Program, Statement}; use std::collections::HashMap;
use std::str::FromStr; use std::str::FromStr;
impl Program { impl Program {
/// Evaluate the program, returning either an error or what it prints out when run. /// Evaluate the program, returning either an error or a pair of the final value
/// produced and the output printed to the console.
/// ///
/// Doing this evaluation is particularly useful for testing, to ensure that if we /// Doing this evaluation is particularly useful for testing, to ensure that if we
/// modify a program in some way it does the same thing on both sides of the /// modify a program in some way it does the same thing on both sides of the
@@ -15,38 +16,26 @@ impl Program {
/// Note that the errors here are slightly more strict that we enforce at runtime. /// Note that the errors here are slightly more strict that we enforce at runtime.
/// For example, we check for overflow and underflow errors during evaluation, and /// For example, we check for overflow and underflow errors during evaluation, and
/// we don't check for those in the compiled code. /// we don't check for those in the compiled code.
pub fn eval(&self) -> Result<String, EvalError> { pub fn eval(&self) -> Result<(Value<Expression>, String), EvalError<Expression>> {
let mut env = EvalEnvironment::empty(); let mut env = ScopedMap::new();
let mut stdout = String::new(); let mut stdout = String::new();
let result = self.body.eval(&mut stdout, &mut env)?;
for stmt in self.statements.iter() { Ok((result, stdout))
// at this point, evaluation is pretty simple. just walk through each
// statement, in order, and record printouts as we come to them.
match stmt {
Statement::Binding(_, name, value) => {
let actual_value = value.eval(&env)?;
env = env.extend(name.clone().intern(), actual_value);
}
Statement::Print(_, name) => {
let value = env.lookup(name.clone().intern())?;
let line = format!("{} = {}\n", name, value);
stdout.push_str(&line);
}
}
}
Ok(stdout)
} }
} }
impl Expression { impl Expression {
fn eval(&self, env: &EvalEnvironment) -> Result<Value, EvalError> { fn eval(
&self,
stdout: &mut String,
env: &mut ScopedMap<Name, Value<Expression>>,
) -> Result<Value<Expression>, EvalError<Expression>> {
match self { match self {
Expression::Value(_, v) => match v { Expression::Value(_, v) => match v {
super::Value::Number(_, ty, v) => match ty { super::Value::Number(_, ty, v) => match ty {
None => Ok(Value::U64(*v)), None => Ok(Value::Number(*v)),
// FIXME: make these types validate their input size // FIXME: make these types validate their input size
Some(ConstantType::Void) => Ok(Value::Void),
Some(ConstantType::I8) => Ok(Value::I8(*v as i8)), Some(ConstantType::I8) => Ok(Value::I8(*v as i8)),
Some(ConstantType::I16) => Ok(Value::I16(*v as i16)), Some(ConstantType::I16) => Ok(Value::I16(*v as i16)),
Some(ConstantType::I32) => Ok(Value::I32(*v as i32)), Some(ConstantType::I32) => Ok(Value::I32(*v as i32)),
@@ -56,25 +45,141 @@ impl Expression {
Some(ConstantType::U32) => Ok(Value::U32(*v as u32)), Some(ConstantType::U32) => Ok(Value::U32(*v as u32)),
Some(ConstantType::U64) => Ok(Value::U64(*v)), Some(ConstantType::U64) => Ok(Value::U64(*v)),
}, },
super::Value::Void => Ok(Value::Void),
}, },
Expression::Reference(_, n) => Ok(env.lookup(ArcIntern::new(n.clone()))?), Expression::Constructor(_, on, fields) => {
let mut map = HashMap::with_capacity(fields.len());
for (k, v) in fields.iter() {
map.insert(k.clone(), v.eval(stdout, env)?);
}
Ok(Value::Structure(Some(on.clone()), map))
}
Expression::Reference(n) => env
.get(n)
.ok_or_else(|| {
EvalError::LookupFailed(n.location().clone(), n.current_name().to_string())
})
.cloned(),
Expression::FieldRef(loc, expr, field) => {
let struck = expr.eval(stdout, env)?;
if let Value::Structure(on, mut fields) = struck {
if let Some(value) = fields.remove(&field.clone()) {
Ok(value)
} else {
Err(EvalError::BadFieldForStructure(
loc.clone(),
on,
field.clone(),
))
}
} else {
Err(EvalError::NoFieldForValue(
loc.clone(),
struck,
field.clone(),
))
}
}
Expression::Cast(_, target, expr) => { Expression::Cast(_, target, expr) => {
let target_type = PrimitiveType::from_str(target)?; let target_type = PrimitiveType::from_str(target)?;
let value = expr.eval(env)?; let value = expr.eval(stdout, env)?;
Ok(target_type.safe_cast(&value)?) Ok(target_type.safe_cast(&value)?)
} }
Expression::Primitive(_, op, args) => { Expression::Primitive(_, op) => Ok(Value::primitive(op.original_name().to_string())),
let mut arg_values = Vec::with_capacity(args.len());
for arg in args.iter() { Expression::Call(loc, fun, args) => {
// yay, recursion! makes this pretty straightforward let function = fun.eval(stdout, env)?;
arg_values.push(arg.eval(env)?);
match function {
Value::Closure(name, mut closure_env, arguments, body) => {
if args.len() != arguments.len() {
return Err(EvalError::WrongArgCount(
loc.clone(),
name,
arguments.len(),
args.len(),
));
} }
Ok(Value::calculate(op, arg_values)?) closure_env.new_scope();
for (name, value) in arguments.into_iter().zip(args.iter()) {
let value = value.eval(stdout, env)?;
closure_env.insert(name, value);
}
let result = body.eval(stdout, &mut closure_env)?;
closure_env.release_scope();
Ok(result)
}
Value::Primitive(name) if name == "print" => {
if let [Expression::Reference(name)] = &args[..] {
let value = Expression::Reference(name.clone()).eval(stdout, env)?;
let value = match value {
Value::Number(x) => Value::U64(x),
x => x,
};
let addendum = format!("{} = {}\n", name, value);
stdout.push_str(&addendum);
Ok(Value::Void)
} else {
panic!(
"Non-reference/non-singleton argument to 'print': {:?}",
args
);
}
}
Value::Primitive(name) => {
let values = args
.iter()
.map(|x| x.eval(stdout, env))
.collect::<Result<_, _>>()?;
Value::calculate(name.as_str(), values).map_err(Into::into)
}
_ => Err(EvalError::NotAFunction(loc.clone(), function)),
}
}
Expression::Block(_, stmts) => {
let mut result = Value::Void;
for stmt in stmts.iter() {
result = stmt.eval(stdout, env)?;
}
Ok(result)
}
Expression::Binding(_, name, value) => {
let actual_value = value.eval(stdout, env)?;
env.insert(name.clone(), actual_value.clone());
Ok(actual_value)
}
Expression::Function(_, name, arg_names, _, body) => {
let result = Value::Closure(
name.clone(),
env.clone(),
arg_names.iter().cloned().map(|(x, _)| x.clone()).collect(),
*body.clone(),
);
if let Some(name) = name {
env.insert(name.clone(), result.clone());
}
Ok(result)
} }
} }
} }
@@ -82,14 +187,17 @@ impl Expression {
#[test] #[test]
fn two_plus_three() { fn two_plus_three() {
let input = Program::parse(0, "x = 2 + 3; print x;").expect("parse works"); let input = crate::syntax::parse_string(0, "x = 2 + 3; print x;").expect("parse works");
let output = input.eval().expect("runs successfully"); let program = Program::validate(input).into_result().unwrap();
let (_, output) = program.eval().expect("runs successfully");
assert_eq!("x = 5u64\n", &output); assert_eq!("x = 5u64\n", &output);
} }
#[test] #[test]
fn lotsa_math() { fn lotsa_math() {
let input = Program::parse(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works"); let input =
let output = input.eval().expect("runs successfully"); crate::syntax::parse_string(0, "x = 2 + 3 * 10 / 5 - 1; print x;").expect("parse works");
let program = Program::validate(input).into_result().unwrap();
let (_, output) = program.eval().expect("runs successfully");
assert_eq!("x = 7u64\n", &output); assert_eq!("x = 7u64\n", &output);
} }

View File

@@ -0,0 +1,161 @@
use crate::syntax::{Expression, Name};
use std::collections::HashSet;
impl Expression {
/// Find the set of free variables used within this expression.
///
/// Obviously, if this expression contains a function definition, argument
/// variables in the body will not be reported as free.
pub fn free_variables(&self) -> HashSet<Name> {
match self {
Expression::Value(_, _) => HashSet::new(),
Expression::Constructor(_, _, args) => {
args.iter().fold(HashSet::new(), |mut existing, (_, expr)| {
existing.extend(expr.free_variables());
existing
})
}
Expression::Reference(n) => HashSet::from([n.clone()]),
Expression::FieldRef(_, expr, _) => expr.free_variables(),
Expression::Cast(_, _, expr) => expr.free_variables(),
Expression::Primitive(_, _) => HashSet::new(),
Expression::Call(_, f, args) => {
args.iter().fold(f.free_variables(), |mut existing, expr| {
existing.extend(expr.free_variables());
existing
})
}
Expression::Block(_, exprs) => {
let mut free_vars = HashSet::new();
let mut bound_vars = HashSet::new();
for expr in exprs.iter() {
for var in expr.free_variables().into_iter() {
if !bound_vars.contains(&var) {
free_vars.insert(var);
}
}
bound_vars.extend(expr.new_bindings());
}
free_vars
}
Expression::Binding(_, _, expr) => expr.free_variables(),
Expression::Function(_, name, args, _, body) => {
let mut candidates = body.free_variables();
if let Some(name) = name {
candidates.remove(name);
}
for (name, _) in args.iter() {
candidates.remove(name);
}
candidates
}
}
}
/// Find the set of new bindings in the provided expression.
///
/// New bindings are those that introduce a variable that can be
/// referenced in subsequent statements / expressions within a
/// parent construct. This eventually means something in the next
/// block, but can involve some odd effects in the language.
pub fn new_bindings(&self) -> HashSet<Name> {
match self {
Expression::Value(_, _) => HashSet::new(),
Expression::Constructor(_, _, args) => {
args.iter().fold(HashSet::new(), |mut existing, (_, expr)| {
existing.extend(expr.new_bindings());
existing
})
}
Expression::Reference(_) => HashSet::new(),
Expression::FieldRef(_, expr, _) => expr.new_bindings(),
Expression::Cast(_, _, expr) => expr.new_bindings(),
Expression::Primitive(_, _) => HashSet::new(),
Expression::Call(_, f, args) => {
args.iter().fold(f.new_bindings(), |mut existing, expr| {
existing.extend(expr.new_bindings());
existing
})
}
Expression::Block(_, _) => HashSet::new(),
Expression::Binding(_, name, expr) => {
let mut others = expr.new_bindings();
others.insert(name.clone());
others
}
Expression::Function(_, Some(name), _, _, _) => HashSet::from([name.clone()]),
Expression::Function(_, None, _, _, _) => HashSet::new(),
}
}
}
#[test]
fn basic_frees_works() {
let test = Expression::parse(0, "1u64").unwrap();
assert_eq!(0, test.free_variables().len());
let test = Expression::parse(0, "1u64 + 2").unwrap();
assert_eq!(0, test.free_variables().len());
let test = Expression::parse(0, "x").unwrap();
assert_eq!(1, test.free_variables().len());
assert!(test.free_variables().contains(&Name::manufactured("x")));
let test = Expression::parse(0, "1 + x").unwrap();
assert_eq!(1, test.free_variables().len());
assert!(test.free_variables().contains(&Name::manufactured("x")));
let test = Expression::parse(0, "Structure{ field1: x; field2: y; }").unwrap();
assert_eq!(2, test.free_variables().len());
assert!(test.free_variables().contains(&Name::manufactured("x")));
assert!(test.free_variables().contains(&Name::manufactured("y")));
let test = Expression::parse(0, "{ print x; print y }").unwrap();
assert_eq!(2, test.free_variables().len());
assert!(test.free_variables().contains(&Name::manufactured("x")));
assert!(test.free_variables().contains(&Name::manufactured("y")));
}
#[test]
fn test_around_function() {
let nada = Expression::parse(0, "function(x) x").unwrap();
assert_eq!(0, nada.free_variables().len());
let lift = Expression::parse(0, "function(x) x + y").unwrap();
assert_eq!(1, lift.free_variables().len());
assert!(lift.free_variables().contains(&Name::manufactured("y")));
let nest = Expression::parse(0, "function(y) function(x) x + y").unwrap();
assert_eq!(0, nest.free_variables().len());
let multi = Expression::parse(0, "function(x, y) x + y + z").unwrap();
assert_eq!(1, multi.free_variables().len());
assert!(multi.free_variables().contains(&Name::manufactured("z")));
}
#[test]
fn test_is_set() {
let multi = Expression::parse(0, "function(x, y) x + y + z + z").unwrap();
assert_eq!(1, multi.free_variables().len());
assert!(multi.free_variables().contains(&Name::manufactured("z")));
}
#[test]
fn bindings_remove() {
let x_bind = Expression::parse(0, "{ x = 4; print x }").unwrap();
assert_eq!(0, x_bind.free_variables().len());
let inner = Expression::parse(0, "{ { x = 4; print x }; print y }").unwrap();
assert_eq!(1, inner.free_variables().len());
assert!(inner.free_variables().contains(&Name::manufactured("y")));
let inner = Expression::parse(0, "{ { x = 4; print x }; print x }").unwrap();
assert_eq!(1, inner.free_variables().len());
assert!(inner.free_variables().contains(&Name::manufactured("x")));
let double = Expression::parse(0, "{ x = y = 1; x + y }").unwrap();
assert_eq!(0, double.free_variables().len());
}

View File

@@ -12,6 +12,10 @@ pub struct Location {
location: Range<usize>, location: Range<usize>,
} }
pub trait Located {
fn location(&self) -> &Location;
}
impl Location { impl Location {
/// Generate a new `Location` from a file index and an offset from the /// Generate a new `Location` from a file index and an offset from the
/// start of the file. /// start of the file.
@@ -85,9 +89,9 @@ impl Location {
/// the user with some guidance. That being said, you still might want to add /// the user with some guidance. That being said, you still might want to add
/// even more information to ut, using [`Diagnostic::with_labels`], /// even more information to ut, using [`Diagnostic::with_labels`],
/// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`]. /// [`Diagnostic::with_notes`], or [`Diagnostic::with_code`].
pub fn labelled_error(&self, msg: &str) -> Diagnostic<usize> { pub fn labelled_error<T: AsRef<str>>(&self, msg: T) -> Diagnostic<usize> {
Diagnostic::error().with_labels(vec![ Diagnostic::error().with_labels(vec![
Label::primary(self.file_idx, self.location.clone()).with_message(msg) Label::primary(self.file_idx, self.location.clone()).with_message(msg.as_ref())
]) ])
} }
@@ -116,4 +120,36 @@ impl Location {
}) })
} }
} }
/// Infer a location set by combining all of the information we have
/// in the list of located things.
///
/// This will attempt to throw away manufactured locations whenever
/// possible, but if there's multiple files mixed in it will likely
/// fail. In all failure cases, including when the set of items is
/// empty, will return a manufactured location to use.
pub fn infer_from<T: Located>(items: &[T]) -> Location {
let mut current = None;
for item in items {
let location = item.location();
if (location.file_idx != 0)
|| (location.location.start != 0)
|| (location.location.end != 0)
{
current = match current {
None => Some(Some(location.clone())),
Some(None) => Some(None), // we ran into an error somewhere
Some(Some(actual)) => Some(actual.merge(location)),
};
}
}
match current {
None => Location::manufactured(),
Some(None) => Location::manufactured(),
Some(Some(x)) => x,
}
}
} }

143
src/syntax/name.rs Normal file
View File

@@ -0,0 +1,143 @@
use crate::syntax::Location;
use internment::ArcIntern;
use std::fmt;
use std::hash::Hash;
use std::sync::atomic::{AtomicU64, Ordering};
/// The name of a thing in the source language.
///
/// In many ways, you can treat this like a string, but it's a very tricky
/// string in a couple of ways:
///
/// First, it's a string associated with a particular location in the source
/// file, and you can find out what that source location is relatively easily.
///
/// Second, it's a name that retains something of its identity across renaming,
/// so that you can keep track of what a variables original name was, as well as
/// what it's new name is if it's been renamed.
///
/// Finally, when it comes to equality tests, comparisons, and hashing, `Name`
/// uses *only* the new name, if the variable has been renamed, or the original
/// name, if it has not been renamed. It never uses the location. This allows
/// relatively fast hashing and searching for things like binding sites, as the
/// value of the binding `Name` will be equal to the bound `Name`, even though
/// they occur at different locations.
#[derive(Clone, Debug)]
pub struct Name {
name: ArcIntern<String>,
rename: Option<ArcIntern<String>>,
location: Location,
}
impl Name {
/// Create a new name at the given location.
///
/// This creates an "original" name, which has not been renamed, at the
/// given location.
pub fn new<S: ToString>(n: S, location: Location) -> Name {
Name {
name: ArcIntern::new(n.to_string()),
rename: None,
location,
}
}
/// Create a new name with no location information.
///
/// This creates an "original" name, which has not been renamed, at the
/// given location. You should always prefer to use [`Location::new`] if
/// there is any possible way to get it, because that will be more
/// helpful to our users.
pub fn manufactured<S: ToString>(n: S) -> Name {
Name {
name: ArcIntern::new(n.to_string()),
rename: None,
location: Location::manufactured(),
}
}
/// Create a unique name based on the original name provided.
///
/// This will automatically append a number and wrap that in
/// <>, which is hoped to be unique.
pub fn gensym<S: ToString>(n: S) -> Name {
static GENSYM_COUNTER: AtomicU64 = AtomicU64::new(0);
let new_name = format!(
"<{}{}>",
n.to_string(),
GENSYM_COUNTER.fetch_add(1, Ordering::SeqCst)
);
Name {
name: ArcIntern::new(new_name),
rename: None,
location: Location::manufactured(),
}
}
/// As with gensym, but tie the name to the given location
pub fn located_gensym<S: ToString>(location: Location, n: S) -> Name {
Name {
location,
..Name::gensym(n)
}
}
/// Returns a reference to the original name of the variable.
///
/// Regardless of whether or not the function has been renamed, this will
/// return whatever name this variable started with.
pub fn original_name(&self) -> &str {
self.name.as_str()
}
/// Returns a reference to the current name of the variable.
///
/// If the variable has been renamed, it will return that, otherwise we'll
/// return the current name.
pub fn current_name(&self) -> &str {
self.rename
.as_ref()
.map(|x| x.as_str())
.unwrap_or_else(|| self.name.as_str())
}
/// Returns the current name of the variable as an interned string.
pub fn current_interned(&self) -> &ArcIntern<String> {
self.rename.as_ref().unwrap_or(&self.name)
}
/// Return the location of this name.
pub fn location(&self) -> &Location {
&self.location
}
/// Rename this variable to the given value
pub fn rename(&mut self, new_name: &ArcIntern<String>) {
self.rename = Some(new_name.clone());
}
pub fn intern(&self) -> ArcIntern<String> {
self.current_interned().clone()
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.current_interned() == other.current_interned()
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.current_interned().hash(state)
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.current_name().fmt(f)
}
}

View File

@@ -8,8 +8,9 @@
//! (Although, at some point, things can become so complicated that you might //! (Although, at some point, things can become so complicated that you might
//! eventually want to leave lalrpop behind.) //! eventually want to leave lalrpop behind.)
//! //!
use crate::syntax::{LexerError, Location}; use crate::syntax::{Location, ParserError};
use crate::syntax::ast::{Program,Statement,Expression,Value,Name}; use crate::syntax::ast::{TopLevel,Expression,Value,Type};
use crate::syntax::name::Name;
use crate::syntax::tokens::{ConstantType, Token}; use crate::syntax::tokens::{ConstantType, Token};
use internment::ArcIntern; use internment::ArcIntern;
@@ -24,23 +25,33 @@ grammar(file_idx: usize);
extern { extern {
type Location = usize; // Logos, our lexer, implements locations as type Location = usize; // Logos, our lexer, implements locations as
// offsets from the start of the file. // offsets from the start of the file.
type Error = LexerError; type Error = ParserError;
// here we redeclare all of the tokens. // here we redeclare all of the tokens.
enum Token { enum Token {
"=" => Token::Equals, "=" => Token::Equals,
":" => Token::Colon,
";" => Token::Semi, ";" => Token::Semi,
"," => Token::Comma,
"." => Token::Dot,
"(" => Token::LeftParen, "(" => Token::LeftParen,
")" => Token::RightParen, ")" => Token::RightParen,
"<" => Token::LessThan, "<" => Token::LessThan,
">" => Token::GreaterThan, ">" => Token::GreaterThan,
"_" => Token::Underscore,
"{" => Token::OpenBrace,
"}" => Token::CloseBrace,
"->" => Token::SingleArrow,
"function" => Token::Function,
"struct" => Token::Struct,
"print" => Token::Print, "print" => Token::Print,
"+" => Token::Operator('+'), "+" => Token::Operator('+'),
"-" => Token::Operator('-'), "-" => Token::Operator('-'),
"*" => Token::Operator('*'), "*" => Token::Operator('*'),
"/" => Token::Operator('/'), "/" => Token::Operator('/'),
"÷" => Token::Operator('/'),
// the previous items just match their tokens, and if you try // the previous items just match their tokens, and if you try
// to name and use "their value", you get their source location. // to name and use "their value", you get their source location.
@@ -48,62 +59,67 @@ extern {
// which is why we put their types in angle brackets. // which is why we put their types in angle brackets.
"<num>" => Token::Number((<Option<u8>>,<Option<ConstantType>>,<u64>)), "<num>" => Token::Number((<Option<u8>>,<Option<ConstantType>>,<u64>)),
"<var>" => Token::Variable(<ArcIntern<String>>), "<var>" => Token::Variable(<ArcIntern<String>>),
"<type>" => Token::TypeName(<ArcIntern<String>>),
} }
} }
pub Program: Program = { pub Program: Vec<TopLevel> = {
// a program is just a set of statements <mut rest: Program> <t:TopLevel> => {
<stmts:Statements> => Program { rest.push(t);
statements: stmts rest
}
}
Statements: Vec<Statement> = {
// a statement is either a set of statements followed by another
// statement (note, here, that you can name the result of a sub-parse
// using <name: subrule>) ...
<mut stmts:Statements> <stmt:Statement> => {
stmts.push(stmt);
stmts
}, },
=> vec![],
}
// ... or it's nothing. This may feel like an awkward way to define pub TopLevel: TopLevel = {
// lists of things -- and it is a bit awkward -- but there are actual <s:Structure> => s,
// technical reasons that you want to (a) use recursivion to define <s:Expression> ";" => TopLevel::Expression(s),
// these, and (b) use *left* recursion, specifically. That's why, in }
// this file, all of the recursive cases are to the left, like they
// are above. Argument: (Name, Option<Type>) = {
// <name_start: @L> <v:"<var>"> <name_end: @L> <t:(":" Type)?> =>
// the details of why left recursion is better is actually pretty (Name::new(v, Location::new(file_idx, name_start..name_end)), t.map(|v| v.1)),
// fiddly and in the weeds, and if you're interested you should look }
// up LALR parsers versus LL parsers; both their differences and how
// they're constructed, as they're kind of neat. OptionalComma: () = {
// => (),
// but if you're just writing grammars with lalrpop, then you should "," => (),
// just remember that you should always use left recursion, and be }
// done with it.
=> { Structure: TopLevel = {
Vec::new() <s:@L> "struct" <n: TypeName> "{" <fields: Field*> "}" <e:@L> => {
TopLevel::Structure(Location::new(file_idx, s..e), n, fields)
} }
} }
pub Statement: Statement = { Field: (Name, Option<Type>) = {
// A statement can be a variable binding. Note, here, that we use this <s:@L> <name:"<var>"> <e:@L> ":" <field_type: Type> ";" =>
// funny @L thing to get the source location before the variable, so that (Name::new(name, Location::new(file_idx, s..e)), Some(field_type)),
// we can say that this statement spans across everything. <s:@L> <name:"<var>"> <e:@L> ";" =>
<ls: @L> <v:"<var>"> <var_end: @L> "=" <e:Expression> ";" <le: @L> => (Name::new(name, Location::new(file_idx, s..e)), None),
Statement::Binding(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, ls..var_end)),
e,
),
// Alternatively, a statement can just be a print statement. }
<ls: @L> "print" <name_start: @L> <v:"<var>"> <name_end: @L> ";" <le: @L> =>
Statement::Print( Type: Type = {
Location::new(file_idx, ls..le), <name:Name> => Type::Named(name),
<t:TypeName> => Type::Named(t),
"struct" "{" <fields: TypeField*> "}" =>
Type::Struct(fields),
}
TypeField: (Name, Option<Type>) = {
<name: Name> ":" <ty: Type> ";" => (name, Some(ty)),
<name: Name> (":" "_")? ";" => (name, None),
}
Name: Name = {
<name_start: @L> <v:"<var>"> <name_end: @L> =>
Name::new(v, Location::new(file_idx, name_start..name_end)),
}
TypeName: Name = {
<name_start: @L> <v:"<type>"> <name_end: @L> =>
Name::new(v, Location::new(file_idx, name_start..name_end)), Name::new(v, Location::new(file_idx, name_start..name_end)),
),
} }
// Expressions! Expressions are a little fiddly, because we're going to // Expressions! Expressions are a little fiddly, because we're going to
@@ -129,33 +145,93 @@ pub Statement: Statement = {
// to run through a few examples. Consider thinking about how you want to // to run through a few examples. Consider thinking about how you want to
// parse something like "1 + 2 * 3", for example, versus "1 + 2 + 3" or // parse something like "1 + 2 * 3", for example, versus "1 + 2 + 3" or
// "1 * 2 + 3", and hopefully that'll help. // "1 * 2 + 3", and hopefully that'll help.
Expression: Expression = { pub Expression: Expression = {
BindingExpression,
}
BindingExpression: Expression = {
// An expression can be a variable binding. Note, here, that we use this
// funny @L thing to get the source location before the variable, so that
// we can say that this statement spans across everything.
<ls: @L> <v:"<var>"> <var_end: @L> "=" <e:BindingExpression> <le: @L> =>
Expression::Binding(
Location::new(file_idx, ls..le),
Name::new(v, Location::new(file_idx, ls..var_end)),
Box::new(e),
),
FunctionExpression,
}
FunctionExpression: Expression = {
<s:@L> "function" <opt_name:Name?> "(" <args:Comma<Argument>> ")" <ret:("->" Type)?> <exp:Expression> <e:@L> =>
Expression::Function(Location::new(file_idx, s..e), opt_name, args, ret.map(|x| x.1), Box::new(exp)),
PrintExpression,
}
PrintExpression: Expression = {
<ls: @L> "print" <pe: @L> <e: ConstructorExpression> <le: @L> =>
Expression::Call(
Location::new(file_idx, ls..le),
Box::new(
Expression::Primitive(
Location::new(file_idx, ls..pe),
Name::new("print", Location::new(file_idx, ls..pe)),
),
),
vec![e],
),
ConstructorExpression,
}
ConstructorExpression: Expression = {
<s:@L> <name:TypeName> "{" <fields:FieldSetter*> "}" <e:@L> =>
Expression::Constructor(Location::new(file_idx, s..e), name, fields),
AdditiveExpression, AdditiveExpression,
} }
FieldSetter: (Name, Expression) = {
<name:Name> ":" <expr:Expression> ";" => (name, expr),
}
// we group addition and subtraction under the heading "additive" // we group addition and subtraction under the heading "additive"
AdditiveExpression: Expression = { AdditiveExpression: Expression = {
<ls: @L> <e1:AdditiveExpression> <l: @L> "+" <e2:MultiplicativeExpression> <le: @L> => <ls: @L> <e1:AdditiveExpression> <l: @L> "+" <e2:MultiplicativeExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "+".to_string(), vec![e1, e2]), Expression::primitive(Location::new(file_idx, ls..le), "+", vec![e1, e2]),
<ls: @L> <e1:AdditiveExpression> <l: @L> "-" <e2:MultiplicativeExpression> <le: @L> => <ls: @L> <e1:AdditiveExpression> <l: @L> "-" <e2:MultiplicativeExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "-".to_string(), vec![e1, e2]), Expression::primitive(Location::new(file_idx, ls..le), "-", vec![e1, e2]),
MultiplicativeExpression, MultiplicativeExpression,
} }
// similarly, we group multiplication and division under "multiplicative" // similarly, we group multiplication and division under "multiplicative"
MultiplicativeExpression: Expression = { MultiplicativeExpression: Expression = {
<ls: @L> <e1:MultiplicativeExpression> <l: @L> "*" <e2:UnaryExpression> <le: @L> => <ls: @L> <e1:MultiplicativeExpression> <l: @L> "*" <e2:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "*".to_string(), vec![e1, e2]), Expression::primitive(Location::new(file_idx, ls..le), "*", vec![e1, e2]),
<ls: @L> <e1:MultiplicativeExpression> <l: @L> "/" <e2:UnaryExpression> <le: @L> => <ls: @L> <e1:MultiplicativeExpression> <l: @L> "/" <e2:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, ls..le), "/".to_string(), vec![e1, e2]), Expression::primitive(Location::new(file_idx, ls..le), "/", vec![e1, e2]),
UnaryExpression, UnaryExpression,
} }
UnaryExpression: Expression = { UnaryExpression: Expression = {
<l: @L> "-" <e:UnaryExpression> <le: @L> => <l: @L> "-" <e:UnaryExpression> <le: @L> =>
Expression::Primitive(Location::new(file_idx, l..le), "-".to_string(), vec![e]), Expression::primitive(Location::new(file_idx, l..le), "negate", vec![e]),
<l: @L> "<" <v:"<var>"> ">" <e:UnaryExpression> <le: @L> => <l: @L> "<" <v:"<var>"> ">" <e:UnaryExpression> <le: @L> =>
Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)), Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)),
<l: @L> "<" <v:"<type>"> ">" <e:UnaryExpression> <le: @L> =>
Expression::Cast(Location::new(file_idx, l..le), v.to_string(), Box::new(e)),
CallExpression,
}
CallExpression: Expression = {
<s: @L> <f:CallExpression> "(" <args: Comma<Expression>> ")" <e: @L> =>
Expression::Call(Location::new(file_idx, s..e), Box::new(f), args),
FieldExpression,
}
FieldExpression: Expression = {
<s: @L> <fe:FieldExpression> "." <field:Name> <e: @L> =>
Expression::FieldRef(Location::new(file_idx, s..e), Box::new(fe), field),
AtomicExpression, AtomicExpression,
} }
@@ -163,10 +239,33 @@ UnaryExpression: Expression = {
// they cannot be further divided into parts // they cannot be further divided into parts
AtomicExpression: Expression = { AtomicExpression: Expression = {
// just a variable reference // just a variable reference
<l: @L> <v:"<var>"> <end: @L> => Expression::Reference(Location::new(file_idx, l..end), v.to_string()), <l: @L> <v:"<var>"> <end: @L> => Expression::Reference(Name::new(v.to_string(), Location::new(file_idx, l..end))),
// just a number // just a number
<l: @L> <n:"<num>"> <end: @L> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)), <l: @L> <n:"<num>"> <end: @L> => Expression::Value(Location::new(file_idx, l..end), Value::Number(n.0, n.1, n.2)),
// this expression could actually be a block!
<s:@L> "{" <exprs:Expressions> ";"? "}" <e:@L> => Expression::Block(Location::new(file_idx, s..e), exprs),
<s:@L> "{" "}" <e:@L> => Expression::Block(Location::new(file_idx, s..e), vec![]),
// finally, let people parenthesize expressions and get back to a // finally, let people parenthesize expressions and get back to a
// lower precedence // lower precedence
"(" <e:Expression> ")" => e, "(" <e:Expression> ")" => e,
} }
Expressions: Vec<Expression> = {
<e:Expression> => vec![e],
<mut exps:Expressions> ";" <e:Expression> => {
exps.push(e);
exps
}
}
// Lifted from the LALRPop book, a comma-separated list of T that may or
// may not conclude with a comma.
Comma<T>: Vec<T> = {
<mut v:(<T> ",")*> <e:T?> => match e {
None => v,
Some(e) => {
v.push(e);
v
}
}
};

View File

@@ -1,90 +1,275 @@
use crate::syntax::ast::{Expression, Program, Statement, Value}; use crate::syntax::ast::{ConstantType, Expression, Program, TopLevel, Type, Value};
use pretty::{DocAllocator, DocBuilder, Pretty}; use crate::util::pretty::{derived_display, Allocator};
use pretty::{DocAllocator, DocBuilder};
use super::ConstantType; impl Program {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
let mut result = allocator.nil(); let mut result = allocator.nil();
for stmt in self.statements.iter() { for definition in self.structures.values() {
let mut interior = allocator.nil();
for (name, ty) in definition.fields.iter() {
let mut type_bit = allocator.nil();
if let Some(ty) = ty {
type_bit = allocator
.text(":")
.append(allocator.space())
.append(ty.pretty(allocator));
}
interior = interior
.append(name.original_name().to_string())
.append(type_bit)
.append(allocator.text(";"))
.append(allocator.hardline());
}
interior = interior.indent(9);
let start = allocator
.text("struct")
.append(allocator.space())
.append(allocator.text(definition.name.original_name().to_string()))
.append(allocator.space())
.append(allocator.text("{"));
let conclusion = allocator.text("}").append(allocator.hardline());
result = result.append(start.append(interior).append(conclusion));
}
for definition in self.functions.values() {
let mut return_type_bit = allocator.nil();
if let Some(rettype) = definition.return_type.as_ref() {
return_type_bit = allocator
.text("->")
.append(allocator.space())
.append(rettype.pretty(allocator));
}
result = result result = result
.append(allocator.text("function"))
.append(allocator.space())
.append(allocator.text(definition.name.original_name().to_string()))
.append(allocator.text("("))
.append(allocator.intersperse(
definition.arguments.iter().map(|(x, t)| {
let mut type_bit = allocator.nil();
if let Some(ty) = t {
type_bit = allocator
.text(":")
.append(allocator.space())
.append(ty.pretty(allocator));
}
allocator
.text(x.original_name().to_string())
.append(type_bit)
}),
allocator.text(","),
))
.append(allocator.text(")"))
.append(return_type_bit)
.append(allocator.softline())
.append(definition.body.pretty(allocator))
.append(allocator.hardline());
}
result.append(self.body.pretty(allocator))
}
}
impl TopLevel {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
TopLevel::Expression(expr) => expr
.pretty(allocator)
.append(allocator.text(";"))
.append(allocator.hardline()),
TopLevel::Structure(_, name, fields) => allocator
.text("struct")
.append(allocator.space())
.append(allocator.text(name.to_string()))
.append(allocator.space())
.append(allocator.text("{"))
.append(allocator.hardline())
.append(
allocator
.concat(fields.iter().map(|(name, ty)| {
let type_bit = if let Some(ty) = ty {
allocator
.text(":")
.append(allocator.space())
.append(ty.pretty(allocator))
} else {
allocator.nil()
};
allocator
.text(name.to_string())
.append(type_bit)
.append(allocator.text(";"))
.append(allocator.hardline())
}))
.indent(2),
)
.append(allocator.text("}"))
.append(allocator.hardline()),
}
}
}
impl Expression {
pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
Expression::Value(_, val) => val.pretty(allocator),
Expression::Constructor(_, name, fields) => allocator
.text(name.to_string())
.append(allocator.space())
.append(allocator.text("{"))
.append(allocator.hardline())
.append(
allocator
.concat(fields.iter().map(|(n, e)| {
allocator
.text(n.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(e.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline())
}))
.indent(2),
)
.append(allocator.text("}")),
Expression::Reference(var) => allocator.text(var.to_string()),
Expression::FieldRef(_, val, field) => val
.pretty(allocator)
.append(allocator.text("."))
.append(allocator.text(field.to_string())),
Expression::Cast(_, t, e) => allocator
.text(t.clone())
.angles()
.append(e.pretty(allocator)),
Expression::Primitive(_, op) => allocator.text(op.original_name().to_string()),
Expression::Call(_, fun, args) => {
let mut args = args.iter().map(|x| x.pretty(allocator)).collect::<Vec<_>>();
match fun.as_ref() {
Expression::Primitive(_, name)
if ["/", "*", "+", "-"].contains(&name.current_name())
&& args.len() == 2 =>
{
let second = args.pop().unwrap();
args.pop()
.unwrap()
.append(allocator.space())
.append(allocator.text(name.current_name().to_string()))
.append(allocator.space())
.append(second)
.parens()
}
Expression::Primitive(_, name)
if ["negate"].contains(&name.current_name()) && args.len() == 1 =>
{
allocator.text("-").append(args.pop().unwrap())
}
Expression::Primitive(_, name)
if ["print"].contains(&&name.current_name()) && args.len() == 1 =>
{
allocator
.text("print")
.append(allocator.space())
.append(args.pop().unwrap())
}
_ => {
let comma_sepped_args = allocator.intersperse(args, allocator.text(","));
fun.pretty(allocator).append(comma_sepped_args.parens())
}
}
}
Expression::Block(_, stmts) => match stmts.split_last() {
None => allocator.text("()"),
Some((last, &[])) => last.pretty(allocator),
Some((last, start)) => {
let beginning = allocator.text("{").append(allocator.hardline());
let mut inner = allocator.nil();
for stmt in start.iter() {
inner = inner
.append(stmt.pretty(allocator)) .append(stmt.pretty(allocator))
.append(allocator.text(";")) .append(allocator.text(";"))
.append(allocator.hardline()); .append(allocator.hardline());
} }
result inner = inner
} .append(last.pretty(allocator))
} .append(allocator.hardline());
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement inner = inner.indent(2);
where
A: 'a, beginning.append(inner).append(allocator.text("}"))
D: ?Sized + DocAllocator<'a, A>, }
{ },
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { Expression::Binding(_, var, expr) => allocator
match self {
Statement::Binding(_, var, expr) => allocator
.text(var.to_string()) .text(var.to_string())
.append(allocator.space()) .append(allocator.space())
.append(allocator.text("=")) .append(allocator.text("="))
.append(allocator.space()) .append(allocator.space())
.append(expr.pretty(allocator)), .append(expr.pretty(allocator)),
Statement::Print(_, var) => allocator Expression::Function(_, name, args, rettype, body) => allocator
.text("print") .text("function")
.append(allocator.space()) .append(allocator.space())
.append(allocator.text(var.to_string())), .append(
} name.as_ref()
} .map(|x| allocator.text(x.to_string()))
} .unwrap_or_else(|| allocator.nil()),
)
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression .append(
where allocator
A: 'a, .intersperse(
D: ?Sized + DocAllocator<'a, A>, args.iter().map(|(x, t)| {
{ allocator.text(x.to_string()).append(
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { t.as_ref()
match self { .map(|t| {
Expression::Value(_, val) => val.pretty(allocator), allocator
Expression::Reference(_, var) => allocator.text(var.to_string()), .text(":")
Expression::Cast(_, t, e) => allocator
.text(t.clone())
.angles()
.append(e.pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 1 => allocator
.text(op.to_string())
.append(exprs[0].pretty(allocator)),
Expression::Primitive(_, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(allocator.text(op.to_string()))
.append(allocator.space()) .append(allocator.space())
.append(right) .append(t.pretty(allocator))
.parens() })
} .unwrap_or_else(|| allocator.nil()),
Expression::Primitive(_, op, exprs) => { )
let call = allocator.text(op.to_string()); }),
let args = exprs.iter().map(|x| x.pretty(allocator)); allocator.text(","),
let comma_sepped_args = allocator.intersperse(args, CommaSep {}); )
call.append(comma_sepped_args.parens()) .parens(),
} )
.append(
rettype
.as_ref()
.map(|rettype| {
allocator
.space()
.append(allocator.text("->"))
.append(allocator.space())
.append(rettype.pretty(allocator))
})
.unwrap_or_else(|| allocator.nil()),
)
.append(allocator.space())
.append(body.pretty(allocator)),
} }
} }
} }
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value impl Value {
where pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
match self { match self {
Value::Number(opt_base, ty, value) => { Value::Number(opt_base, ty, value) => {
let value_str = match opt_base { let value_str = match opt_base {
@@ -98,6 +283,8 @@ where
allocator.text(value_str) allocator.text(value_str)
} }
Value::Void => allocator.text("void()"),
} }
} }
} }
@@ -105,6 +292,7 @@ where
fn type_suffix(x: &Option<ConstantType>) -> &'static str { fn type_suffix(x: &Option<ConstantType>) -> &'static str {
match x { match x {
None => "", None => "",
Some(ConstantType::Void) => panic!("Should never get a void type suffix."),
Some(ConstantType::I8) => "i8", Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16", Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32", Some(ConstantType::I32) => "i32",
@@ -116,15 +304,33 @@ fn type_suffix(x: &Option<ConstantType>) -> &'static str {
} }
} }
#[derive(Clone, Copy)] impl Type {
struct CommaSep {} pub fn pretty<'a>(&self, allocator: &'a Allocator<'a>) -> DocBuilder<'a, Allocator<'a>> {
match self {
Type::Named(x) => allocator.text(x.to_string()),
Type::Struct(fields) => allocator
.text("struct")
.append(allocator.space())
.append(allocator.intersperse(
fields.iter().map(|(name, ty)| {
allocator
.text(name.to_string())
.append(allocator.text(":"))
.append(allocator.space())
.append(
ty.as_ref()
.map(|x| x.pretty(allocator))
.unwrap_or_else(|| allocator.text("_")),
)
.append(allocator.text(";"))
}),
allocator.hardline(),
).braces())
}
}
}
impl<'a, D, A> Pretty<'a, D, A> for CommaSep derived_display!(Program);
where derived_display!(TopLevel);
A: 'a, derived_display!(Expression);
D: ?Sized + DocAllocator<'a, A>, derived_display!(Value);
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
allocator.text(",").append(allocator.space())
}
}

View File

@@ -0,0 +1,51 @@
use super::{Expression, Name};
use std::collections::HashMap;
impl Expression {
/// Replace all references in the given map to their alternative expression values
pub fn replace_references(&mut self, map: &mut HashMap<Name, Expression>) {
match self {
Expression::Value(_, _) => {}
Expression::Constructor(_, _, items) => {
for (_, item) in items.iter_mut() {
item.replace_references(map);
}
}
Expression::Reference(name) => match map.get(name) {
None => {}
Some(x) => *self = x.clone(),
},
Expression::FieldRef(_, subexp, _) => {
subexp.replace_references(map);
}
Expression::Cast(_, _, subexp) => {
subexp.replace_references(map);
}
Expression::Primitive(_, _) => {}
Expression::Call(_, fun, args) => {
fun.replace_references(map);
for arg in args.iter_mut() {
arg.replace_references(map);
}
}
Expression::Block(_, exprs) => {
for expr in exprs.iter_mut() {
expr.replace_references(map);
}
}
Expression::Binding(_, n, expr) => {
expr.replace_references(map);
map.remove(n);
}
Expression::Function(_, mname, args, _, body) => {
if let Some(name) = mname {
map.remove(name);
}
for (arg_name, _) in args.iter() {
map.remove(arg_name);
}
body.replace_references(map);
}
}
}
}

View File

@@ -1,7 +1,6 @@
use internment::ArcIntern; use internment::ArcIntern;
use logos::{Lexer, Logos}; use logos::{Lexer, Logos};
use std::fmt; use std::{fmt, str::FromStr};
use std::num::ParseIntError;
use thiserror::Error; use thiserror::Error;
/// A single token of the input stream; used to help the parsing go down /// A single token of the input stream; used to help the parsing go down
@@ -26,14 +25,29 @@ use thiserror::Error;
/// trait, you should get back the exact same token. /// trait, you should get back the exact same token.
#[derive(Logos, Clone, Debug, PartialEq, Eq)] #[derive(Logos, Clone, Debug, PartialEq, Eq)]
pub enum Token { pub enum Token {
// we're actually just going to skip whitespace, though
#[regex(r"[ \t\r\n\f]+", logos::skip)]
// this is an extremely simple version of comments, just line
// comments. More complicated /* */ comments can be harder to
// implement, and didn't seem worth it at the time.
#[regex(r"//.*", logos::skip)]
// Our first set of tokens are simple characters that we're // Our first set of tokens are simple characters that we're
// going to use to structure NGR programs. // going to use to structure NGR programs.
#[token("=")] #[token("=")]
Equals, Equals,
#[token(":")]
Colon,
#[token(";")] #[token(";")]
Semi, Semi,
#[token(",")]
Comma,
#[token(".")]
Dot,
#[token("(")] #[token("(")]
LeftParen, LeftParen,
@@ -46,6 +60,26 @@ pub enum Token {
#[token(">")] #[token(">")]
GreaterThan, GreaterThan,
#[token("_")]
Underscore,
#[token("{")]
OpenBrace,
#[token("}")]
CloseBrace,
#[token("->")]
SingleArrow,
#[token("λ")]
#[token("lambda")]
#[token("function")]
Function,
#[token("struct")]
Struct,
// Next we take of any reserved words; I always like to put // Next we take of any reserved words; I always like to put
// these before we start recognizing more complicated regular // these before we start recognizing more complicated regular
// expressions. I don't think it matters, but it works for me. // expressions. I don't think it matters, but it works for me.
@@ -54,7 +88,7 @@ pub enum Token {
// Next are the operators for NGR. We only have 4, now, but // Next are the operators for NGR. We only have 4, now, but
// we might extend these later, or even make them user-definable! // we might extend these later, or even make them user-definable!
#[regex(r"[+\-*/]", |v| v.slice().chars().next())] #[regex(r"[+\-*/÷]", |v| v.slice().chars().next())]
Operator(char), Operator(char),
/// Numbers capture both the value we read from the input, /// Numbers capture both the value we read from the input,
@@ -75,28 +109,30 @@ pub enum Token {
#[regex(r"[a-z][a-zA-Z0-9_]*", |v| ArcIntern::new(v.slice().to_string()))] #[regex(r"[a-z][a-zA-Z0-9_]*", |v| ArcIntern::new(v.slice().to_string()))]
Variable(ArcIntern<String>), Variable(ArcIntern<String>),
// the next token will be an error token // Type names; these are like variables, but must start with a capital
#[error] // letter.
// we're actually just going to skip whitespace, though #[regex(r"[A-Z][a-zA-Z0-9_]*", |v| ArcIntern::new(v.slice().to_string()))]
#[regex(r"[ \t\r\n\f]+", logos::skip)] TypeName(ArcIntern<String>),
// this is an extremely simple version of comments, just line
// comments. More complicated /* */ comments can be harder to
// implement, and didn't seem worth it at the time.
#[regex(r"//.*", logos::skip)]
/// This token represents that some core error happened in lexing;
/// possibly that something didn't match anything at all.
Error,
} }
impl fmt::Display for Token { impl fmt::Display for Token {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
Token::Equals => write!(f, "'='"), Token::Equals => write!(f, "'='"),
Token::Colon => write!(f, "':'"),
Token::Semi => write!(f, "';'"), Token::Semi => write!(f, "';'"),
Token::Comma => write!(f, "','"),
Token::Dot => write!(f, "'.'"),
Token::LeftParen => write!(f, "'('"), Token::LeftParen => write!(f, "'('"),
Token::RightParen => write!(f, "')'"), Token::RightParen => write!(f, "')'"),
Token::LessThan => write!(f, "<"), Token::LessThan => write!(f, "<"),
Token::GreaterThan => write!(f, ">"), Token::GreaterThan => write!(f, ">"),
Token::Underscore => write!(f, "_"),
Token::OpenBrace => write!(f, "{{"),
Token::CloseBrace => write!(f, "}}"),
Token::SingleArrow => write!(f, "->"),
Token::Function => write!(f, "function"),
Token::Struct => write!(f, "struct"),
Token::Print => write!(f, "'print'"), Token::Print => write!(f, "'print'"),
Token::Operator(c) => write!(f, "'{}'", c), Token::Operator(c) => write!(f, "'{}'", c),
Token::Number((None, otype, v)) => write!(f, "'{}{}'", v, display_optional_type(otype)), Token::Number((None, otype, v)) => write!(f, "'{}{}'", v, display_optional_type(otype)),
@@ -120,7 +156,7 @@ impl fmt::Display for Token {
) )
} }
Token::Variable(s) => write!(f, "'{}'", s), Token::Variable(s) => write!(f, "'{}'", s),
Token::Error => write!(f, "<error>"), Token::TypeName(s) => write!(f, "'{}'", s),
} }
} }
} }
@@ -154,11 +190,13 @@ pub enum ConstantType {
I16 = 21, I16 = 21,
I32 = 22, I32 = 22,
I64 = 23, I64 = 23,
Void = 255,
} }
impl From<ConstantType> for cranelift_codegen::ir::Type { impl From<ConstantType> for cranelift_codegen::ir::Type {
fn from(value: ConstantType) -> Self { fn from(value: ConstantType) -> Self {
match value { match value {
ConstantType::Void => cranelift_codegen::ir::types::I64,
ConstantType::I8 | ConstantType::U8 => cranelift_codegen::ir::types::I8, ConstantType::I8 | ConstantType::U8 => cranelift_codegen::ir::types::I8,
ConstantType::I16 | ConstantType::U16 => cranelift_codegen::ir::types::I16, ConstantType::I16 | ConstantType::U16 => cranelift_codegen::ir::types::I16,
ConstantType::I32 | ConstantType::U32 => cranelift_codegen::ir::types::I32, ConstantType::I32 | ConstantType::U32 => cranelift_codegen::ir::types::I32,
@@ -167,18 +205,48 @@ impl From<ConstantType> for cranelift_codegen::ir::Type {
} }
} }
impl ConstantType { pub struct StringNotConstantType();
/// Returns true if the given type is (a) numeric and (b) signed;
pub fn is_signed(&self) -> bool { impl FromStr for ConstantType {
matches!( type Err = StringNotConstantType;
self,
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 fn from_str(s: &str) -> Result<Self, Self::Err> {
) match s {
"i8" => Ok(ConstantType::I8),
"i16" => Ok(ConstantType::I16),
"i32" => Ok(ConstantType::I32),
"i64" => Ok(ConstantType::I64),
"u8" => Ok(ConstantType::U8),
"u16" => Ok(ConstantType::U16),
"u32" => Ok(ConstantType::U32),
"u64" => Ok(ConstantType::U64),
"void" => Ok(ConstantType::Void),
_ => Err(StringNotConstantType()),
}
}
} }
impl fmt::Display for ConstantType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConstantType::I8 => write!(f, "i8"),
ConstantType::I16 => write!(f, "i16"),
ConstantType::I32 => write!(f, "i32"),
ConstantType::I64 => write!(f, "i64"),
ConstantType::U8 => write!(f, "u8"),
ConstantType::U16 => write!(f, "u16"),
ConstantType::U32 => write!(f, "u32"),
ConstantType::U64 => write!(f, "u64"),
ConstantType::Void => write!(f, "void"),
}
}
}
impl ConstantType {
/// Return the set of types that can be safely casted into this type. /// Return the set of types that can be safely casted into this type.
pub fn safe_casts_to(self) -> Vec<ConstantType> { pub fn safe_casts_to(self) -> Vec<ConstantType> {
match self { match self {
ConstantType::Void => vec![ConstantType::Void],
ConstantType::I8 => vec![ConstantType::I8], ConstantType::I8 => vec![ConstantType::I8],
ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8], ConstantType::I16 => vec![ConstantType::I16, ConstantType::I8, ConstantType::U8],
ConstantType::I32 => vec![ ConstantType::I32 => vec![
@@ -226,6 +294,7 @@ impl ConstantType {
/// Return the name of the given type, as a string /// Return the name of the given type, as a string
pub fn name(&self) -> String { pub fn name(&self) -> String {
match self { match self {
ConstantType::Void => "void".to_string(),
ConstantType::I8 => "i8".to_string(), ConstantType::I8 => "i8".to_string(),
ConstantType::I16 => "i16".to_string(), ConstantType::I16 => "i16".to_string(),
ConstantType::I32 => "i32".to_string(), ConstantType::I32 => "i32".to_string(),
@@ -236,6 +305,32 @@ impl ConstantType {
ConstantType::U64 => "u64".to_string(), ConstantType::U64 => "u64".to_string(),
} }
} }
/// Return the set of all primitives that can return this
/// type, along with the argument types for those primitives.
///
/// A "None" value as an argument type means that the argument
/// type is unconstrained by the return type.
pub fn primitives_for(&self) -> Vec<(crate::ir::Primitive, Vec<Option<ConstantType>>)> {
use crate::ir::Primitive::*;
match self {
ConstantType::Void => vec![(Print, vec![None])],
ConstantType::I8 | ConstantType::I16 | ConstantType::I32 | ConstantType::I64 => vec![
(Plus, vec![Some(*self), Some(*self)]),
(Minus, vec![Some(*self), Some(*self)]),
(Times, vec![Some(*self), Some(*self)]),
(Divide, vec![Some(*self), Some(*self)]),
(Negate, vec![Some(*self)]),
],
ConstantType::U8 | ConstantType::U16 | ConstantType::U32 | ConstantType::U64 => vec![
(Plus, vec![Some(*self), Some(*self)]),
(Minus, vec![Some(*self), Some(*self)]),
(Times, vec![Some(*self), Some(*self)]),
(Divide, vec![Some(*self), Some(*self)]),
],
}
}
} }
#[derive(Debug, Error, PartialEq)] #[derive(Debug, Error, PartialEq)]
@@ -257,6 +352,7 @@ impl TryFrom<i64> for ConstantType {
21 => Ok(ConstantType::I16), 21 => Ok(ConstantType::I16),
22 => Ok(ConstantType::I32), 22 => Ok(ConstantType::I32),
23 => Ok(ConstantType::I64), 23 => Ok(ConstantType::I64),
255 => Ok(ConstantType::Void),
_ => Err(InvalidConstantType::Value(value)), _ => Err(InvalidConstantType::Value(value)),
} }
} }
@@ -269,7 +365,7 @@ impl TryFrom<i64> for ConstantType {
fn parse_number( fn parse_number(
base: Option<u8>, base: Option<u8>,
value: &Lexer<Token>, value: &Lexer<Token>,
) -> Result<(Option<u8>, Option<ConstantType>, u64), ParseIntError> { ) -> Result<(Option<u8>, Option<ConstantType>, u64), ()> {
let (radix, strval) = match base { let (radix, strval) = match base {
None => (10, value.slice()), None => (10, value.slice()),
Some(radix) => (radix, &value.slice()[2..]), Some(radix) => (radix, &value.slice()[2..]),
@@ -295,13 +391,14 @@ fn parse_number(
(None, strval) (None, strval)
}; };
let intval = u64::from_str_radix(strval, radix as u32)?; let intval = u64::from_str_radix(strval, radix as u32).map_err(|_| ())?;
Ok((base, declared_type, intval)) Ok((base, declared_type, intval))
} }
fn display_optional_type(otype: &Option<ConstantType>) -> &'static str { fn display_optional_type(otype: &Option<ConstantType>) -> &'static str {
match otype { match otype {
None => "", None => "",
Some(ConstantType::Void) => "void",
Some(ConstantType::I8) => "i8", Some(ConstantType::I8) => "i8",
Some(ConstantType::I16) => "i16", Some(ConstantType::I16) => "i16",
Some(ConstantType::I32) => "i32", Some(ConstantType::I32) => "i32",
@@ -316,18 +413,18 @@ fn display_optional_type(otype: &Option<ConstantType>) -> &'static str {
#[test] #[test]
fn lex_numbers() { fn lex_numbers() {
let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc 12u8 0xci64// 9"); let mut lex0 = Token::lexer("12 0b1100 0o14 0d12 0xc 12u8 0xci64// 9");
assert_eq!(lex0.next(), Some(Token::Number((None, None, 12)))); assert_eq!(lex0.next(), Some(Ok(Token::Number((None, None, 12)))));
assert_eq!(lex0.next(), Some(Token::Number((Some(2), None, 12)))); assert_eq!(lex0.next(), Some(Ok(Token::Number((Some(2), None, 12)))));
assert_eq!(lex0.next(), Some(Token::Number((Some(8), None, 12)))); assert_eq!(lex0.next(), Some(Ok(Token::Number((Some(8), None, 12)))));
assert_eq!(lex0.next(), Some(Token::Number((Some(10), None, 12)))); assert_eq!(lex0.next(), Some(Ok(Token::Number((Some(10), None, 12)))));
assert_eq!(lex0.next(), Some(Token::Number((Some(16), None, 12)))); assert_eq!(lex0.next(), Some(Ok(Token::Number((Some(16), None, 12)))));
assert_eq!( assert_eq!(
lex0.next(), lex0.next(),
Some(Token::Number((None, Some(ConstantType::U8), 12))) Some(Ok(Token::Number((None, Some(ConstantType::U8), 12))))
); );
assert_eq!( assert_eq!(
lex0.next(), lex0.next(),
Some(Token::Number((Some(16), Some(ConstantType::I64), 12))) Some(Ok(Token::Number((Some(16), Some(ConstantType::I64), 12))))
); );
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
@@ -335,46 +432,52 @@ fn lex_numbers() {
#[test] #[test]
fn lex_symbols() { fn lex_symbols() {
let mut lex0 = Token::lexer("x + \t y * \n z // rest"); let mut lex0 = Token::lexer("x + \t y * \n z // rest");
assert_eq!(lex0.next(), Some(Token::var("x"))); assert_eq!(lex0.next(), Some(Ok(Token::var("x"))));
assert_eq!(lex0.next(), Some(Token::Operator('+'))); assert_eq!(lex0.next(), Some(Ok(Token::Operator('+'))));
assert_eq!(lex0.next(), Some(Token::var("y"))); assert_eq!(lex0.next(), Some(Ok(Token::var("y"))));
assert_eq!(lex0.next(), Some(Token::Operator('*'))); assert_eq!(lex0.next(), Some(Ok(Token::Operator('*'))));
assert_eq!(lex0.next(), Some(Token::var("z"))); assert_eq!(lex0.next(), Some(Ok(Token::var("z"))));
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
#[test] #[test]
fn lexer_spans() { fn lexer_spans() {
let mut lex0 = Token::lexer("y = x + 1//foo").spanned(); let mut lex0 = Token::lexer("y = x + 1//foo").spanned();
assert_eq!(lex0.next(), Some((Token::var("y"), 0..1))); assert_eq!(lex0.next(), Some((Ok(Token::var("y")), 0..1)));
assert_eq!(lex0.next(), Some((Token::Equals, 2..3))); assert_eq!(lex0.next(), Some((Ok(Token::Equals), 2..3)));
assert_eq!(lex0.next(), Some((Token::var("x"), 4..5))); assert_eq!(lex0.next(), Some((Ok(Token::var("x")), 4..5)));
assert_eq!(lex0.next(), Some((Token::Operator('+'), 6..7))); assert_eq!(lex0.next(), Some((Ok(Token::Operator('+')), 6..7)));
assert_eq!(lex0.next(), Some((Token::Number((None, None, 1)), 8..9))); assert_eq!(
lex0.next(),
Some((Ok(Token::Number((None, None, 1))), 8..9))
);
assert_eq!(lex0.next(), None); assert_eq!(lex0.next(), None);
} }
#[test] #[test]
fn further_spans() { fn further_spans() {
let mut lex0 = Token::lexer("x = 2i64 + 2i64;\ny = -x;\nprint y;").spanned(); let mut lex0 = Token::lexer("x = 2i64 + 2i64;\ny = -x;\nprint y;").spanned();
assert_eq!(lex0.next(), Some((Token::var("x"), 0..1))); assert_eq!(lex0.next(), Some((Ok(Token::var("x")), 0..1)));
assert_eq!(lex0.next(), Some((Token::Equals, 2..3))); assert_eq!(lex0.next(), Some((Ok(Token::Equals), 2..3)));
assert_eq!( assert_eq!(
lex0.next(), lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 4..8)) Some((Ok(Token::Number((None, Some(ConstantType::I64), 2))), 4..8))
); );
assert_eq!(lex0.next(), Some((Token::Operator('+'), 9..10))); assert_eq!(lex0.next(), Some((Ok(Token::Operator('+')), 9..10)));
assert_eq!( assert_eq!(
lex0.next(), lex0.next(),
Some((Token::Number((None, Some(ConstantType::I64), 2)), 11..15)) Some((
Ok(Token::Number((None, Some(ConstantType::I64), 2))),
11..15
))
); );
assert_eq!(lex0.next(), Some((Token::Semi, 15..16))); assert_eq!(lex0.next(), Some((Ok(Token::Semi), 15..16)));
assert_eq!(lex0.next(), Some((Token::var("y"), 17..18))); assert_eq!(lex0.next(), Some((Ok(Token::var("y")), 17..18)));
assert_eq!(lex0.next(), Some((Token::Equals, 19..20))); assert_eq!(lex0.next(), Some((Ok(Token::Equals), 19..20)));
assert_eq!(lex0.next(), Some((Token::Operator('-'), 21..22))); assert_eq!(lex0.next(), Some((Ok(Token::Operator('-')), 21..22)));
assert_eq!(lex0.next(), Some((Token::var("x"), 22..23))); assert_eq!(lex0.next(), Some((Ok(Token::var("x")), 22..23)));
assert_eq!(lex0.next(), Some((Token::Semi, 23..24))); assert_eq!(lex0.next(), Some((Ok(Token::Semi), 23..24)));
assert_eq!(lex0.next(), Some((Token::Print, 25..30))); assert_eq!(lex0.next(), Some((Ok(Token::Print), 25..30)));
assert_eq!(lex0.next(), Some((Token::var("y"), 31..32))); assert_eq!(lex0.next(), Some((Ok(Token::var("y")), 31..32)));
assert_eq!(lex0.next(), Some((Token::Semi, 32..33))); assert_eq!(lex0.next(), Some((Ok(Token::Semi), 32..33)));
} }

View File

@@ -1,9 +1,12 @@
use crate::{ use crate::eval::PrimitiveType;
eval::PrimitiveType, use crate::syntax::{Expression, Location, Program, StructureDefinition, TopLevel};
syntax::{Expression, Location, Program, Statement}, use crate::util::scoped_map::ScopedMap;
}; use crate::util::warning_result::WarningResult;
use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::diagnostic::Diagnostic;
use std::{collections::HashMap, str::FromStr}; use std::collections::HashMap;
use std::str::FromStr;
use super::{FunctionDefinition, Name, Type};
/// An error we found while validating the input program. /// An error we found while validating the input program.
/// ///
@@ -12,6 +15,7 @@ use std::{collections::HashMap, str::FromStr};
/// that we're not going to be able to work through. As with most /// that we're not going to be able to work through. As with most
/// of these errors, we recommend converting this to a [`Diagnostic`] /// of these errors, we recommend converting this to a [`Diagnostic`]
/// and using [`codespan_reporting`] to present them to the user. /// and using [`codespan_reporting`] to present them to the user.
#[derive(Debug)]
pub enum Error { pub enum Error {
UnboundVariable(Location, String), UnboundVariable(Location, String),
UnknownType(Location, String), UnknownType(Location, String),
@@ -64,9 +68,9 @@ impl Program {
/// This checks for things like references to variables that don't exist, for /// This checks for things like references to variables that don't exist, for
/// example, and generates warnings for things that are inadvisable but not /// example, and generates warnings for things that are inadvisable but not
/// actually a problem. /// actually a problem.
pub fn validate(&self) -> (Vec<Error>, Vec<Warning>) { pub fn validate(raw_syntax: Vec<TopLevel>) -> WarningResult<Program, Warning, Error> {
let mut bound_variables = HashMap::new(); let mut bound_variables = ScopedMap::new();
self.validate_with_bindings(&mut bound_variables) Self::validate_with_bindings(raw_syntax, &mut bound_variables)
} }
/// Validate that the program makes semantic sense, not just syntactic sense. /// Validate that the program makes semantic sense, not just syntactic sense.
@@ -75,98 +79,254 @@ impl Program {
/// example, and generates warnings for things that are inadvisable but not /// example, and generates warnings for things that are inadvisable but not
/// actually a problem. /// actually a problem.
pub fn validate_with_bindings( pub fn validate_with_bindings(
&self, raw_syntax: Vec<TopLevel>,
bound_variables: &mut HashMap<String, Location>, bound_variables: &mut ScopedMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) { ) -> WarningResult<Program, Warning, Error> {
let mut errors = vec![]; let mut functions = HashMap::new();
let mut warnings = vec![]; let mut structures = HashMap::new();
let mut result = WarningResult::ok(vec![]);
let location = Location::infer_from(&raw_syntax);
for stmt in self.statements.iter() { for stmt in raw_syntax.into_iter() {
let (mut new_errors, mut new_warnings) = stmt.validate(bound_variables); match stmt {
errors.append(&mut new_errors); TopLevel::Expression(expr) => {
warnings.append(&mut new_warnings); let expr_result =
expr.validate(bound_variables, &mut structures, &mut functions);
result = result.merge_with(expr_result, |mut previous, current| {
previous.push(current);
Ok(previous)
});
} }
(errors, warnings) TopLevel::Structure(loc, name, fields) => {
let definition =
StructureDefinition::new(loc, name.clone(), fields.into_iter().collect());
structures.insert(name, definition);
}
} }
} }
impl Statement { result.map(move |exprs| Program {
/// Validate that the statement makes semantic sense, not just syntactic sense. functions,
/// structures,
/// This checks for things like references to variables that don't exist, for body: Expression::Block(location, exprs),
/// example, and generates warnings for things that are inadvisable but not })
/// actually a problem. Since statements appear in a broader context, you'll }
/// need to provide the set of variables that are bound where this statement }
/// occurs. We use a `HashMap` to map these bound locations to the locations
/// where their bound, because these locations are handy when generating errors impl Expression {
/// and warnings.
fn validate( fn validate(
&self, self,
bound_variables: &mut HashMap<String, Location>, variable_map: &mut ScopedMap<String, Location>,
) -> (Vec<Error>, Vec<Warning>) { structure_map: &mut HashMap<Name, StructureDefinition>,
let mut errors = vec![]; function_map: &mut HashMap<Name, FunctionDefinition>,
let mut warnings = vec![]; ) -> WarningResult<Expression, Warning, Error> {
match self { match self {
Statement::Binding(loc, var, val) => { Expression::Value(_, _) => WarningResult::ok(self),
Expression::Constructor(location, name, fields) => {
let mut result = WarningResult::ok(vec![]);
for (name, expr) in fields.into_iter() {
let expr_result = expr.validate(variable_map, structure_map, function_map);
result = result.merge_with(expr_result, move |mut fields, new_expr| {
fields.push((name, new_expr));
Ok(fields)
});
}
result.map(move |fields| Expression::Constructor(location, name, fields))
}
Expression::Reference(ref var)
if variable_map.contains_key(&var.original_name().to_string()) =>
{
WarningResult::ok(self)
}
Expression::Reference(var) => WarningResult::err(Error::UnboundVariable(
var.location().clone(),
var.original_name().to_string(),
)),
Expression::FieldRef(location, exp, field) => exp
.validate(variable_map, structure_map, function_map)
.map(|x| Expression::FieldRef(location, Box::new(x), field)),
Expression::Cast(location, t, expr) => {
let mut expr_result = expr.validate(variable_map, structure_map, function_map);
if PrimitiveType::from_str(&t).is_err() {
expr_result.add_error(Error::UnknownType(location.clone(), t.clone()));
}
expr_result.map(|e| Expression::Cast(location, t, Box::new(e)))
}
// FIXME: Check for valid primitives here!!
Expression::Primitive(_, _) => WarningResult::ok(self),
Expression::Call(loc, func, args) => {
let mut result = func
.validate(variable_map, structure_map, function_map)
.map(|x| (x, vec![]));
for arg in args.into_iter() {
let expr_result = arg.validate(variable_map, structure_map, function_map);
result =
result.merge_with(expr_result, |(func, mut previous_args), new_arg| {
previous_args.push(new_arg);
Ok((func, previous_args))
});
}
result.map(|(func, args)| Expression::Call(loc, Box::new(func), args))
}
Expression::Block(loc, stmts) => {
let mut result = WarningResult::ok(vec![]);
for stmt in stmts.into_iter() {
let stmt_result = stmt.validate(variable_map, structure_map, function_map);
result = result.merge_with(stmt_result, |mut stmts, stmt| {
stmts.push(stmt);
Ok(stmts)
});
}
result.map(|stmts| Expression::Block(loc, stmts))
}
Expression::Binding(loc, var, val) => {
// we're going to make the decision that a variable is not bound in the right // we're going to make the decision that a variable is not bound in the right
// hand side of its binding, which makes a lot of things easier. So we'll just // hand side of its binding, which makes a lot of things easier. So we'll just
// immediately check the expression, and go from there. // immediately check the expression, and go from there.
let (mut exp_errors, mut exp_warnings) = val.validate(bound_variables); let mut result = val.validate(variable_map, structure_map, function_map);
errors.append(&mut exp_errors); if let Some(original_binding_site) =
warnings.append(&mut exp_warnings); variable_map.get(&var.original_name().to_string())
if let Some(original_binding_site) = bound_variables.get(&var.name) { {
warnings.push(Warning::ShadowedVariable( result.add_warning(Warning::ShadowedVariable(
original_binding_site.clone(), original_binding_site.clone(),
loc.clone(), loc.clone(),
var.to_string(), var.to_string(),
)); ));
} else { } else {
bound_variables.insert(var.to_string(), loc.clone()); variable_map.insert(var.to_string(), loc.clone());
}
result.map(|val| Expression::Binding(loc, var, Box::new(val)))
}
Expression::Function(loc, name, mut arguments, return_type, body) => {
let mut result = WarningResult::ok(());
// first we should check for shadowing
for new_name in name.iter().chain(arguments.iter().map(|x| &x.0)) {
if let Some(original_site) = variable_map.get(new_name.original_name()) {
result.add_warning(Warning::ShadowedVariable(
original_site.clone(),
loc.clone(),
new_name.original_name().to_string(),
));
} }
} }
Statement::Print(_, var) if bound_variables.contains_key(&var.name) => {} // the function name is now available in our current scope, if the function was given one
Statement::Print(loc, var) => { if let Some(name) = &name {
errors.push(Error::UnboundVariable(loc.clone(), var.to_string())) variable_map.insert(name.original_name().to_string(), name.location().clone());
}
} }
(errors, warnings) // the arguments are available in a new scope, which we will use to validate the function
} // body
variable_map.new_scope();
for (arg, _) in arguments.iter() {
variable_map.insert(arg.original_name().to_string(), arg.location().clone());
} }
impl Expression { let body_result = body.validate(variable_map, structure_map, function_map);
fn validate(&self, variable_map: &HashMap<String, Location>) -> (Vec<Error>, Vec<Warning>) { variable_map.release_scope();
match self {
Expression::Value(_, _) => (vec![], vec![]), body_result.merge_with(result, move |mut body, _| {
Expression::Reference(_, var) if variable_map.contains_key(var) => (vec![], vec![]), // figure out what, if anything, needs to be in the closure for this function.
Expression::Reference(loc, var) => ( let mut free_variables = body.free_variables();
vec![Error::UnboundVariable(loc.clone(), var.clone())], for (n, _) in arguments.iter() {
vec![], free_variables.remove(n);
}
// generate a new name for the closure type we're about to create
let closure_type_name = Name::located_gensym(
loc.clone(),
name.as_ref().map(Name::original_name).unwrap_or("closure_"),
);
// ... and then create a structure type that has all of the free variables
// in it
let closure_type = StructureDefinition::new(
loc.clone(),
closure_type_name.clone(),
free_variables.iter().map(|x| (x.clone(), None)).collect(),
);
// this will become the first argument of the function, so name it and add
// it to the argument list.
let closure_arg = Name::gensym("__closure_arg");
arguments.insert(
0,
(
closure_arg.clone(),
Some(Type::Named(closure_type_name.clone())),
), ),
Expression::Cast(location, t, expr) => { );
let (mut errs, warns) = expr.validate(variable_map); // Now make a map from the old free variable names to references into
// our closure argument
let rebinds = free_variables
.into_iter()
.map(|n| {
(
n.clone(),
Expression::FieldRef(
n.location().clone(),
Box::new(Expression::Reference(closure_arg.clone())),
n,
),
)
})
.collect::<Vec<(Name, Expression)>>();
let mut rebind_map = rebinds.iter().cloned().collect();
// and replace all the references in the function with this map
body.replace_references(&mut rebind_map);
// OK! This function definitely needs a name; if the user didn't give
// it one, we'll do so.
let function_name =
name.unwrap_or_else(|| Name::located_gensym(loc.clone(), "function"));
// And finally, we can make the function definition and insert it into our global
// list along with the new closure type.
let function = FunctionDefinition::new(
function_name.clone(),
arguments.clone(),
return_type.clone(),
body,
);
if PrimitiveType::from_str(t).is_err() { structure_map.insert(closure_type_name.clone(), closure_type);
errs.push(Error::UnknownType(location.clone(), t.clone())) function_map.insert(function_name.clone(), function);
}
(errs, warns) // And the result of this function is a call to a primitive that generates
} // the closure value in some sort of reasonable way.
Expression::Primitive(_, _, args) => { Ok(Expression::Call(
let mut errors = vec![]; Location::manufactured(),
let mut warnings = vec![]; Box::new(Expression::Primitive(
Location::manufactured(),
for expr in args.iter() { Name::new("<closure>", Location::manufactured()),
let (mut err, mut warn) = expr.validate(variable_map); )),
errors.append(&mut err); vec![
warnings.append(&mut warn); Expression::Reference(function_name),
} Expression::Constructor(
Location::manufactured(),
(errors, warnings) closure_type_name,
rebinds,
),
],
))
})
} }
} }
} }
@@ -174,16 +334,19 @@ impl Expression {
#[test] #[test]
fn cast_checks_are_reasonable() { fn cast_checks_are_reasonable() {
let good_stmt = Statement::parse(0, "x = <u16>4u8;").expect("valid test case"); let mut variable_map = ScopedMap::new();
let (good_errs, good_warns) = good_stmt.validate(&mut HashMap::new()); let mut structure_map = HashMap::new();
let mut function_map = HashMap::new();
assert!(good_errs.is_empty()); let good_stmt = Expression::parse(0, "x = <u16>4u8;").expect("valid test case");
assert!(good_warns.is_empty()); let result_good = good_stmt.validate(&mut variable_map, &mut structure_map, &mut function_map);
let bad_stmt = Statement::parse(0, "x = <apple>4u8;").expect("valid test case"); assert!(result_good.is_ok());
let (bad_errs, bad_warns) = bad_stmt.validate(&mut HashMap::new()); assert!(result_good.warnings().is_empty());
assert!(bad_warns.is_empty()); let bad_stmt = Expression::parse(0, "x = <apple>4u8;").expect("valid test case");
assert_eq!(bad_errs.len(), 1); let result_err = bad_stmt.validate(&mut variable_map, &mut structure_map, &mut function_map);
assert!(matches!(bad_errs[0], Error::UnknownType(_, ref x) if x == "apple"));
assert!(result_err.is_err());
assert!(result_err.warnings().is_empty());
} }

View File

@@ -10,21 +10,32 @@
//! all the constraints we've generated. If that's successful, in the final phase, we //! all the constraints we've generated. If that's successful, in the final phase, we
//! do the final conversion to the IR AST, filling in any type information we've learned //! do the final conversion to the IR AST, filling in any type information we've learned
//! along the way. //! along the way.
mod ast; mod constraint;
mod convert; mod convert;
mod error;
mod finalize; mod finalize;
mod result;
mod solve; mod solve;
mod warning;
use self::convert::convert_program; use self::constraint::Constraint;
use self::finalize::finalize_program; use self::error::TypeInferenceError;
use self::solve::solve_constraints; pub use self::result::TypeInferenceResult;
pub use self::solve::{TypeInferenceError, TypeInferenceResult, TypeInferenceWarning}; use self::warning::TypeInferenceWarning;
use crate::ir::ast as ir; use crate::ir::ast as ir;
use crate::syntax; use crate::syntax;
#[cfg(test)] use crate::syntax::Name;
use crate::syntax::arbitrary::GenerationEnvironment; use std::collections::HashMap;
#[cfg(test)]
use proptest::prelude::Arbitrary; struct InferenceEngine {
constraints: Vec<Constraint>,
type_definitions: HashMap<Name, ir::TypeOrVar>,
variable_types: HashMap<Name, ir::TypeOrVar>,
functions: HashMap<Name, ir::FunctionDefinition<ir::TypeOrVar>>,
body: ir::Expression<ir::TypeOrVar>,
errors: Vec<TypeInferenceError>,
warnings: Vec<TypeInferenceWarning>,
}
impl syntax::Program { impl syntax::Program {
/// Infer the types for the syntactic AST, returning either a type-checked program in /// Infer the types for the syntactic AST, returning either a type-checked program in
@@ -32,21 +43,53 @@ impl syntax::Program {
/// ///
/// You really should have made sure that this program was validated before running /// You really should have made sure that this program was validated before running
/// this method, otherwise you may experience panics during operation. /// this method, otherwise you may experience panics during operation.
pub fn type_infer(self) -> TypeInferenceResult<ir::Program> { pub fn type_infer(self) -> TypeInferenceResult<ir::Program<ir::Type>> {
let mut constraint_db = vec![]; let mut engine = InferenceEngine::from(self);
let program = convert_program(self, &mut constraint_db); engine.solve_constraints();
let inference_result = solve_constraints(constraint_db);
inference_result.map(|resolutions| finalize_program(program, &resolutions)) if engine.errors.is_empty() {
let resolutions = std::mem::take(&mut engine.constraints)
.into_iter()
.map(|constraint| match constraint {
Constraint::Equivalent(_, ir::TypeOrVar::Variable(_, name), result) => {
match result.try_into() {
Err(e) => panic!("Ended up with complex type {}", e),
Ok(v) => (name, v),
}
}
_ => panic!("Had something that wasn't an equivalence left at the end!"),
})
.collect();
let warnings = std::mem::take(&mut engine.warnings);
TypeInferenceResult::Success {
result: engine.finalize_program(resolutions),
warnings,
}
} else {
TypeInferenceResult::Failure {
errors: engine.errors,
warnings: engine.warnings,
}
}
} }
} }
proptest::proptest! { proptest::proptest! {
#[test] #[test]
fn translation_maintains_semantics(input in syntax::Program::arbitrary_with(GenerationEnvironment::new(false))) { fn translation_maintains_semantics(input in syntax::arbitrary::ProgramGenerator::default()) {
let syntax_result = input.eval(); let input_program = syntax::Program::validate(input).into_result().expect("can validate random program");
let ir = input.type_infer().expect("arbitrary should generate type-safe programs"); let syntax_result = input_program.eval().map(|(x,o)| (x.strip(), o));
let ir_result = ir.eval(); let ir = input_program.type_infer().expect("arbitrary should generate type-safe programs");
proptest::prop_assert_eq!(syntax_result, ir_result); let ir_evaluator = crate::ir::Evaluator::<crate::ir::Type>::default();
let ir_result = ir_evaluator.eval(ir).map(|(x, o)| (x.strip(), o));
match (syntax_result, ir_result) {
(Err(e1), Err(e2)) => proptest::prop_assert_eq!(e1, e2),
(Ok((v1, o1)), Ok((v2, o2))) => {
proptest::prop_assert_eq!(v1, v2);
proptest::prop_assert_eq!(o1, o2);
}
_ => proptest::prop_assert!(false),
}
} }
} }

View File

@@ -1,336 +0,0 @@
pub use crate::ir::ast::Primitive;
/// This is largely a copy of `ir/ast`, with a couple of extensions that we're going
/// to want to use while we're doing type inference, but don't want to keep around
/// afterwards. These are:
///
/// * A notion of a type variable
/// * An unknown numeric constant form
///
use crate::{
eval::PrimitiveType,
syntax::{self, ConstantType, Location},
};
use internment::ArcIntern;
use pretty::{DocAllocator, Pretty};
use std::fmt;
use std::sync::atomic::AtomicUsize;
/// We're going to represent variables as interned strings.
///
/// These should be fast enough for comparison that it's OK, since it's going to end up
/// being pretty much the pointer to the string.
type Variable = ArcIntern<String>;
/// The representation of a program within our IR. For now, this is exactly one file.
///
/// In addition, for the moment there's not really much of interest to hold here besides
/// the list of statements read from the file. Order is important. In the future, you
/// could imagine caching analysis information in this structure.
///
/// `Program` implements both [`Pretty`] and [`Arbitrary`]. The former should be used
/// to print the structure whenever possible, especially if you value your or your
/// user's time. The latter is useful for testing that conversions of `Program` retain
/// their meaning. All `Program`s generated through [`Arbitrary`] are guaranteed to be
/// syntactically valid, although they may contain runtime issue like over- or underflow.
#[derive(Debug)]
pub struct Program {
// For now, a program is just a vector of statements. In the future, we'll probably
// extend this to include a bunch of other information, but for now: just a list.
pub(crate) statements: Vec<Statement>,
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Program
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let mut result = allocator.nil();
for stmt in self.statements.iter() {
// there's probably a better way to do this, rather than constantly
// adding to the end, but this works.
result = result
.append(stmt.pretty(allocator))
.append(allocator.text(";"))
.append(allocator.hardline());
}
result
}
}
/// The representation of a statement in the language.
///
/// For now, this is either a binding site (`x = 4`) or a print statement
/// (`print x`). Someday, though, more!
///
/// As with `Program`, this type implements [`Pretty`], which should
/// be used to display the structure whenever possible. It does not
/// implement [`Arbitrary`], though, mostly because it's slightly
/// complicated to do so.
///
#[derive(Debug)]
pub enum Statement {
Binding(Location, Variable, Type, Expression),
Print(Location, Type, Variable),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Statement
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Statement::Binding(_, var, _, expr) => allocator
.text(var.as_ref().to_string())
.append(allocator.space())
.append(allocator.text("="))
.append(allocator.space())
.append(expr.pretty(allocator)),
Statement::Print(_, _, var) => allocator
.text("print")
.append(allocator.space())
.append(allocator.text(var.as_ref().to_string())),
}
}
}
/// The representation of an expression.
///
/// Note that expressions, like everything else in this syntax tree,
/// supports [`Pretty`], and it's strongly encouraged that you use
/// that trait/module when printing these structures.
///
/// Also, Expressions at this point in the compiler are explicitly
/// defined so that they are *not* recursive. By this point, if an
/// expression requires some other data (like, for example, invoking
/// a primitive), any subexpressions have been bound to variables so
/// that the referenced data will always either be a constant or a
/// variable reference.
#[derive(Debug, PartialEq)]
pub enum Expression {
Atomic(ValueOrRef),
Cast(Location, Type, ValueOrRef),
Primitive(Location, Type, Primitive, Vec<ValueOrRef>),
}
impl Expression {
/// Return a reference to the type of the expression, as inferred or recently
/// computed.
pub fn type_of(&self) -> &Type {
match self {
Expression::Atomic(ValueOrRef::Ref(_, t, _)) => t,
Expression::Atomic(ValueOrRef::Value(_, t, _)) => t,
Expression::Cast(_, t, _) => t,
Expression::Primitive(_, t, _, _) => t,
}
}
/// Return a reference to the location associated with the expression.
pub fn location(&self) -> &Location {
match self {
Expression::Atomic(ValueOrRef::Ref(l, _, _)) => l,
Expression::Atomic(ValueOrRef::Value(l, _, _)) => l,
Expression::Cast(l, _, _) => l,
Expression::Primitive(l, _, _, _) => l,
}
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Expression
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Expression::Atomic(x) => x.pretty(allocator),
Expression::Cast(_, t, e) => allocator
.text("<")
.append(t.pretty(allocator))
.append(allocator.text(">"))
.append(e.pretty(allocator)),
Expression::Primitive(_, _, op, exprs) if exprs.len() == 1 => {
op.pretty(allocator).append(exprs[0].pretty(allocator))
}
Expression::Primitive(_, _, op, exprs) if exprs.len() == 2 => {
let left = exprs[0].pretty(allocator);
let right = exprs[1].pretty(allocator);
left.append(allocator.space())
.append(op.pretty(allocator))
.append(allocator.space())
.append(right)
.parens()
}
Expression::Primitive(_, _, op, exprs) => {
allocator.text(format!("!!{:?} with {} arguments!!", op, exprs.len()))
}
}
}
}
/// An expression that is always either a value or a reference.
///
/// This is the type used to guarantee that we don't nest expressions
/// at this level. Instead, expressions that take arguments take one
/// of these, which can only be a constant or a reference.
#[derive(Clone, Debug, PartialEq)]
pub enum ValueOrRef {
Value(Location, Type, Value),
Ref(Location, Type, ArcIntern<String>),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b ValueOrRef
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
ValueOrRef::Value(_, _, v) => v.pretty(allocator),
ValueOrRef::Ref(_, _, v) => allocator.text(v.as_ref().to_string()),
}
}
}
impl From<ValueOrRef> for Expression {
fn from(value: ValueOrRef) -> Self {
Expression::Atomic(value)
}
}
/// A constant in the IR.
///
/// The optional argument in numeric types is the base that was used by the
/// user to input the number. By retaining it, we can ensure that if we need
/// to print the number back out, we can do so in the form that the user
/// entered it.
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Unknown(Option<u8>, u64),
I8(Option<u8>, i8),
I16(Option<u8>, i16),
I32(Option<u8>, i32),
I64(Option<u8>, i64),
U8(Option<u8>, u8),
U16(Option<u8>, u16),
U32(Option<u8>, u32),
U64(Option<u8>, u64),
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Value
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
let pretty_internal = |opt_base: &Option<u8>, x, t| {
syntax::Value::Number(*opt_base, Some(t), x).pretty(allocator)
};
let pretty_internal_signed = |opt_base, x: i64, t| {
let base = pretty_internal(opt_base, x.unsigned_abs(), t);
allocator.text("-").append(base)
};
match self {
Value::Unknown(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::U64)
}
Value::I8(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I8)
}
Value::I16(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I16)
}
Value::I32(opt_base, value) => {
pretty_internal_signed(opt_base, *value as i64, ConstantType::I32)
}
Value::I64(opt_base, value) => {
pretty_internal_signed(opt_base, *value, ConstantType::I64)
}
Value::U8(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U8)
}
Value::U16(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U16)
}
Value::U32(opt_base, value) => {
pretty_internal(opt_base, *value as u64, ConstantType::U32)
}
Value::U64(opt_base, value) => pretty_internal(opt_base, *value, ConstantType::U64),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Type {
Variable(Location, ArcIntern<String>),
Primitive(PrimitiveType),
}
impl Type {
pub fn is_concrete(&self) -> bool {
!matches!(self, Type::Variable(_, _))
}
}
impl<'a, 'b, D, A> Pretty<'a, D, A> for &'b Type
where
A: 'a,
D: ?Sized + DocAllocator<'a, A>,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, A> {
match self {
Type::Variable(_, x) => allocator.text(x.to_string()),
Type::Primitive(pt) => allocator.text(format!("{}", pt)),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Variable(_, x) => write!(f, "{}", x),
Type::Primitive(pt) => pt.fmt(f),
}
}
}
/// Generate a fresh new name based on the given name.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gensym(name: &str) -> ArcIntern<String> {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let new_name = format!(
"<{}:{}>",
name,
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
ArcIntern::new(new_name)
}
/// Generate a fresh new type; this will be a unique new type variable.
///
/// The new name is guaranteed to be unique across the entirety of the
/// execution. This is achieved by using characters in the variable name
/// that would not be valid input, and by including a counter that is
/// incremented on every invocation.
pub fn gentype() -> Type {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let name = ArcIntern::new(format!(
"t<{}>",
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
));
Type::Variable(Location::manufactured(), name)
}

View File

@@ -0,0 +1,78 @@
use crate::ir::TypeOrVar;
use crate::syntax::{Location, Name};
use std::fmt;
/// A type inference constraint that we're going to need to solve.
#[derive(Debug)]
pub enum Constraint {
/// The given type must be printable using the `print` built-in
Printable(Location, TypeOrVar),
/// The provided numeric value fits in the given constant type
FitsInNumType(Location, TypeOrVar, u64),
/// The given type can be casted to the target type safely
CanCastTo(Location, TypeOrVar, TypeOrVar),
/// The given type has the given field in it, and the type of that field
/// is as given.
TypeHasField(Location, TypeOrVar, Name, TypeOrVar),
/// The given type must be some numeric type, but this is not a constant
/// value, so don't try to default it if we can't figure it out
NumericType(Location, TypeOrVar),
/// The given type is attached to a constant and must be some numeric type.
/// If we can't figure it out, we should warn the user and then just use a
/// default.
ConstantNumericType(Location, TypeOrVar),
/// The two types should be equivalent
Equivalent(Location, TypeOrVar, TypeOrVar),
/// The given type can be resolved to something
IsSomething(Location, TypeOrVar),
/// The given type can be negated
IsSigned(Location, TypeOrVar),
/// Checks to see if the given named type is equivalent to the provided one.
NamedTypeIs(Location, Name, TypeOrVar),
}
impl fmt::Display for Constraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Constraint::Printable(_, ty) => write!(f, "PRINTABLE {}", ty),
Constraint::FitsInNumType(_, ty, num) => write!(f, "FITS_IN {} {}", num, ty),
Constraint::CanCastTo(_, ty, ty2) => write!(f, "CAST {} -> {}", ty, ty2),
Constraint::TypeHasField(_, ty1, field, ty2) => {
write!(f, "FIELD {}.{} -> {}", ty1, field, ty2)
}
Constraint::NumericType(_, ty) => write!(f, "NUMERIC {}", ty),
Constraint::ConstantNumericType(_, ty) => write!(f, "CONST_NUMERIC {}", ty),
Constraint::Equivalent(_, ty, ty2) => write!(f, "EQUIVALENT {} => {}", ty, ty2),
Constraint::IsSomething(_, ty) => write!(f, "SOMETHING {}", ty),
Constraint::IsSigned(_, ty) => write!(f, "SIGNED {}", ty),
Constraint::NamedTypeIs(_, name, ty) => write!(f, "TYPE_EQUIV {} == {}", name, ty),
}
}
}
impl Constraint {
/// Replace all instances of the name (anywhere! including on the left hand side of equivalences!)
/// with the given type.
///
/// Returns whether or not anything was changed in the constraint.
pub fn replace(&mut self, name: &Name, replace_with: &TypeOrVar) -> bool {
match self {
Constraint::Printable(_, ty) => ty.replace(name, replace_with),
Constraint::FitsInNumType(_, ty, _) => ty.replace(name, replace_with),
Constraint::CanCastTo(_, ty1, ty2) => {
ty1.replace(name, replace_with) || ty2.replace(name, replace_with)
}
Constraint::TypeHasField(_, ty1, _, ty2) => {
ty1.replace(name, replace_with) || ty2.replace(name, replace_with)
}
Constraint::ConstantNumericType(_, ty) => ty.replace(name, replace_with),
Constraint::Equivalent(_, ty1, ty2) => {
ty1.replace(name, replace_with) || ty2.replace(name, replace_with)
}
Constraint::IsSigned(_, ty) => ty.replace(name, replace_with),
Constraint::IsSomething(_, ty) => ty.replace(name, replace_with),
Constraint::NumericType(_, ty) => ty.replace(name, replace_with),
Constraint::NamedTypeIs(_, name, ty) => ty.replace(name, replace_with),
}
}
}

View File

@@ -1,92 +1,109 @@
use super::ast as ir; use super::constraint::Constraint;
use super::ast::Type; use super::InferenceEngine;
use crate::eval::PrimitiveType; use crate::eval::PrimitiveType;
use crate::syntax::{self, ConstantType}; use crate::ir::{self, Fields};
use crate::type_infer::solve::Constraint; use crate::syntax::Name;
use crate::syntax::{self, ConstantType, Location};
use crate::util::scoped_map::ScopedMap;
use internment::ArcIntern; use internment::ArcIntern;
use std::collections::HashMap; use std::collections::HashMap;
use std::str::FromStr; use std::str::FromStr;
/// This function takes a syntactic program and converts it into the IR version of the impl From<syntax::Program> for InferenceEngine {
/// program, with appropriate type variables introduced and their constraints added to fn from(value: syntax::Program) -> Self {
/// the given database. let syntax::Program {
/// functions,
/// If the input function has been validated (which it should be), then this should run structures,
/// into no error conditions. However, if you failed to validate the input, then this body,
/// function can panic. } = value;
pub fn convert_program( let mut result = InferenceEngine {
mut program: syntax::Program, constraints: Vec::new(),
constraint_db: &mut Vec<Constraint>, type_definitions: HashMap::new(),
) -> ir::Program { variable_types: HashMap::new(),
let mut statements = Vec::new(); functions: HashMap::new(),
let mut renames = HashMap::new(); body: ir::Expression::Block(Location::manufactured(), ir::TypeOrVar::new(), vec![]),
let mut bindings = HashMap::new(); errors: vec![],
warnings: vec![],
};
let mut renames = ScopedMap::new();
for stmt in program.statements.drain(..) { // first let's transfer all the type information over into our new
statements.append(&mut convert_statement( // data structures
stmt, for (_, structure) in structures.into_iter() {
constraint_db, let mut fields = Fields::default();
&mut renames,
&mut bindings, for (name, optty) in structure.fields {
)); match optty {
None => {
let newty = ir::TypeOrVar::new_located(name.location().clone());
fields.insert(name, newty);
} }
ir::Program { statements } Some(t) => {
let existing_ty = result.convert_type(t);
fields.insert(name, existing_ty);
}
}
} }
/// This function takes a syntactic statements and converts it into a series of result
/// IR statements, adding type variables and constraints as necessary. .type_definitions
/// .insert(structure.name.clone(), ir::TypeOrVar::Structure(fields));
/// We generate a series of statements because we're going to flatten all
/// incoming expressions so that they are no longer recursive. This will
/// generate a bunch of new bindings for all the subexpressions, which we
/// return as a bundle.
///
/// See the safety warning on [`convert_program`]! This function assumes that
/// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean.
fn convert_statement(
statement: syntax::Statement,
constraint_db: &mut Vec<Constraint>,
renames: &mut HashMap<ArcIntern<String>, ArcIntern<String>>,
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> Vec<ir::Statement> {
match statement {
syntax::Statement::Print(loc, name) => {
let iname = ArcIntern::new(name.to_string());
let final_name = renames
.get(&iname)
.map(Clone::clone)
.unwrap_or_else(|| iname.clone());
let varty = bindings
.get(&final_name)
.expect("print variable defined before use")
.clone();
constraint_db.push(Constraint::Printable(loc.clone(), varty.clone()));
vec![ir::Statement::Print(loc, varty, iname)]
} }
syntax::Statement::Binding(loc, name, expr) => { // then transfer all the functions over to the new system
let (mut prereqs, expr, ty) = for (_, function) in functions.into_iter() {
convert_expression(expr, constraint_db, renames, bindings); // convert the arguments into the new type scheme. if given, use the ones
let iname = ArcIntern::new(name.to_string()); // given, otherwise generate a new type variable for us to solve for.
let final_name = if bindings.contains_key(&iname) { let mut arguments = vec![];
let new_name = ir::gensym(iname.as_str()); for (name, ty) in function.arguments.into_iter() {
renames.insert(iname, new_name.clone()); match ty {
new_name None => {
let inferred_type = ir::TypeOrVar::new_located(name.location().clone());
arguments.push((name, inferred_type));
}
Some(t) => {
arguments.push((name, result.convert_type(t)));
}
}
}
// similarly, use the provided return type if given, otherwise generate
// a new type variable to use.
let return_type = if let Some(t) = function.return_type {
result.convert_type(t)
} else { } else {
iname ir::TypeOrVar::new_located(function.name.location().clone())
}; };
bindings.insert(final_name.clone(), ty.clone()); let (body, body_type) = result.convert_expression(function.body, &mut renames);
prereqs.push(ir::Statement::Binding(loc, final_name, ty, expr)); result.constraints.push(Constraint::Equivalent(
prereqs function.name.location().clone(),
return_type.clone(),
body_type,
));
let new_function = ir::FunctionDefinition {
name: function.name,
arguments,
return_type,
body,
};
result
.functions
.insert(new_function.name.clone(), new_function);
} }
// finally we can transfer the body over
result.body = result.convert_expression(body, &mut renames).0;
result
} }
} }
impl InferenceEngine {
/// This function takes a syntactic expression and converts it into a series /// This function takes a syntactic expression and converts it into a series
/// of IR statements, adding type variables and constraints as necessary. /// of IR statements, adding type variables and constraints as necessary.
/// ///
@@ -99,280 +116,399 @@ fn convert_statement(
/// you have run [`Statement::validate`], and will trigger panics in error /// you have run [`Statement::validate`], and will trigger panics in error
/// conditions if you have run that and had it come back clean. /// conditions if you have run that and had it come back clean.
fn convert_expression( fn convert_expression(
&mut self,
expression: syntax::Expression, expression: syntax::Expression,
constraint_db: &mut Vec<Constraint>, renames: &mut ScopedMap<Name, ArcIntern<String>>,
renames: &HashMap<ArcIntern<String>, ArcIntern<String>>, ) -> (ir::Expression<ir::TypeOrVar>, ir::TypeOrVar) {
bindings: &mut HashMap<ArcIntern<String>, Type>,
) -> (Vec<ir::Statement>, ir::Expression, Type) {
match expression { match expression {
// converting values is mostly tedious, because there's so many cases
// involved
syntax::Expression::Value(loc, val) => match val { syntax::Expression::Value(loc, val) => match val {
syntax::Value::Void => (
ir::Expression::Atomic(ir::ValueOrRef::Value(
loc,
ir::TypeOrVar::Primitive(PrimitiveType::Void),
ir::Value::Void,
)),
ir::TypeOrVar::Primitive(PrimitiveType::Void),
),
syntax::Value::Number(base, mctype, value) => { syntax::Value::Number(base, mctype, value) => {
let (newval, newtype) = match mctype { let (newval, newtype) = match mctype {
None => { None => {
let newtype = ir::gentype(); let newtype = ir::TypeOrVar::new();
let newval = ir::Value::Unknown(base, value); let newval = ir::Value::U64(base, value);
constraint_db.push(Constraint::ConstantNumericType( self.constraints.push(Constraint::ConstantNumericType(
loc.clone(), loc.clone(),
newtype.clone(), newtype.clone(),
)); ));
(newval, newtype) (newval, newtype)
} }
Some(ConstantType::Void) => (
ir::Value::Void,
ir::TypeOrVar::Primitive(PrimitiveType::Void),
),
Some(ConstantType::U8) => ( Some(ConstantType::U8) => (
ir::Value::U8(base, value as u8), ir::Value::U8(base, value as u8),
ir::Type::Primitive(PrimitiveType::U8), ir::TypeOrVar::Primitive(PrimitiveType::U8),
), ),
Some(ConstantType::U16) => ( Some(ConstantType::U16) => (
ir::Value::U16(base, value as u16), ir::Value::U16(base, value as u16),
ir::Type::Primitive(PrimitiveType::U16), ir::TypeOrVar::Primitive(PrimitiveType::U16),
), ),
Some(ConstantType::U32) => ( Some(ConstantType::U32) => (
ir::Value::U32(base, value as u32), ir::Value::U32(base, value as u32),
ir::Type::Primitive(PrimitiveType::U32), ir::TypeOrVar::Primitive(PrimitiveType::U32),
), ),
Some(ConstantType::U64) => ( Some(ConstantType::U64) => (
ir::Value::U64(base, value), ir::Value::U64(base, value),
ir::Type::Primitive(PrimitiveType::U64), ir::TypeOrVar::Primitive(PrimitiveType::U64),
), ),
Some(ConstantType::I8) => ( Some(ConstantType::I8) => (
ir::Value::I8(base, value as i8), ir::Value::I8(base, value as i8),
ir::Type::Primitive(PrimitiveType::I8), ir::TypeOrVar::Primitive(PrimitiveType::I8),
), ),
Some(ConstantType::I16) => ( Some(ConstantType::I16) => (
ir::Value::I16(base, value as i16), ir::Value::I16(base, value as i16),
ir::Type::Primitive(PrimitiveType::I16), ir::TypeOrVar::Primitive(PrimitiveType::I16),
), ),
Some(ConstantType::I32) => ( Some(ConstantType::I32) => (
ir::Value::I32(base, value as i32), ir::Value::I32(base, value as i32),
ir::Type::Primitive(PrimitiveType::I32), ir::TypeOrVar::Primitive(PrimitiveType::I32),
), ),
Some(ConstantType::I64) => ( Some(ConstantType::I64) => (
ir::Value::I64(base, value as i64), ir::Value::I64(base, value as i64),
ir::Type::Primitive(PrimitiveType::I64), ir::TypeOrVar::Primitive(PrimitiveType::I64),
), ),
}; };
constraint_db.push(Constraint::FitsInNumType( self.constraints.push(Constraint::FitsInNumType(
loc.clone(), loc.clone(),
newtype.clone(), newtype.clone(),
value, value,
)); ));
( (
vec![],
ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)), ir::Expression::Atomic(ir::ValueOrRef::Value(loc, newtype.clone(), newval)),
newtype, newtype,
) )
} }
}, },
syntax::Expression::Reference(loc, name) => { syntax::Expression::Constructor(loc, name, fields) => {
let iname = ArcIntern::new(name); let mut result_fields = HashMap::new();
let final_name = renames.get(&iname).cloned().unwrap_or(iname); let mut type_fields = ir::Fields::default();
let rtype = bindings let mut prereqs = vec![];
.get(&final_name)
for (name, syntax_expr) in fields.into_iter() {
let (field_expr, field_type) = self.convert_expression(syntax_expr, renames);
type_fields.insert(name.clone(), field_type);
let (prereq, value) = simplify_expr(field_expr);
result_fields.insert(name.clone(), value);
merge_prereq(&mut prereqs, prereq);
}
let result_type = ir::TypeOrVar::Structure(type_fields);
self.constraints.push(Constraint::NamedTypeIs(
loc.clone(),
name.clone(),
result_type.clone(),
));
let expression =
ir::Expression::Construct(loc, result_type.clone(), name, result_fields);
(expression, result_type)
}
syntax::Expression::Reference(mut name) => {
if let Some(rename) = renames.get(&name) {
name.rename(rename);
}
let result_type = self
.variable_types
.get(&name)
.cloned() .cloned()
.expect("variable bound before use"); .expect("variable bound before use");
let refexp =
ir::Expression::Atomic(ir::ValueOrRef::Ref(loc, rtype.clone(), final_name));
(vec![], refexp, rtype) let expression = ir::Expression::Atomic(ir::ValueOrRef::Ref(
name.location().clone(),
result_type.clone(),
name.clone(),
));
(expression, result_type)
}
syntax::Expression::FieldRef(loc, expr, field) => {
let (expr, expr_type) = self.convert_expression(*expr, renames);
let (prereqs, val_or_ref) = simplify_expr(expr);
let result_type = ir::TypeOrVar::new();
let result = ir::Expression::FieldRef(
loc.clone(),
result_type.clone(),
expr_type.clone(),
val_or_ref,
field.clone(),
);
self.constraints.push(Constraint::TypeHasField(
loc,
expr_type.clone(),
field,
result_type.clone(),
));
(finalize_expression(prereqs, result), result_type)
} }
syntax::Expression::Cast(loc, target, expr) => { syntax::Expression::Cast(loc, target, expr) => {
let (mut stmts, nexpr, etype) = let (expr, expr_type) = self.convert_expression(*expr, renames);
convert_expression(*expr, constraint_db, renames, bindings); let (prereqs, val_or_ref) = simplify_expr(expr);
let val_or_ref = simplify_expr(nexpr, &mut stmts); let target_type: ir::TypeOrVar = PrimitiveType::from_str(&target)
let target_prim_type = PrimitiveType::from_str(&target).expect("valid type for cast"); .expect("valid type for cast")
let target_type = Type::Primitive(target_prim_type); .into();
let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref); let res = ir::Expression::Cast(loc.clone(), target_type.clone(), val_or_ref);
constraint_db.push(Constraint::CanCastTo(loc, etype, target_type.clone())); self.constraints.push(Constraint::CanCastTo(
loc,
(stmts, res, target_type) expr_type.clone(),
} target_type.clone(),
syntax::Expression::Primitive(loc, fun, mut args) => {
let primop = ir::Primitive::from_str(&fun).expect("valid primitive");
let mut stmts = vec![];
let mut nargs = vec![];
let mut atypes = vec![];
let ret_type = ir::gentype();
for arg in args.drain(..) {
let (mut astmts, aexp, atype) =
convert_expression(arg, constraint_db, renames, bindings);
stmts.append(&mut astmts);
nargs.push(simplify_expr(aexp, &mut stmts));
atypes.push(atype);
}
constraint_db.push(Constraint::ProperPrimitiveArgs(
loc.clone(),
primop,
atypes.clone(),
ret_type.clone(),
)); ));
( (finalize_expression(prereqs, res), target_type)
stmts, }
ir::Expression::Primitive(loc, ret_type.clone(), primop, nargs),
ret_type, syntax::Expression::Primitive(loc, name) => {
) let primop = ir::Primitive::from_str(name.current_name()).expect("valid primitive");
match primop {
ir::Primitive::Plus | ir::Primitive::Times | ir::Primitive::Divide => {
let numeric_type = ir::TypeOrVar::new_located(loc.clone());
self.constraints
.push(Constraint::NumericType(loc.clone(), numeric_type.clone()));
let funtype = ir::TypeOrVar::Function(
vec![numeric_type.clone(), numeric_type.clone()],
Box::new(numeric_type.clone()),
);
let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop);
(ir::Expression::Atomic(result_value), funtype)
}
ir::Primitive::Minus => {
let numeric_type = ir::TypeOrVar::new_located(loc.clone());
self.constraints
.push(Constraint::NumericType(loc.clone(), numeric_type.clone()));
let funtype = ir::TypeOrVar::Function(
vec![numeric_type.clone(), numeric_type.clone()],
Box::new(numeric_type.clone()),
);
let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop);
(ir::Expression::Atomic(result_value), funtype)
}
ir::Primitive::Print => {
let arg_type = ir::TypeOrVar::new_located(loc.clone());
self.constraints
.push(Constraint::Printable(loc.clone(), arg_type.clone()));
let funtype = ir::TypeOrVar::Function(
vec![arg_type],
Box::new(ir::TypeOrVar::Primitive(PrimitiveType::Void)),
);
let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop);
(ir::Expression::Atomic(result_value), funtype)
}
ir::Primitive::Negate => {
let arg_type = ir::TypeOrVar::new_located(loc.clone());
self.constraints
.push(Constraint::NumericType(loc.clone(), arg_type.clone()));
self.constraints
.push(Constraint::IsSigned(loc.clone(), arg_type.clone()));
let funtype =
ir::TypeOrVar::Function(vec![arg_type.clone()], Box::new(arg_type));
let result_value = ir::ValueOrRef::Primitive(loc, funtype.clone(), primop);
(ir::Expression::Atomic(result_value), funtype)
} }
} }
} }
fn simplify_expr(expr: ir::Expression, stmts: &mut Vec<ir::Statement>) -> ir::ValueOrRef { syntax::Expression::Call(loc, fun, args) => {
let return_type = ir::TypeOrVar::new();
let arg_types = args
.iter()
.map(|_| ir::TypeOrVar::new())
.collect::<Vec<_>>();
let (fun, fun_type) = self.convert_expression(*fun, renames);
let target_fun_type =
ir::TypeOrVar::Function(arg_types.clone(), Box::new(return_type.clone()));
self.constraints.push(Constraint::Equivalent(
loc.clone(),
fun_type,
target_fun_type,
));
let mut prereqs = vec![];
let (fun_prereqs, fun) = simplify_expr(fun);
merge_prereq(&mut prereqs, fun_prereqs);
let new_args = args
.into_iter()
.zip(arg_types)
.map(|(arg, target_type)| {
let (arg, arg_type) = self.convert_expression(arg, renames);
let location = arg.location().clone();
let (arg_prereq, new_valref) = simplify_expr(arg);
merge_prereq(&mut prereqs, arg_prereq);
self.constraints.push(Constraint::Equivalent(
location,
arg_type,
target_type,
));
new_valref
})
.collect();
let last_call =
ir::Expression::Call(loc.clone(), return_type.clone(), fun, new_args);
(finalize_expressions(prereqs, last_call), return_type)
}
syntax::Expression::Block(loc, stmts) => {
let mut result_type = ir::TypeOrVar::Primitive(PrimitiveType::Void);
let mut exprs = vec![];
for xpr in stmts.into_iter() {
let (expr, expr_type) = self.convert_expression(xpr, renames);
result_type = expr_type;
exprs.push(expr);
}
(
ir::Expression::Block(loc, result_type.clone(), exprs),
result_type,
)
}
syntax::Expression::Binding(loc, name, expr) => {
let (expr, expr_type) = self.convert_expression(*expr, renames);
let final_name = self.finalize_name(renames, name);
self.variable_types
.insert(final_name.clone(), expr_type.clone());
let result_expr =
ir::Expression::Bind(loc, final_name, expr_type.clone(), Box::new(expr));
(result_expr, expr_type)
}
syntax::Expression::Function(_, _, _, _, _) => {
panic!(
"Function expressions should not survive validation to get to type checking!"
);
}
}
}
fn convert_type(&mut self, ty: syntax::Type) -> ir::TypeOrVar {
match ty {
syntax::Type::Named(x) => match PrimitiveType::from_str(x.current_name()) {
Err(_) => {
let retval = ir::TypeOrVar::new_located(x.location().clone());
self.constraints.push(Constraint::NamedTypeIs(
x.location().clone(),
x,
retval.clone(),
));
retval
}
Ok(v) => ir::TypeOrVar::Primitive(v),
},
syntax::Type::Struct(fields) => {
let mut new_fields = ir::Fields::default();
for (name, field_type) in fields.into_iter() {
let new_field_type = field_type
.map(|x| self.convert_type(x))
.unwrap_or_else(ir::TypeOrVar::new);
new_fields.insert(name, new_field_type);
}
ir::TypeOrVar::Structure(new_fields)
}
}
}
fn finalize_name(
&mut self,
renames: &mut ScopedMap<Name, ArcIntern<String>>,
mut name: syntax::Name,
) -> syntax::Name {
if self.variable_types.contains_key(&name) {
let new_name = Name::gensym(name.original_name()).intern();
renames.insert(name.clone(), new_name.clone());
name.rename(&new_name);
name
} else {
name
}
}
}
fn simplify_expr(
expr: ir::Expression<ir::TypeOrVar>,
) -> (
Option<ir::Expression<ir::TypeOrVar>>,
ir::ValueOrRef<ir::TypeOrVar>,
) {
match expr { match expr {
ir::Expression::Atomic(v_or_ref) => v_or_ref, ir::Expression::Atomic(v_or_ref) => (None, v_or_ref),
expr => { expr => {
let etype = expr.type_of().clone(); let etype = expr.type_of().clone();
let loc = expr.location().clone(); let loc = expr.location().clone();
let nname = ir::gensym("g"); let nname = Name::located_gensym(loc.clone(), "g");
let nbinding = ir::Statement::Binding(loc.clone(), nname.clone(), etype.clone(), expr); let nbinding =
ir::Expression::Bind(loc.clone(), nname.clone(), etype.clone(), Box::new(expr));
stmts.push(nbinding); (Some(nbinding), ir::ValueOrRef::Ref(loc, etype, nname))
ir::ValueOrRef::Ref(loc, etype, nname)
} }
} }
} }
#[cfg(test)] fn finalize_expression(
mod tests { prereq: Option<ir::Expression<ir::TypeOrVar>>,
use super::*; actual: ir::Expression<ir::TypeOrVar>,
use crate::syntax::Location; ) -> ir::Expression<ir::TypeOrVar> {
if let Some(prereq) = prereq {
fn one() -> syntax::Expression { ir::Expression::Block(
syntax::Expression::Value( prereq.location().clone(),
Location::manufactured(), actual.type_of().clone(),
syntax::Value::Number(None, None, 1), vec![prereq, actual],
) )
} else {
actual
}
} }
fn vec_contains<T, F: Fn(&T) -> bool>(x: &[T], f: F) -> bool { fn finalize_expressions(
for x in x.iter() { mut prereqs: Vec<ir::Expression<ir::TypeOrVar>>,
if f(x) { actual: ir::Expression<ir::TypeOrVar>,
return true; ) -> ir::Expression<ir::TypeOrVar> {
if prereqs.is_empty() {
actual
} else {
let return_type = actual.type_of();
let loc = actual.location().clone();
prereqs.push(actual);
ir::Expression::Block(loc, return_type, prereqs)
} }
} }
false
}
fn infer_expression( fn merge_prereq<T>(left: &mut Vec<T>, prereq: Option<T>) {
x: syntax::Expression, if let Some(item) = prereq {
) -> (ir::Expression, Vec<ir::Statement>, Vec<Constraint>, Type) { left.push(item)
let mut constraints = Vec::new();
let renames = HashMap::new();
let mut bindings = HashMap::new();
let (stmts, expr, ty) = convert_expression(x, &mut constraints, &renames, &mut bindings);
(expr, stmts, constraints, ty)
}
fn infer_statement(x: syntax::Statement) -> (Vec<ir::Statement>, Vec<Constraint>) {
let mut constraints = Vec::new();
let mut renames = HashMap::new();
let mut bindings = HashMap::new();
let res = convert_statement(x, &mut constraints, &mut renames, &mut bindings);
(res, constraints)
}
#[test]
fn constant_one() {
let (expr, stmts, constraints, ty) = infer_expression(one());
assert!(stmts.is_empty());
assert!(matches!(
expr,
ir::Expression::Atomic(ir::ValueOrRef::Value(_, _, ir::Value::Unknown(None, 1)))
));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == &ty)
));
}
#[test]
fn one_plus_one() {
let opo = syntax::Expression::Primitive(
Location::manufactured(),
"+".to_string(),
vec![one(), one()],
);
let (expr, stmts, constraints, ty) = infer_expression(opo);
assert!(stmts.is_empty());
assert!(matches!(expr, ir::Expression::Primitive(_, t, ir::Primitive::Plus, _) if t == ty));
assert!(vec_contains(&constraints, |x| matches!(
x,
Constraint::FitsInNumType(_, _, 1)
)));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t != &ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ProperPrimitiveArgs(_, ir::Primitive::Plus, args, ret) if args.len() == 2 && ret == &ty)
));
}
#[test]
fn one_plus_one_plus_one() {
let stmt = syntax::Statement::parse(1, "x = 1 + 1 + 1;").expect("basic parse");
let (stmts, constraints) = infer_statement(stmt);
assert_eq!(stmts.len(), 2);
let ir::Statement::Binding(
_args,
name1,
temp_ty1,
ir::Expression::Primitive(_, primty1, ir::Primitive::Plus, primargs1),
) = stmts.get(0).expect("item two")
else {
panic!("Failed to match first statement");
};
let ir::Statement::Binding(
_args,
name2,
temp_ty2,
ir::Expression::Primitive(_, primty2, ir::Primitive::Plus, primargs2),
) = stmts.get(1).expect("item two")
else {
panic!("Failed to match second statement");
};
let &[ir::ValueOrRef::Value(_, ref left1ty, _), ir::ValueOrRef::Value(_, ref right1ty, _)] =
&primargs1[..]
else {
panic!("Failed to match first arguments");
};
let &[ir::ValueOrRef::Ref(_, _, ref left2name), ir::ValueOrRef::Value(_, ref right2ty, _)] =
&primargs2[..]
else {
panic!("Failed to match first arguments");
};
assert_ne!(name1, name2);
assert_ne!(temp_ty1, temp_ty2);
assert_ne!(primty1, primty2);
assert_eq!(name1, left2name);
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == left1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right1ty)
));
assert!(vec_contains(
&constraints,
|x| matches!(x, Constraint::ConstantNumericType(_, t) if t == right2ty)
));
for (i, s) in stmts.iter().enumerate() {
println!("{}: {:?}", i, s);
}
for (i, c) in constraints.iter().enumerate() {
println!("{}: {:?}", i, c);
}
} }
} }

145
src/type_infer/error.rs Normal file
View File

@@ -0,0 +1,145 @@
use super::constraint::Constraint;
use crate::eval::PrimitiveType;
use crate::ir::{Primitive, TypeOrVar};
use crate::syntax::{Location, Name};
use codespan_reporting::diagnostic::Diagnostic;
/// The various kinds of errors that can occur while doing type inference.
pub enum TypeInferenceError {
/// The user provide a constant that is too large for its inferred type.
ConstantTooLarge(Location, PrimitiveType, u64),
/// Somehow we're trying to use a non-number as a number
NotANumber(Location, PrimitiveType),
/// The two types needed to be equivalent, but weren't.
NotEquivalent(Location, TypeOrVar, TypeOrVar),
/// We cannot safely cast the first type to the second type.
CannotSafelyCast(Location, PrimitiveType, PrimitiveType),
/// The primitive invocation provided the wrong number of arguments.
WrongPrimitiveArity(Location, Primitive, usize, usize, usize),
/// We cannot cast between the type types, for any number of reasons
CannotCast(Location, TypeOrVar, TypeOrVar),
/// We cannot turn a number into a function.
CannotMakeNumberAFunction(Location, TypeOrVar, Option<u64>),
/// We cannot turn a number into a Structure.
CannotMakeNumberAStructure(Location, TypeOrVar, Option<u64>),
/// We had a constraint we just couldn't solve.
CouldNotSolve(Constraint),
/// Functions are not printable.
FunctionsAreNotPrintable(Location),
/// The given type isn't signed, and can't be negated
IsNotSigned(Location, TypeOrVar),
/// The given type doesn't have the given field.
NoFieldForType(Location, Name, TypeOrVar),
/// There is no type with the given name.
UnknownTypeName(Location, Name),
}
impl From<TypeInferenceError> for Diagnostic<usize> {
fn from(value: TypeInferenceError) -> Self {
match value {
TypeInferenceError::ConstantTooLarge(loc, primty, value) => loc
.labelled_error("constant too large for type")
.with_message(format!(
"Type {} has a max value of {}, which is smaller than {}",
primty,
primty.max_value().expect("constant type has max value"),
value
)),
TypeInferenceError::NotANumber(loc, primty) => loc
.labelled_error("not a numeric type")
.with_message(format!(
"For some reason, we're trying to use {} as a numeric type",
primty,
)),
TypeInferenceError::NotEquivalent(loc, ty1, ty2) => loc
.labelled_error("type inference error")
.with_message(format!("Expected type {}, received type {}", ty1, ty2)),
TypeInferenceError::CannotSafelyCast(loc, ty1, ty2) => loc
.labelled_error("unsafe type cast")
.with_message(format!("Cannot safely cast {} to {}", ty1, ty2)),
TypeInferenceError::WrongPrimitiveArity(loc, prim, lower, upper, observed) => loc
.labelled_error("wrong number of arguments")
.with_message(format!(
"expected {} for {}, received {}",
if lower == upper && lower > 1 {
format!("{} arguments", lower)
} else if lower == upper {
format!("{} argument", lower)
} else {
format!("{}-{} arguments", lower, upper)
},
prim,
observed
)),
TypeInferenceError::CannotCast(loc, t1, t2) => loc
.labelled_error("cannot cast between types")
.with_message(format!(
"tried to cast from {} to {}",
t1, t2,
)),
TypeInferenceError::CannotMakeNumberAFunction(loc, t, val) => loc
.labelled_error(if let Some(val) = val {
format!("cannot turn {} into a function", val)
} else {
"cannot use a constant as a function type".to_string()
})
.with_message(format!("function type was {}", t)),
TypeInferenceError::CannotMakeNumberAStructure(loc, t, val) => loc
.labelled_error(if let Some(val) = val {
format!("cannot turn {} into a function", val)
} else {
"cannot use a constant as a function type".to_string()
})
.with_message(format!("function type was {}", t)),
TypeInferenceError::FunctionsAreNotPrintable(loc) => loc
.labelled_error("cannot print function values"),
TypeInferenceError::IsNotSigned(loc, pt) => loc
.labelled_error(format!("type {} is not signed", pt))
.with_message("and so it cannot be negated"),
TypeInferenceError::NoFieldForType(loc, field, t) => loc
.labelled_error(format!("no field {} available for type {}", field, t)),
TypeInferenceError::UnknownTypeName(loc , name) => loc
.labelled_error(format!("unknown type named {}", name)),
TypeInferenceError::CouldNotSolve(Constraint::CanCastTo(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if it was safe to cast from {} to {}",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::TypeHasField(loc, a, field, _)) => {
loc.labelled_error("internal error")
.with_message(format!("fould not determine if type {} has field {}", a, field))
}
TypeInferenceError::CouldNotSolve(Constraint::Equivalent(loc, a, b)) => {
loc.labelled_error("internal error").with_message(format!(
"could not determine if {} and {} were equivalent",
a, b
))
}
TypeInferenceError::CouldNotSolve(Constraint::FitsInNumType(loc, ty, val)) => {
loc.labelled_error("internal error").with_message(format!(
"Could not determine if {} could fit in {}",
val, ty
))
}
TypeInferenceError::CouldNotSolve(Constraint::NumericType(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if {} was a numeric type", ty)),
TypeInferenceError::CouldNotSolve(Constraint::ConstantNumericType(loc, ty)) =>
panic!("What? Constants should always eventually be solved, even by default; {:?} and type {:?}", loc, ty),
TypeInferenceError::CouldNotSolve(Constraint::Printable(loc, ty)) => loc
.labelled_error("internal error")
.with_message(format!("Could not determine if type {} was printable", ty)),
TypeInferenceError::CouldNotSolve(Constraint::IsSomething(loc, _)) => {
loc.labelled_error("could not infer type")
.with_message("Could not find *any* type information; is this an unused function argument?")
}
TypeInferenceError::CouldNotSolve(Constraint::IsSigned(loc, t)) => loc
.labelled_error("internal error")
.with_message(format!("could not infer that type {} was signed", t)),
TypeInferenceError::CouldNotSolve(Constraint::NamedTypeIs(loc, name, ty)) => loc
.labelled_error("internal error")
.with_message(format!("could not infer that the name {} refers to {}", name, ty)),
}
}
}

View File

@@ -1,185 +1,238 @@
use super::{ast as input, solve::TypeResolutions}; use crate::eval::PrimitiveType;
use crate::{eval::PrimitiveType, ir as output}; use crate::ir::{Expression, FunctionDefinition, Program, Type, TypeOrVar, Value, ValueOrRef};
use crate::syntax::Name;
use std::collections::HashMap;
pub fn finalize_program( pub type TypeResolutions = HashMap<Name, Type>;
mut program: input::Program,
resolutions: &TypeResolutions, impl super::InferenceEngine {
) -> output::Program { pub fn finalize_program(self, resolutions: TypeResolutions) -> Program<Type> {
output::Program { // we can't do this in place without some type nonsense, so we're going to
statements: program // create a brand new set of program arguments and then construct the new
.statements // `Program` from them.
.drain(..) let mut functions = HashMap::new();
.map(|x| finalize_statement(x, resolutions)) let mut type_definitions = HashMap::new();
.collect(),
} // this is handy for debugging
for (name, ty) in resolutions.iter() {
tracing::debug!(name = %name, resolved_type = %ty, "resolved type variable");
} }
fn finalize_statement( // copy over the type definitions
statement: input::Statement, for (name, def) in self.type_definitions.into_iter() {
resolutions: &TypeResolutions, type_definitions.insert(name, finalize_type(def, &resolutions));
) -> output::Statement { }
match statement {
input::Statement::Binding(loc, var, ty, expr) => output::Statement::Binding( // now copy over the functions
loc, for (name, function_def) in self.functions.into_iter() {
var, assert_eq!(name, function_def.name);
finalize_type(ty, resolutions),
finalize_expression(expr, resolutions), let body = finalize_expression(function_def.body, &resolutions);
), let arguments = function_def
input::Statement::Print(loc, ty, var) => { .arguments
output::Statement::Print(loc, finalize_type(ty, resolutions), var) .into_iter()
.map(|(name, t)| (name, finalize_type(t, &resolutions)))
.collect();
functions.insert(
name,
FunctionDefinition {
name: function_def.name,
arguments,
return_type: body.type_of(),
body,
},
);
}
// and now we can finally compute the new body
let body = finalize_expression(self.body, &resolutions);
Program {
functions,
type_definitions,
body,
} }
} }
} }
fn finalize_expression( fn finalize_expression(
expression: input::Expression, expression: Expression<TypeOrVar>,
resolutions: &TypeResolutions, resolutions: &TypeResolutions,
) -> output::Expression { ) -> Expression<Type> {
match expression { match expression {
input::Expression::Atomic(val_or_ref) => { Expression::Atomic(val_or_ref) => {
output::Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions)) Expression::Atomic(finalize_val_or_ref(val_or_ref, resolutions))
} }
input::Expression::Cast(loc, target, val_or_ref) => output::Expression::Cast(
Expression::Cast(loc, target, val_or_ref) => Expression::Cast(
loc, loc,
finalize_type(target, resolutions), finalize_type(target, resolutions),
finalize_val_or_ref(val_or_ref, resolutions), finalize_val_or_ref(val_or_ref, resolutions),
), ),
input::Expression::Primitive(loc, ty, prim, mut args) => output::Expression::Primitive(
Expression::Construct(loc, ty, name, fields) => Expression::Construct(
loc, loc,
finalize_type(ty, resolutions), finalize_type(ty, resolutions),
prim, name,
args.drain(..) fields
.into_iter()
.map(|(k, v)| (k, finalize_val_or_ref(v, resolutions)))
.collect(),
),
Expression::FieldRef(loc, ty, struct_type, valref, field) => Expression::FieldRef(
loc,
finalize_type(ty, resolutions),
finalize_type(struct_type, resolutions),
finalize_val_or_ref(valref, resolutions),
field,
),
Expression::Block(loc, ty, exprs) => {
let mut final_exprs = Vec::with_capacity(exprs.len());
for expr in exprs {
let newexpr = finalize_expression(expr, resolutions);
if let Expression::Block(_, _, mut subexprs) = newexpr {
final_exprs.append(&mut subexprs);
} else {
final_exprs.push(newexpr);
}
}
Expression::Block(loc, finalize_type(ty, resolutions), final_exprs)
}
Expression::Call(loc, ty, fun, args) => Expression::Call(
loc,
finalize_type(ty, resolutions),
finalize_val_or_ref(fun, resolutions),
args.into_iter()
.map(|x| finalize_val_or_ref(x, resolutions)) .map(|x| finalize_val_or_ref(x, resolutions))
.collect(), .collect(),
), ),
Expression::Bind(loc, var, ty, subexp) => Expression::Bind(
loc,
var,
finalize_type(ty, resolutions),
Box::new(finalize_expression(*subexp, resolutions)),
),
} }
} }
fn finalize_type(ty: input::Type, resolutions: &TypeResolutions) -> output::Type { fn finalize_type(ty: TypeOrVar, resolutions: &TypeResolutions) -> Type {
match ty { match ty {
input::Type::Primitive(x) => output::Type::Primitive(x), TypeOrVar::Primitive(x) => Type::Primitive(x),
input::Type::Variable(_, tvar) => match resolutions.get(&tvar) { TypeOrVar::Variable(_, tvar) => match resolutions.get(&tvar) {
None => panic!("Did not resolve type for type variable {}", tvar), None => panic!("Did not resolve type for type variable {}", tvar),
Some(pt) => output::Type::Primitive(*pt), Some(pt) => {
tracing::trace!(type_variable = %tvar, final_type = %pt, "finalizing variable type");
pt.clone()
}
}, },
TypeOrVar::Function(mut args, ret) => Type::Function(
args.drain(..)
.map(|x| finalize_type(x, resolutions))
.collect(),
Box::new(finalize_type(*ret, resolutions)),
),
TypeOrVar::Structure(fields) => {
Type::Structure(fields.map(|subtype| finalize_type(subtype, resolutions)))
}
} }
} }
fn finalize_val_or_ref( fn finalize_val_or_ref(
valref: input::ValueOrRef, valref: ValueOrRef<TypeOrVar>,
resolutions: &TypeResolutions, resolutions: &TypeResolutions,
) -> output::ValueOrRef { ) -> ValueOrRef<Type> {
match valref { match valref {
input::ValueOrRef::Ref(loc, ty, var) => { ValueOrRef::Ref(loc, ty, var) => ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var),
output::ValueOrRef::Ref(loc, finalize_type(ty, resolutions), var) ValueOrRef::Primitive(loc, ty, prim) => {
ValueOrRef::Primitive(loc, finalize_type(ty, resolutions), prim)
} }
input::ValueOrRef::Value(loc, ty, val) => { ValueOrRef::Value(loc, ty, val) => {
let new_type = finalize_type(ty, resolutions); let new_type = finalize_type(ty, resolutions);
match val { match val {
input::Value::Unknown(base, value) => match new_type { // U64 is essentially "unknown" for us, so we use the inferred type
output::Type::Primitive(PrimitiveType::U8) => output::ValueOrRef::Value( Value::U64(base, value) => match new_type {
loc, Type::Function(_, _) => {
new_type, panic!("Somehow inferred that a constant was a function")
output::Value::U8(base, value as u8), }
), Type::Structure(_) => {
output::Type::Primitive(PrimitiveType::U16) => output::ValueOrRef::Value( panic!("Somehow inferred that a constant was a structure")
loc, }
new_type, Type::Primitive(PrimitiveType::Void) => {
output::Value::U16(base, value as u16), panic!("Somehow inferred that a constant was void")
), }
output::Type::Primitive(PrimitiveType::U32) => output::ValueOrRef::Value( Type::Primitive(PrimitiveType::U8) => {
loc, ValueOrRef::Value(loc, new_type, Value::U8(base, value as u8))
new_type, }
output::Value::U32(base, value as u32), Type::Primitive(PrimitiveType::U16) => {
), ValueOrRef::Value(loc, new_type, Value::U16(base, value as u16))
output::Type::Primitive(PrimitiveType::U64) => { }
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value)) Type::Primitive(PrimitiveType::U32) => {
ValueOrRef::Value(loc, new_type, Value::U32(base, value as u32))
}
Type::Primitive(PrimitiveType::U64) => {
ValueOrRef::Value(loc, new_type, Value::U64(base, value))
}
Type::Primitive(PrimitiveType::I8) => {
ValueOrRef::Value(loc, new_type, Value::I8(base, value as i8))
}
Type::Primitive(PrimitiveType::I16) => {
ValueOrRef::Value(loc, new_type, Value::I16(base, value as i16))
}
Type::Primitive(PrimitiveType::I32) => {
ValueOrRef::Value(loc, new_type, Value::I32(base, value as i32))
}
Type::Primitive(PrimitiveType::I64) => {
ValueOrRef::Value(loc, new_type, Value::I64(base, value as i64))
} }
output::Type::Primitive(PrimitiveType::I8) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I8(base, value as i8),
),
output::Type::Primitive(PrimitiveType::I16) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I16(base, value as i16),
),
output::Type::Primitive(PrimitiveType::I32) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I32(base, value as i32),
),
output::Type::Primitive(PrimitiveType::I64) => output::ValueOrRef::Value(
loc,
new_type,
output::Value::I64(base, value as i64),
),
}, },
input::Value::U8(base, value) => { Value::U8(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::U8)));
new_type, ValueOrRef::Value(loc, new_type, Value::U8(base, value))
output::Type::Primitive(PrimitiveType::U8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U8(base, value))
} }
input::Value::U16(base, value) => { Value::U16(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::U16)));
new_type, ValueOrRef::Value(loc, new_type, Value::U16(base, value))
output::Type::Primitive(PrimitiveType::U16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U16(base, value))
} }
input::Value::U32(base, value) => { Value::U32(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::U32)));
new_type, ValueOrRef::Value(loc, new_type, Value::U32(base, value))
output::Type::Primitive(PrimitiveType::U32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U32(base, value))
} }
input::Value::U64(base, value) => { Value::I8(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::I8)));
new_type, ValueOrRef::Value(loc, new_type, Value::I8(base, value))
output::Type::Primitive(PrimitiveType::U64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::U64(base, value))
} }
input::Value::I8(base, value) => { Value::I16(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::I16)));
new_type, ValueOrRef::Value(loc, new_type, Value::I16(base, value))
output::Type::Primitive(PrimitiveType::I8)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I8(base, value))
} }
input::Value::I16(base, value) => { Value::I32(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::I32)));
new_type, ValueOrRef::Value(loc, new_type, Value::I32(base, value))
output::Type::Primitive(PrimitiveType::I16)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I16(base, value))
} }
input::Value::I32(base, value) => { Value::I64(base, value) => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::I64)));
new_type, ValueOrRef::Value(loc, new_type, Value::I64(base, value))
output::Type::Primitive(PrimitiveType::I32)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I32(base, value))
} }
input::Value::I64(base, value) => { Value::Void => {
assert!(matches!( assert!(matches!(new_type, Type::Primitive(PrimitiveType::Void)));
new_type, ValueOrRef::Value(loc, new_type, Value::Void)
output::Type::Primitive(PrimitiveType::I64)
));
output::ValueOrRef::Value(loc, new_type, output::Value::I64(base, value))
} }
} }
} }

50
src/type_infer/result.rs Normal file
View File

@@ -0,0 +1,50 @@
use super::error::TypeInferenceError;
use super::warning::TypeInferenceWarning;
/// The results of type inference; like [`Result`], but with a bit more information.
///
/// This result is parameterized, because sometimes it's handy to return slightly
/// different things; there's a [`TypeInferenceResult::map`] function for performing
/// those sorts of conversions.
pub enum TypeInferenceResult<Result> {
Success {
result: Result,
warnings: Vec<TypeInferenceWarning>,
},
Failure {
errors: Vec<TypeInferenceError>,
warnings: Vec<TypeInferenceWarning>,
},
}
impl<R> TypeInferenceResult<R> {
// If this was a successful type inference, run the function over the result to
// create a new result.
//
// This is the moral equivalent of [`Result::map`], but for type inference results.
pub fn map<U, F>(self, f: F) -> TypeInferenceResult<U>
where
F: FnOnce(R) -> U,
{
match self {
TypeInferenceResult::Success { result, warnings } => TypeInferenceResult::Success {
result: f(result),
warnings,
},
TypeInferenceResult::Failure { errors, warnings } => {
TypeInferenceResult::Failure { errors, warnings }
}
}
}
// Return the final result, or panic if it's not a success
pub fn expect(self, msg: &str) -> R {
match self {
TypeInferenceResult::Success { result, .. } => result,
TypeInferenceResult::Failure { .. } => {
panic!("tried to get value from failed type inference: {}", msg)
}
}
}
}

File diff suppressed because it is too large Load Diff

21
src/type_infer/warning.rs Normal file
View File

@@ -0,0 +1,21 @@
use crate::ir::TypeOrVar;
use crate::syntax::Location;
use codespan_reporting::diagnostic::Diagnostic;
/// Warnings that we might want to tell the user about.
///
/// These are fine, probably, but could indicate some behavior the user might not
/// expect, and so they might want to do something about them.
pub enum TypeInferenceWarning {
DefaultedTo(Location, TypeOrVar),
}
impl From<TypeInferenceWarning> for Diagnostic<usize> {
fn from(value: TypeInferenceWarning) -> Self {
match value {
TypeInferenceWarning::DefaultedTo(loc, ty) => Diagnostic::warning()
.with_labels(vec![loc.primary_label().with_message("unknown type")])
.with_message(format!("Defaulted unknown type to {}", ty)),
}
}
}

4
src/util.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod pretty;
pub mod scoped_map;
pub mod warning_result;
pub mod weighted_map;

35
src/util/pretty.rs Normal file
View File

@@ -0,0 +1,35 @@
use pretty::Arena;
pub type Allocator<'a> = Arena<'a, ()>;
macro_rules! pretty_function_type {
($allocator: ident, $args: ident, $rettype: ident) => {
$allocator
.intersperse(
$args.iter().map(|x| x.pretty($allocator)),
$allocator.text(","),
)
.parens()
.append($allocator.space())
.append($allocator.text("->"))
.append($allocator.space())
.append($rettype.pretty($allocator))
};
}
macro_rules! derived_display {
($type: ty) => {
impl std::fmt::Display for $type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let arena = pretty::Arena::new();
let doc = self.pretty(&arena);
doc.render_fmt(732, f)
}
}
};
}
// this is a dumb Rust trick to export the functions to the rest
// of the crate, but not globally.
pub(crate) use derived_display;
pub(crate) use pretty_function_type;

134
src/util/scoped_map.rs Normal file
View File

@@ -0,0 +1,134 @@
use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::Hash;
/// A version of [`std::collections::HashMap`] with a built-in notion of scope.
#[derive(Clone)]
pub struct ScopedMap<K: Eq + Hash + PartialEq, V> {
scopes: Vec<HashMap<K, V>>,
}
impl<K: Eq + Hash + PartialEq, V> Default for ScopedMap<K, V> {
fn default() -> Self {
ScopedMap::new()
}
}
impl<K: Eq + Hash + PartialEq, V> ScopedMap<K, V> {
/// Generate a new scoped map.
///
/// In addition to generate the map structure, this method also generates
/// an initial scope for use by the caller.
pub fn new() -> ScopedMap<K, V> {
ScopedMap {
scopes: vec![HashMap::new()],
}
}
/// Get a value from the scoped map.
pub fn get<Q>(&self, k: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
for map in self.scopes.iter().rev() {
match map.get(k) {
None => continue,
Some(v) => return Some(v),
}
}
None
}
/// Returns true if the map contains the given key.
pub fn contains_key(&self, k: &K) -> bool {
self.scopes.iter().any(|x| x.contains_key(k))
}
/// Insert a value into the current binding scope.
///
/// If this variable is bound in the current scope, then its value will be
/// overridden. If it's bound in a previous scope, however, that value will
/// be shadowed, so that its value will preserved if/when the current scope
/// is popped.
pub fn insert(&mut self, k: K, v: V) {
self.scopes
.last_mut()
.expect("tried to insert into ScopedMap with no scopes")
.insert(k, v);
}
/// Create a new scope.
///
/// Modifications to this scope will shadow all previous scopes without
/// modifying them. Consider the following examples:
///
/// ```
/// use ngr::util::scoped_map::ScopedMap;
///
/// let mut example1 = ScopedMap::new();
/// example1.insert(1, true);
/// example1.insert(1, false);
/// assert_eq!(Some(&false), example1.get(&1));
/// let mut example2 = ScopedMap::new();
/// example2.insert(1, true);
/// example2.new_scope();
/// example2.insert(1, false);
/// assert_eq!(Some(&false), example2.get(&1));
/// example2.release_scope().expect("scope releases");
/// assert_eq!(Some(&true), example2.get(&1));
/// ```
pub fn new_scope(&mut self) {
self.scopes.push(HashMap::new());
}
/// Pop the current scope, returning to whatever was bound in the previous
/// scope. If there is no prior scope, `None` will be returned.
pub fn release_scope(&mut self) -> Option<HashMap<K, V>> {
self.scopes.pop()
}
/// Create a new scoped set by mapping over the values of this one.
pub fn map_values<F, W>(self, f: F) -> ScopedMap<K, W>
where
F: Fn(V) -> W,
{
let mut scopes = Vec::with_capacity(self.scopes.len());
for scope in self.scopes {
let mut map = HashMap::with_capacity(scope.len());
for (k, v) in scope {
map.insert(k, f(v));
}
scopes.push(map);
}
ScopedMap { scopes }
}
/// Returns true if this map is completely empty, at every level of
/// scope.
pub fn is_empty(&self) -> bool {
self.scopes.iter().all(|x| x.is_empty())
}
}
impl<K: Clone + Eq + Hash, V: Clone> ScopedMap<K, V> {
/// Returns the set of all variables bound at this time, with shadowed
/// variables hidden.
pub fn bindings(&self) -> HashMap<K, V> {
let mut result = HashMap::new();
for scope in self.scopes.iter().rev() {
for (key, value) in scope.iter() {
if !result.contains_key(key) {
result.insert(key.clone(), value.clone());
}
}
}
result
}
}

155
src/util/warning_result.rs Normal file
View File

@@ -0,0 +1,155 @@
use codespan_reporting::diagnostic::Diagnostic;
/// This type is like `Result`, except that both the Ok case and Err case
/// can also include warning data.
///
/// Unfortunately, this type cannot be used with `?`, and so should probably
/// only be used when it's really, really handy to be able to carry this
/// sort of information. But when it is handy (for example, in type checking
/// and in early validation), it's really useful.
pub enum WarningResult<T, W, E> {
Ok(T, Vec<W>),
Err(Vec<E>, Vec<W>),
}
impl<T, W, E> WarningResult<T, W, E> {
pub fn ok(value: T) -> Self {
WarningResult::Ok(value, vec![])
}
pub fn err(error: E) -> Self {
WarningResult::Err(vec![error], vec![])
}
pub fn is_ok(&self) -> bool {
matches!(self, WarningResult::Ok(_, _))
}
pub fn is_err(&self) -> bool {
matches!(self, WarningResult::Err(_, _))
}
pub fn warnings(&self) -> &[W] {
match self {
WarningResult::Ok(_, warns) => warns.as_slice(),
WarningResult::Err(_, warns) => warns.as_slice(),
}
}
pub fn into_result(self) -> Option<T> {
match self {
WarningResult::Ok(v, _) => Some(v),
WarningResult::Err(_, _) => None,
}
}
pub fn into_errors(self) -> Option<Vec<E>> {
match self {
WarningResult::Ok(_, _) => None,
WarningResult::Err(errs, _) => Some(errs),
}
}
pub fn add_warning(&mut self, warning: W) {
match self {
WarningResult::Ok(_, warns) => warns.push(warning),
WarningResult::Err(_, warns) => warns.push(warning),
}
}
pub fn add_error(&mut self, error: E) {
match self {
WarningResult::Ok(_, warns) => {
*self = WarningResult::Err(vec![error], std::mem::take(warns))
}
WarningResult::Err(errs, _) => errs.push(error),
}
}
pub fn modify<F>(&mut self, f: F)
where
F: FnOnce(&mut T),
{
if let WarningResult::Ok(v, _) = self {
f(v);
}
}
pub fn map<F, R>(self, f: F) -> WarningResult<R, W, E>
where
F: FnOnce(T) -> R,
{
match self {
WarningResult::Ok(v, ws) => WarningResult::Ok(f(v), ws),
WarningResult::Err(e, ws) => WarningResult::Err(e, ws),
}
}
/// Merges two results together using the given function to combine `Ok`
/// results into a single value.
///
pub fn merge_with<F, O>(mut self, other: WarningResult<O, W, E>, f: F) -> WarningResult<T, W, E>
where
F: FnOnce(T, O) -> Result<T, E>,
{
match self {
WarningResult::Err(ref mut errors1, ref mut warns1) => match other {
WarningResult::Err(mut errors2, mut warns2) => {
errors1.append(&mut errors2);
warns1.append(&mut warns2);
self
}
WarningResult::Ok(_, mut ws) => {
warns1.append(&mut ws);
self
}
},
WarningResult::Ok(value1, mut warns1) => match other {
WarningResult::Err(errors, mut warns2) => {
warns2.append(&mut warns1);
WarningResult::Err(errors, warns2)
}
WarningResult::Ok(value2, mut warns2) => {
warns1.append(&mut warns2);
match f(value1, value2) {
Ok(final_value) => WarningResult::Ok(final_value, warns1),
Err(e) => WarningResult::Err(vec![e], warns1),
}
}
},
}
}
}
impl<T, W, E> WarningResult<T, W, E>
where
W: Into<Diagnostic<usize>>,
E: Into<Diagnostic<usize>>,
{
/// Returns the complete set of diagnostics (warnings and errors) as an
/// Iterator.
///
/// This function removes the diagnostics from the result! So calling
/// this twice is not advised.
pub fn diagnostics(&mut self) -> impl Iterator<Item = Diagnostic<usize>> {
match self {
WarningResult::Err(errors, warnings) => std::mem::take(errors)
.into_iter()
.map(Into::into)
.chain(std::mem::take(warnings).into_iter().map(Into::into)),
WarningResult::Ok(_, warnings) =>
// this is a moderately ridiculous hack to get around
// the two match arms returning different iterator
// types
{
vec![]
.into_iter()
.map(Into::into)
.chain(std::mem::take(warnings).into_iter().map(Into::into))
}
}
}
}

21
src/util/weighted_map.rs Normal file
View File

@@ -0,0 +1,21 @@
use rand::distributions::{Distribution, WeightedIndex};
pub struct WeightedMap<T: Clone> {
index: WeightedIndex<usize>,
items: Vec<T>,
}
impl<T: Clone> WeightedMap<T> {
pub fn new(map: &[(usize, T)]) -> Self {
let index = WeightedIndex::new(map.iter().map(|x| x.0)).unwrap();
let items = map.iter().map(|x| x.1.clone()).collect();
WeightedMap { index, items }
}
}
impl<T: Clone> Distribution<T> for WeightedMap<T> {
fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> T {
let idx = self.index.sample(rng);
self.items.get(idx).unwrap().clone()
}
}