From f5965fdb44e448a51006625fa2f2b4766a94cecd Mon Sep 17 00:00:00 2001 From: Scott Richmond Date: Thu, 26 Dec 2024 23:46:06 -0500 Subject: [PATCH] compile guards in match forms --- src/compiler.rs | 110 +++++++++++++++++++++++++++++++----------------- src/main.rs | 6 +-- 2 files changed, 74 insertions(+), 42 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index 72518d5..9558855 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -237,6 +237,15 @@ impl Compiler { self.spans.push(self.span); } + fn emit_byte(&mut self, byte: usize) { + self.chunk.bytecode.push(byte as u8); + self.spans.push(self.span); + } + + fn len(&self) -> usize { + self.chunk.bytecode.len() + } + fn bind(&mut self, name: &'static str) { self.bindings.push(Binding { name, @@ -303,15 +312,15 @@ impl Compiler { } If(cond, then, r#else) => { self.visit(cond); - let jif_idx = self.chunk.bytecode.len(); + let jif_idx = self.len(); self.emit_op(Op::JumpIfFalse); - self.chunk.bytecode.push(0xff); + self.emit_byte(0xff); self.visit(then); - let jump_idx = self.chunk.bytecode.len(); + let jump_idx = self.len(); self.emit_op(Op::Jump); - self.chunk.bytecode.push(0xff); + self.emit_byte(0xff); self.visit(r#else); - let end_idx = self.chunk.bytecode.len(); + let end_idx = self.len(); let jif_offset = jump_idx - jif_idx; let jump_offset = end_idx - jump_idx - 2; self.chunk.bytecode[jif_idx + 1] = jif_offset as u8; @@ -332,7 +341,7 @@ impl Compiler { let biter = self.bindings.iter().enumerate().rev(); for (i, binding) in biter { if binding.name == *name { - self.chunk.bytecode.push(i as u8); + self.emit_byte(i); break; } } @@ -384,14 +393,14 @@ impl Compiler { self.visit(member); } self.emit_op(Op::PushTuple); - self.chunk.bytecode.push(members.len() as u8); + self.emit_byte(members.len()); } List(members) => { for member in members { self.visit(member); } self.emit_op(Op::PushList); - self.chunk.bytecode.push(members.len() as u8); + self.emit_byte(members.len()); } LBox(name, expr) => { self.visit(expr); @@ -403,7 +412,7 @@ impl Compiler { self.visit(pair); } self.emit_op(Op::PushDict); - self.chunk.bytecode.push(pairs.len() as u8); + self.emit_byte(pairs.len()); } Pair(key, value) => { let existing_kw = self.chunk.keywords.iter().position(|kw| kw == key); @@ -445,18 +454,17 @@ impl Compiler { while let Some((WhenClause(cond, body), _)) = clauses.next() { self.visit(cond.as_ref()); self.emit_op(Op::JumpIfFalse); - let jif_jump_idx = self.chunk.bytecode.len(); - self.chunk.bytecode.push(0xff); + let jif_jump_idx = self.len(); + self.emit_byte(0xff); self.visit(body); self.emit_op(Op::Jump); - jump_idxes.push(self.chunk.bytecode.len()); - self.chunk.bytecode.push(0xff); - self.chunk.bytecode[jif_jump_idx] = - self.chunk.bytecode.len() as u8 - jif_jump_idx as u8 - 1; + jump_idxes.push(self.len()); + self.emit_byte(0xff); + self.chunk.bytecode[jif_jump_idx] = self.len() as u8 - jif_jump_idx as u8 - 1; } self.emit_op(Op::PanicNoWhen); for idx in jump_idxes { - self.chunk.bytecode[idx] = self.chunk.bytecode.len() as u8 - idx as u8 + 1; + self.chunk.bytecode[idx] = self.len() as u8 - idx as u8 + 1; } } WhenClause(..) => unreachable!(), @@ -465,33 +473,57 @@ impl Compiler { let mut jump_idxes = vec![]; let mut clauses = clauses.iter(); // TODO: add guard checking - while let Some((MatchClause(pattern, _, body), _)) = clauses.next() { + while let Some((MatchClause(pattern, guard, body), _)) = clauses.next() { self.scope_depth += 1; self.visit(pattern); self.emit_op(Op::JumpIfNoMatch); - let jnm_jump_idx = self.chunk.bytecode.len(); - self.chunk.bytecode.push(0xff); - self.visit(body); - self.emit_op(Op::Store); - self.scope_depth -= 1; - while let Some(binding) = self.bindings.last() { - if binding.depth > self.scope_depth { - self.emit_op(Op::Pop); - self.bindings.pop(); - } else { - break; + let jnm_jump_idx = self.len(); + self.emit_byte(0xff); + if let Some(expr) = guard.as_ref() { + self.visit(expr); + self.emit_op(Op::JumpIfFalse); + let jif_idx = self.len(); + self.emit_byte(0xff); + self.visit(body); + self.emit_op(Op::Store); + self.scope_depth -= 1; + while let Some(binding) = self.bindings.last() { + if binding.depth > self.scope_depth { + self.emit_op(Op::Pop); + self.bindings.pop(); + } else { + break; + } } + self.emit_op(Op::Jump); + jump_idxes.push(self.len()); + self.emit_byte(0xff); + self.chunk.bytecode[jnm_jump_idx] = + self.len() as u8 - jnm_jump_idx as u8 - 1; + self.chunk.bytecode[jif_idx] = self.len() as u8 - jif_idx as u8 - 1; + } else { + self.visit(body); + self.emit_op(Op::Store); + self.scope_depth -= 1; + while let Some(binding) = self.bindings.last() { + if binding.depth > self.scope_depth { + self.emit_op(Op::Pop); + self.bindings.pop(); + } else { + break; + } + } + self.emit_op(Op::Jump); + jump_idxes.push(self.len()); + self.emit_byte(0xff); + self.chunk.bytecode[jnm_jump_idx] = + self.len() as u8 - jnm_jump_idx as u8 - 1; } - self.emit_op(Op::Jump); - jump_idxes.push(self.chunk.bytecode.len()); - self.chunk.bytecode.push(0xff); - self.chunk.bytecode[jnm_jump_idx] = - self.chunk.bytecode.len() as u8 - jnm_jump_idx as u8 - 1; } self.emit_op(Op::PanicNoMatch); self.emit_op(Op::Load); for idx in jump_idxes { - self.chunk.bytecode[idx] = self.chunk.bytecode.len() as u8 - idx as u8 + 2; + self.chunk.bytecode[idx] = self.len() as u8 - idx as u8; } } MatchClause(..) => unreachable!(), @@ -534,21 +566,21 @@ impl Compiler { self.emit_op(Op::Truncate); // skip the decrement the first time self.emit_op(Op::Jump); - self.chunk.bytecode.push(1); + self.emit_byte(1); // begin repeat self.emit_op(Op::Decrement); - let repeat_begin = self.chunk.bytecode.len(); + let repeat_begin = self.len(); self.emit_op(Op::Duplicate); self.emit_op(Op::JumpIfZero); - self.chunk.bytecode.push(0xff); + self.emit_byte(0xff); // compile the body self.visit(body); // pop whatever value the body returns self.emit_op(Op::Pop); self.emit_op(Op::JumpBack); // set jump points - let repeat_end = self.chunk.bytecode.len(); - self.chunk.bytecode.push((repeat_end - repeat_begin) as u8); + let repeat_end = self.len(); + self.emit_byte(repeat_end - repeat_begin); self.chunk.bytecode[repeat_begin + 2] = (repeat_end - repeat_begin - 2) as u8; // pop the counter self.emit_op(Op::Pop); diff --git a/src/main.rs b/src/main.rs index b763e90..ca23ed5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -68,9 +68,9 @@ pub fn run(src: &'static str) { pub fn main() { let src = " -let foo = 4 -repeat foo { - :foo +match :foo with { + :foo if true -> :oops + :foo if true -> :yay! } "; run(src);