compile guards in match forms

This commit is contained in:
Scott Richmond 2024-12-26 23:46:06 -05:00
parent cfe0b83192
commit f5965fdb44
2 changed files with 74 additions and 42 deletions

View File

@ -237,6 +237,15 @@ impl Compiler {
self.spans.push(self.span); self.spans.push(self.span);
} }
fn emit_byte(&mut self, byte: usize) {
self.chunk.bytecode.push(byte as u8);
self.spans.push(self.span);
}
fn len(&self) -> usize {
self.chunk.bytecode.len()
}
fn bind(&mut self, name: &'static str) { fn bind(&mut self, name: &'static str) {
self.bindings.push(Binding { self.bindings.push(Binding {
name, name,
@ -303,15 +312,15 @@ impl Compiler {
} }
If(cond, then, r#else) => { If(cond, then, r#else) => {
self.visit(cond); self.visit(cond);
let jif_idx = self.chunk.bytecode.len(); let jif_idx = self.len();
self.emit_op(Op::JumpIfFalse); self.emit_op(Op::JumpIfFalse);
self.chunk.bytecode.push(0xff); self.emit_byte(0xff);
self.visit(then); self.visit(then);
let jump_idx = self.chunk.bytecode.len(); let jump_idx = self.len();
self.emit_op(Op::Jump); self.emit_op(Op::Jump);
self.chunk.bytecode.push(0xff); self.emit_byte(0xff);
self.visit(r#else); self.visit(r#else);
let end_idx = self.chunk.bytecode.len(); let end_idx = self.len();
let jif_offset = jump_idx - jif_idx; let jif_offset = jump_idx - jif_idx;
let jump_offset = end_idx - jump_idx - 2; let jump_offset = end_idx - jump_idx - 2;
self.chunk.bytecode[jif_idx + 1] = jif_offset as u8; self.chunk.bytecode[jif_idx + 1] = jif_offset as u8;
@ -332,7 +341,7 @@ impl Compiler {
let biter = self.bindings.iter().enumerate().rev(); let biter = self.bindings.iter().enumerate().rev();
for (i, binding) in biter { for (i, binding) in biter {
if binding.name == *name { if binding.name == *name {
self.chunk.bytecode.push(i as u8); self.emit_byte(i);
break; break;
} }
} }
@ -384,14 +393,14 @@ impl Compiler {
self.visit(member); self.visit(member);
} }
self.emit_op(Op::PushTuple); self.emit_op(Op::PushTuple);
self.chunk.bytecode.push(members.len() as u8); self.emit_byte(members.len());
} }
List(members) => { List(members) => {
for member in members { for member in members {
self.visit(member); self.visit(member);
} }
self.emit_op(Op::PushList); self.emit_op(Op::PushList);
self.chunk.bytecode.push(members.len() as u8); self.emit_byte(members.len());
} }
LBox(name, expr) => { LBox(name, expr) => {
self.visit(expr); self.visit(expr);
@ -403,7 +412,7 @@ impl Compiler {
self.visit(pair); self.visit(pair);
} }
self.emit_op(Op::PushDict); self.emit_op(Op::PushDict);
self.chunk.bytecode.push(pairs.len() as u8); self.emit_byte(pairs.len());
} }
Pair(key, value) => { Pair(key, value) => {
let existing_kw = self.chunk.keywords.iter().position(|kw| kw == key); let existing_kw = self.chunk.keywords.iter().position(|kw| kw == key);
@ -445,18 +454,17 @@ impl Compiler {
while let Some((WhenClause(cond, body), _)) = clauses.next() { while let Some((WhenClause(cond, body), _)) = clauses.next() {
self.visit(cond.as_ref()); self.visit(cond.as_ref());
self.emit_op(Op::JumpIfFalse); self.emit_op(Op::JumpIfFalse);
let jif_jump_idx = self.chunk.bytecode.len(); let jif_jump_idx = self.len();
self.chunk.bytecode.push(0xff); self.emit_byte(0xff);
self.visit(body); self.visit(body);
self.emit_op(Op::Jump); self.emit_op(Op::Jump);
jump_idxes.push(self.chunk.bytecode.len()); jump_idxes.push(self.len());
self.chunk.bytecode.push(0xff); self.emit_byte(0xff);
self.chunk.bytecode[jif_jump_idx] = self.chunk.bytecode[jif_jump_idx] = self.len() as u8 - jif_jump_idx as u8 - 1;
self.chunk.bytecode.len() as u8 - jif_jump_idx as u8 - 1;
} }
self.emit_op(Op::PanicNoWhen); self.emit_op(Op::PanicNoWhen);
for idx in jump_idxes { for idx in jump_idxes {
self.chunk.bytecode[idx] = self.chunk.bytecode.len() as u8 - idx as u8 + 1; self.chunk.bytecode[idx] = self.len() as u8 - idx as u8 + 1;
} }
} }
WhenClause(..) => unreachable!(), WhenClause(..) => unreachable!(),
@ -465,33 +473,57 @@ impl Compiler {
let mut jump_idxes = vec![]; let mut jump_idxes = vec![];
let mut clauses = clauses.iter(); let mut clauses = clauses.iter();
// TODO: add guard checking // TODO: add guard checking
while let Some((MatchClause(pattern, _, body), _)) = clauses.next() { while let Some((MatchClause(pattern, guard, body), _)) = clauses.next() {
self.scope_depth += 1; self.scope_depth += 1;
self.visit(pattern); self.visit(pattern);
self.emit_op(Op::JumpIfNoMatch); self.emit_op(Op::JumpIfNoMatch);
let jnm_jump_idx = self.chunk.bytecode.len(); let jnm_jump_idx = self.len();
self.chunk.bytecode.push(0xff); self.emit_byte(0xff);
self.visit(body); if let Some(expr) = guard.as_ref() {
self.emit_op(Op::Store); self.visit(expr);
self.scope_depth -= 1; self.emit_op(Op::JumpIfFalse);
while let Some(binding) = self.bindings.last() { let jif_idx = self.len();
if binding.depth > self.scope_depth { self.emit_byte(0xff);
self.emit_op(Op::Pop); self.visit(body);
self.bindings.pop(); self.emit_op(Op::Store);
} else { self.scope_depth -= 1;
break; while let Some(binding) = self.bindings.last() {
if binding.depth > self.scope_depth {
self.emit_op(Op::Pop);
self.bindings.pop();
} else {
break;
}
} }
self.emit_op(Op::Jump);
jump_idxes.push(self.len());
self.emit_byte(0xff);
self.chunk.bytecode[jnm_jump_idx] =
self.len() as u8 - jnm_jump_idx as u8 - 1;
self.chunk.bytecode[jif_idx] = self.len() as u8 - jif_idx as u8 - 1;
} else {
self.visit(body);
self.emit_op(Op::Store);
self.scope_depth -= 1;
while let Some(binding) = self.bindings.last() {
if binding.depth > self.scope_depth {
self.emit_op(Op::Pop);
self.bindings.pop();
} else {
break;
}
}
self.emit_op(Op::Jump);
jump_idxes.push(self.len());
self.emit_byte(0xff);
self.chunk.bytecode[jnm_jump_idx] =
self.len() as u8 - jnm_jump_idx as u8 - 1;
} }
self.emit_op(Op::Jump);
jump_idxes.push(self.chunk.bytecode.len());
self.chunk.bytecode.push(0xff);
self.chunk.bytecode[jnm_jump_idx] =
self.chunk.bytecode.len() as u8 - jnm_jump_idx as u8 - 1;
} }
self.emit_op(Op::PanicNoMatch); self.emit_op(Op::PanicNoMatch);
self.emit_op(Op::Load); self.emit_op(Op::Load);
for idx in jump_idxes { for idx in jump_idxes {
self.chunk.bytecode[idx] = self.chunk.bytecode.len() as u8 - idx as u8 + 2; self.chunk.bytecode[idx] = self.len() as u8 - idx as u8;
} }
} }
MatchClause(..) => unreachable!(), MatchClause(..) => unreachable!(),
@ -534,21 +566,21 @@ impl Compiler {
self.emit_op(Op::Truncate); self.emit_op(Op::Truncate);
// skip the decrement the first time // skip the decrement the first time
self.emit_op(Op::Jump); self.emit_op(Op::Jump);
self.chunk.bytecode.push(1); self.emit_byte(1);
// begin repeat // begin repeat
self.emit_op(Op::Decrement); self.emit_op(Op::Decrement);
let repeat_begin = self.chunk.bytecode.len(); let repeat_begin = self.len();
self.emit_op(Op::Duplicate); self.emit_op(Op::Duplicate);
self.emit_op(Op::JumpIfZero); self.emit_op(Op::JumpIfZero);
self.chunk.bytecode.push(0xff); self.emit_byte(0xff);
// compile the body // compile the body
self.visit(body); self.visit(body);
// pop whatever value the body returns // pop whatever value the body returns
self.emit_op(Op::Pop); self.emit_op(Op::Pop);
self.emit_op(Op::JumpBack); self.emit_op(Op::JumpBack);
// set jump points // set jump points
let repeat_end = self.chunk.bytecode.len(); let repeat_end = self.len();
self.chunk.bytecode.push((repeat_end - repeat_begin) as u8); self.emit_byte(repeat_end - repeat_begin);
self.chunk.bytecode[repeat_begin + 2] = (repeat_end - repeat_begin - 2) as u8; self.chunk.bytecode[repeat_begin + 2] = (repeat_end - repeat_begin - 2) as u8;
// pop the counter // pop the counter
self.emit_op(Op::Pop); self.emit_op(Op::Pop);

View File

@ -68,9 +68,9 @@ pub fn run(src: &'static str) {
pub fn main() { pub fn main() {
let src = " let src = "
let foo = 4 match :foo with {
repeat foo { :foo if true -> :oops
:foo :foo if true -> :yay!
} }
"; ";
run(src); run(src);