From 807b2e6ce0a174526ffffc0939c7d62d639b14e5 Mon Sep 17 00:00:00 2001 From: Scott Richmond Date: Thu, 19 Jun 2025 20:26:16 -0400 Subject: [PATCH] fix upvalue resolution --- may_2025_thoughts.md | 2 +- src/compiler.rs | 68 ++++++++++++++++++++++++++++++++------------ src/main.rs | 13 ++++----- src/value.rs | 3 +- src/vm.rs | 4 ++- 5 files changed, 62 insertions(+), 28 deletions(-) diff --git a/may_2025_thoughts.md b/may_2025_thoughts.md index 4443075..31d2116 100644 --- a/may_2025_thoughts.md +++ b/may_2025_thoughts.md @@ -242,7 +242,7 @@ println!("{a} // {b_high}/{b_low} // {c}"); To reiterate the punch list that *I would have needed for Computer Class 1*: * [x] jump instructions need 16 bits of operand - Whew, that took longer than I expected -* [ ] splatterns +* [x] splatterns - [ ] validator should ensure splatterns are the longest patterns in a form * [ ] improve validator - [ ] Tuples may not be longer than n members diff --git a/src/compiler.rs b/src/compiler.rs index dcfad8b..6f444b0 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -76,6 +76,7 @@ pub enum Op { Stringify, Call, + TailCall, Return, Partial, @@ -218,6 +219,7 @@ impl std::fmt::Display for Op { Call => "call", Return => "return", Partial => "partial", + TailCall => "tail_call", SetUpvalue => "set_upvalue", GetUpvalue => "get_upvalue", @@ -275,7 +277,8 @@ impl Chunk { PushBinding | MatchTuple | MatchSplattedTuple | LoadSplattedTuple | MatchList | MatchSplattedList | LoadSplattedList | MatchDict | MatchSplattedDict | DropDictEntry | LoadDictValue | PushTuple | PushBox | MatchDepth | PopN | StoreAt - | Call | SetUpvalue | GetUpvalue | Partial | MatchString | PushStringMatches => { + | Call | SetUpvalue | GetUpvalue | Partial | MatchString | PushStringMatches + | TailCall => { let next = self.bytecode[*i + 1]; println!("{i:04}: {:16} {next:03}", op.to_string()); *i += 1; @@ -357,6 +360,7 @@ pub struct Compiler<'a> { pub enclosing: Option<&'a Compiler<'a>>, pub upvalues: Vec, loop_info: Vec, + tail_pos: bool, } fn is_binding(expr: &Spanned) -> bool { @@ -401,6 +405,7 @@ impl<'a> Compiler<'a> { upvalues: vec![], src, name, + tail_pos: false, } } @@ -533,7 +538,7 @@ impl<'a> Compiler<'a> { } None => match self.resolve_upvalue(name) { Some(position) => { - println!("resolved upvalue: {name}"); + println!("resolved upvalue: {name} at {position}"); self.emit_op(Op::GetUpvalue); self.emit_byte(position); self.stack_depth += 1; @@ -542,7 +547,7 @@ impl<'a> Compiler<'a> { println!("setting upvalue: {name}"); let upvalue = self.get_upvalue(name); self.emit_op(Op::GetUpvalue); - self.emit_byte(upvalue.stack_pos); + self.emit_byte(self.upvalues.len()); self.upvalues.push(upvalue); dbg!(&self.upvalues); self.stack_depth += 1; @@ -551,16 +556,6 @@ impl<'a> Compiler<'a> { } } - // fn resolve_binding(&self, name: &'static str) -> usize { - // if let Some(pos) = self.resolve_local(name) { - // return pos; - // } - // match self.enclosing { - // Some(compiler) => compiler.resolve_binding(name), - // None => unreachable!(), - // } - // } - fn pop(&mut self) { self.emit_op(Op::Pop); self.stack_depth -= 1; @@ -617,6 +612,8 @@ impl<'a> Compiler<'a> { } Keyword(s) => self.emit_constant(Value::Keyword(s)), Block(lines) => { + let tail_pos = self.tail_pos; + self.tail_pos = false; // increase the scope self.scope_depth += 1; // stash the stack depth @@ -650,6 +647,7 @@ impl<'a> Compiler<'a> { } // otherwise, just evaluate it and leave the value on the stack _ => { + self.tail_pos = tail_pos; self.visit(last_expr); } } @@ -675,9 +673,12 @@ impl<'a> Compiler<'a> { self.emit_op(Op::Load); } If(cond, then, r#else) => { + let tail_pos = self.tail_pos; + self.tail_pos = false; self.visit(cond); let jif_idx = self.stub_jump(Op::JumpIfFalse); self.stack_depth -= 1; + self.tail_pos = tail_pos; self.visit(then); let jump_idx = self.stub_jump(Op::Jump); self.visit(r#else); @@ -1015,7 +1016,10 @@ impl<'a> Compiler<'a> { self.emit_constant(Value::Keyword(key)); self.visit(value); } + // TODO: thread tail position through this Synthetic(first, second, rest) => { + let tail_pos = self.tail_pos; + self.tail_pos = false; match (&first.0, &second.0) { (Word(_), Keyword(_)) => { self.visit(first); @@ -1090,7 +1094,15 @@ impl<'a> Compiler<'a> { self.visit(arg); } self.resolve_binding(fn_name); - self.emit_op(Op::Call); + // if we're in tail position AND there aren't any rest args, this should be a tail call (I think) + if rest.is_empty() { + self.tail_pos = tail_pos; + } + if self.tail_pos { + self.emit_op(Op::TailCall); + } else { + self.emit_op(Op::Call); + } self.emit_byte(arity); self.stack_depth -= arity; } @@ -1099,6 +1111,7 @@ impl<'a> Compiler<'a> { } _ => unreachable!(), } + // the last term in rest should be in tail position if we are in tail position for (term, _) in rest { match term { Keyword(str) => { @@ -1123,12 +1136,15 @@ impl<'a> Compiler<'a> { } } When(clauses) => { + let tail_pos = self.tail_pos; let mut jump_idxes = vec![]; let mut clauses = clauses.iter(); while let Some((WhenClause(cond, body), _)) = clauses.next() { + self.tail_pos = false; self.visit(cond.as_ref()); let jif_jump_idx = self.stub_jump(Op::JumpIfFalse); self.stack_depth -= 1; + self.tail_pos = tail_pos; self.visit(body); self.stack_depth -= 1; jump_idxes.push(self.stub_jump(Op::Jump)); @@ -1142,11 +1158,14 @@ impl<'a> Compiler<'a> { } WhenClause(..) => unreachable!(), Match(scrutinee, clauses) => { + let tail_pos = self.tail_pos; + self.tail_pos = false; self.visit(scrutinee.as_ref()); let stack_depth = self.stack_depth; let mut jump_idxes = vec![]; let mut clauses = clauses.iter(); while let Some((MatchClause(pattern, guard, body), _)) = clauses.next() { + self.tail_pos = false; let mut no_match_jumps = vec![]; self.scope_depth += 1; self.match_depth = 0; @@ -1159,6 +1178,7 @@ impl<'a> Compiler<'a> { no_match_jumps.push(self.stub_jump(Op::JumpIfFalse)); self.stack_depth -= 1; } + self.tail_pos = tail_pos; self.visit(body); self.emit_op(Op::Store); self.scope_depth -= 1; @@ -1222,6 +1242,7 @@ impl<'a> Compiler<'a> { compilers.get_mut(&arity).unwrap() } }; + compiler.tail_pos = false; compiler.stack_depth += arity as usize; compiler.scope_depth += 1; @@ -1244,12 +1265,8 @@ impl<'a> Compiler<'a> { for idx in tup_jump_idxes { compiler.patch_jump(idx, compiler.len() - idx - 3); } - // compiler.pop_n(arity as usize); compiler.emit_op(Op::PopN); compiler.emit_byte(arity as usize); - // for _ in 0..arity { - // compiler.emit_op(Op::Pop); - // } compiler.patch_jump(jump_idx, compiler.len() - jump_idx - 3); let mut no_match_jumps = vec![]; no_match_jumps.push(compiler.stub_jump(Op::JumpIfNoMatch)); @@ -1260,6 +1277,7 @@ impl<'a> Compiler<'a> { no_match_jumps.push(compiler.stub_jump(Op::JumpIfFalse)); compiler.stack_depth -= 1; } + compiler.tail_pos = true; compiler.visit(clause_body); compiler.emit_op(Op::Store); compiler.scope_depth -= 1; @@ -1354,6 +1372,8 @@ impl<'a> Compiler<'a> { self.emit_constant(Value::Nil); } Loop(value, clauses) => { + let tail_pos = self.tail_pos; + self.tail_pos = false; //algo: //first, put the values on the stack let (Ast::Tuple(members), _) = value.as_ref() else { @@ -1374,6 +1394,7 @@ impl<'a> Compiler<'a> { let mut clauses = clauses.iter(); let mut jump_idxes = vec![]; while let Some((Ast::MatchClause(pattern, guard, body), _)) = clauses.next() { + self.tail_pos = false; self.emit_op(Op::ResetMatch); self.scope_depth += 1; let (Ast::TuplePattern(members), _) = pattern.as_ref() else { @@ -1395,6 +1416,7 @@ impl<'a> Compiler<'a> { jnm_idxes.push(self.stub_jump(Op::JumpIfFalse)); self.stack_depth -= 1; } + self.tail_pos = tail_pos; self.visit(body); self.emit_op(Op::Store); self.scope_depth -= 1; @@ -1462,7 +1484,12 @@ impl<'a> Compiler<'a> { Do(terms) => { let mut terms = terms.iter(); let first = terms.next().unwrap(); + let mut terms = terms.rev(); + let last = terms.next().unwrap(); + let terms = terms.rev(); // put the first value on the stack + let tail_pos = self.tail_pos; + self.tail_pos = false; self.visit(first); for term in terms { self.visit(term); @@ -1470,6 +1497,11 @@ impl<'a> Compiler<'a> { self.emit_byte(1); self.stack_depth -= 1; } + self.tail_pos = tail_pos; + self.visit(last); + self.emit_op(Op::Call); + self.emit_byte(1); + self.stack_depth -= 1; } Placeholder => { self.emit_op(Op::Nothing); diff --git a/src/main.rs b/src/main.rs index f7bb8bc..ee4a7c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,15 +75,14 @@ pub fn run(src: &'static str) { pub fn main() { env::set_var("RUST_BACKTRACE", "1"); let src = " -fn arity { - () -> 0 - (_) -> 1 - (_, _) -> 2 - (_, _, _) -> 3 - (_, _, _, ...args) -> (:a_lot, args) +fn nope () -> nil +fn foo () -> { + nope () + :foo } +fn bar () -> foo () -arity (1, 2, 3, 4, 5, 6) +bar () "; run(src); } diff --git a/src/value.rs b/src/value.rs index 77ac905..6a21336 100644 --- a/src/value.rs +++ b/src/value.rs @@ -26,6 +26,7 @@ impl LFn { match self { LFn::Declared { .. } => unreachable!(), LFn::Defined { closed, .. } => { + println!("closing over in {}: {value}", self.name()); closed.borrow_mut().push(value); } } @@ -87,7 +88,7 @@ impl LFn { } => { let chunk_pos = arities.iter().position(|a| arity == *a); match chunk_pos { - Some(pos) => &chunks[pos as usize], + Some(pos) => &chunks[pos], None => &chunks[*splat as usize], } } diff --git a/src/vm.rs b/src/vm.rs index e25bd96..a7b4fc0 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -854,7 +854,7 @@ impl Vm { }; self.push(Value::Partial(Rc::new(partial))); } - Call => { + Call | TailCall => { let arity = self.chunk().bytecode[self.ip + 1]; self.ip += 2; @@ -961,6 +961,8 @@ impl Vm { GetUpvalue => { let idx = self.chunk().bytecode[self.ip + 1]; self.ip += 2; + println!("getting upvalue {idx}"); + dbg!(&self.frame.function); if let Value::Fn(ref inner) = self.frame.function { self.push(inner.as_ref().upvalue(idx)); } else {