diff --git a/src/compiler.rs b/src/compiler.rs index 649b59a..dcfad8b 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -876,7 +876,7 @@ impl<'a> Compiler<'a> { pairs_len -= 1; } - let mut match_depth = self.match_depth; + let match_depth = self.match_depth; self.match_depth = 0; for pair in pairs.iter().take(pairs_len) { let (PairPattern(key, pattern), _) = pair else { @@ -922,9 +922,7 @@ impl<'a> Compiler<'a> { self.patch_jump(before_load_dict_idx, self.len() - before_load_dict_idx - 3); self.patch_jump(jump_idx, self.len() - jump_idx - 3); } - Splattern(patt) => { - self.visit(patt); - } + Splattern(patt) => self.visit(patt), InterpolatedPattern(parts, _) => { println!("An interpolated pattern of {} parts", parts.len()); let mut pattern = "".to_string(); @@ -1181,12 +1179,8 @@ impl<'a> Compiler<'a> { for idx in jump_idxes { self.patch_jump(idx, self.len() - idx - 3); - // self.chunk.bytecode[idx] = (self.len() - idx) as u8 - 1; } self.pop_n(self.stack_depth - stack_depth); - // while self.stack_depth > stack_depth { - // self.pop(); - // } self.emit_op(Op::Load); self.stack_depth += 1; } @@ -1198,10 +1192,12 @@ impl<'a> Compiler<'a> { unreachable!() }; - let mut compilers: HashMap = HashMap::new(); + let mut compilers: HashMap = HashMap::new(); let mut upvalues = vec![]; + let mut has_splat = false; + for clause in fn_body { let MatchClause(pattern, guard, clause_body) = &clause.0 else { unreachable!() @@ -1210,7 +1206,11 @@ impl<'a> Compiler<'a> { unreachable!() }; - let arity = pattern.len(); + if matches!(pattern.last(), Some((Splattern(_), _))) { + has_splat = true; + }; + + let arity = pattern.len() as u8; let compiler = match compilers.get_mut(&arity) { Some(compiler) => compiler, @@ -1223,9 +1223,9 @@ impl<'a> Compiler<'a> { } }; - compiler.stack_depth += arity; + compiler.stack_depth += arity as usize; compiler.scope_depth += 1; - compiler.match_depth = arity; + compiler.match_depth = arity as usize; std::mem::swap(&mut upvalues, &mut compiler.upvalues); @@ -1244,9 +1244,12 @@ impl<'a> Compiler<'a> { for idx in tup_jump_idxes { compiler.patch_jump(idx, compiler.len() - idx - 3); } - for _ in 0..arity { - compiler.emit_op(Op::Pop); - } + // compiler.pop_n(arity as usize); + compiler.emit_op(Op::PopN); + compiler.emit_byte(arity as usize); + // for _ in 0..arity { + // compiler.emit_op(Op::Pop); + // } compiler.patch_jump(jump_idx, compiler.len() - jump_idx - 3); let mut no_match_jumps = vec![]; no_match_jumps.push(compiler.stub_jump(Op::JumpIfNoMatch)); @@ -1278,22 +1281,36 @@ impl<'a> Compiler<'a> { std::mem::swap(&mut compiler.upvalues, &mut upvalues); } - let mut the_chunks = vec![]; + let mut compilers = compilers.into_iter().collect::>(); + compilers.sort_by(|(a, _), (b, _)| a.cmp(b)); - for (arity, mut compiler) in compilers.into_iter() { + let mut arities = vec![]; + let mut chunks = vec![]; + + for (arity, mut compiler) in compilers { compiler.emit_op(Op::PanicNoMatch); let chunk = compiler.chunk; if crate::DEBUG_COMPILE { println!("=== function chuncktion: {name}/{arity} ==="); chunk.dissasemble(); } - the_chunks.push((arity as u8, chunk)); + + arities.push(arity); + chunks.push(chunk); } + let splat = if has_splat { + arities.iter().fold(0, |max, curr| max.max(*curr)) + } else { + 0 + }; + let lfn = crate::value::LFn::Defined { name, doc: *doc, - chunks: the_chunks, + arities, + chunks, + splat, closed: RefCell::new(vec![]), }; @@ -1370,30 +1387,7 @@ impl<'a> Compiler<'a> { self.emit_byte(self.match_depth); self.visit(member); jnm_idxes.push(self.stub_jump(Op::JumpIfNoMatch)); - // self.emit_op(Op::JumpIfNoMatch); - // tup_jump_idxes.push(self.len()); - // self.emit_byte(0xff); } - // let jump_idx = self.stub_jump(Op::Jump); - // self.emit_op(Op::Jump); - // let jump_idx = self.len(); - // self.emit_byte(0xff); - // for idx in jnm_idxes { - // self.patch_jump(idx, self.len() - idx - 3); - // // self.chunk.bytecode[idx] = (self.len() - idx) as u8 - 2; - // } - // self.emit_op(Op::PopN); - // self.emit_byte(arity); - // for _ in 0..arity { - // self.emit_op(Op::Pop); - // } - // self.patch_jump(jump_idx, self.len() - jump_idx - 3); - // self.chunk.bytecode[jump_idx] = (self.len() - jump_idx) as u8 - 1; - // let mut no_match_jumps = vec![]; - // no_match_jumps.push(self.stub_jump(Op::JumpIfNoMatch)); - // self.emit_op(Op::JumpIfNoMatch); - // let jnm_idx = self.len(); - // self.emit_byte(0xff); if guard.is_some() { let guard_expr: &'static Spanned = Box::leak(Box::new(guard.clone().unwrap())); @@ -1414,26 +1408,17 @@ impl<'a> Compiler<'a> { while self.stack_depth > stack_depth { self.pop(); } - // self.stack_depth -= arity; jump_idxes.push(self.stub_jump(Op::Jump)); - // self.emit_op(Op::Jump); - // jump_idxes.push(self.len()); - // self.emit_byte(0xff); for idx in jnm_idxes { self.patch_jump(idx, self.len() - idx - 3); } - // self.chunk.bytecode[jnm_idx] = (self.len() - jnm_idx) as u8; self.scope_depth -= 1; } self.emit_op(Op::PanicNoMatch); for idx in jump_idxes { self.patch_jump(idx, self.len() - idx - 3); - // self.chunk.bytecode[idx] = (self.len() - idx) as u8 - 1; } - // self.emit_op(Op::PopN); - // self.emit_byte(arity); self.stack_depth -= arity; - // println!("Op::Load at end of loop at byte {}", self.len()); self.emit_op(Op::Load); self.stack_depth += 1; self.leave_loop(); @@ -1448,8 +1433,6 @@ impl<'a> Compiler<'a> { self.emit_byte(self.loop_root()); self.emit_op(Op::Load); self.jump(Op::JumpBack, self.len() - self.loop_idx()); - // self.emit_op(Op::JumpBack); - // self.emit_byte(self.len() - self.loop_idx()); } Panic(msg) => { self.visit(msg); diff --git a/src/main.rs b/src/main.rs index 0cb82bd..f7bb8bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,8 +75,15 @@ pub fn run(src: &'static str) { pub fn main() { env::set_var("RUST_BACKTRACE", "1"); let src = " -let (#{a, ...}, d, (e)) = (#{:a 1, :b 2}, :foo, (:bar)) -(a, d, e) +fn arity { + () -> 0 + (_) -> 1 + (_, _) -> 2 + (_, _, _) -> 3 + (_, _, _, ...args) -> (:a_lot, args) +} + +arity (1, 2, 3, 4, 5, 6) "; run(src); } diff --git a/src/value.rs b/src/value.rs index b260fa7..77ac905 100644 --- a/src/value.rs +++ b/src/value.rs @@ -14,7 +14,9 @@ pub enum LFn { Defined { name: &'static str, doc: Option<&'static str>, - chunks: Vec<(u8, Chunk)>, + arities: Vec, + chunks: Vec, + splat: u8, closed: RefCell>, }, } @@ -47,7 +49,23 @@ impl LFn { pub fn accepts(&self, arity: u8) -> bool { match self { - LFn::Defined { chunks, .. } => chunks.iter().any(|(a, _)| *a == arity), + LFn::Defined { arities, splat, .. } => { + if arities.contains(&arity) { + return true; + } + if *splat == 0 { + return false; + } + let max_arity = arities.iter().fold(0, |a, b| a.max(*b)); + arity > max_arity + } + LFn::Declared { .. } => unreachable!(), + } + } + + pub fn splat_arity(&self) -> u8 { + match self { + LFn::Defined { splat, .. } => *splat, LFn::Declared { .. } => unreachable!(), } } @@ -61,7 +79,18 @@ impl LFn { pub fn chunk(&self, arity: u8) -> &Chunk { match self { LFn::Declared { .. } => unreachable!(), - LFn::Defined { chunks, .. } => &chunks.iter().find(|(a, _)| *a == arity).unwrap().1, + LFn::Defined { + arities, + splat, + chunks, + .. + } => { + let chunk_pos = arities.iter().position(|a| arity == *a); + match chunk_pos { + Some(pos) => &chunks[pos as usize], + None => &chunks[*splat as usize], + } + } } } diff --git a/src/vm.rs b/src/vm.rs index 8691b48..e25bd96 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -2,11 +2,9 @@ use crate::base::BaseFn; use crate::compiler::{Chunk, Op}; use crate::parser::Ast; use crate::spans::Spanned; -use crate::value::{LFn, Partial, Value}; -use chumsky::prelude::SimpleSpan; +use crate::value::{LFn, Value}; use imbl::{HashMap, Vector}; use num_traits::FromPrimitive; -use std::cell::OnceCell; use std::cell::RefCell; use std::fmt; use std::mem::swap; @@ -97,7 +95,9 @@ impl Vm { let lfn = LFn::Defined { name: "user script", doc: None, - chunks: vec![(0, chunk)], + chunks: vec![chunk], + arities: vec![0], + splat: 0, closed: RefCell::new(vec![]), }; let base_fn = Value::Fn(Rc::new(lfn)); @@ -261,7 +261,7 @@ impl Vm { let cond = self.pop(); match cond { Value::Number(x) if x <= 0.0 => { - self.ip += jump_len as usize + 3; + self.ip += jump_len + 3; } Value::Number(..) => { self.ip += 3; @@ -868,6 +868,15 @@ impl Vm { val.show() )); } + let splat_arity = val.as_fn().splat_arity(); + if splat_arity > 0 && arity >= splat_arity { + let splatted_args = self.stack.split_off( + self.stack.len() - (arity - splat_arity) as usize - 1, + ); + let gathered_args = Vector::from(splatted_args); + self.push(Value::List(Box::new(gathered_args))); + } + let arity = splat_arity.min(arity); let mut frame = CallFrame { function: val, arity,