diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index a1836ea9dfb3a4..6999348822deca 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -1,5 +1,3 @@ -use std::mem::take; - use crate::asm::{CodeBlock, Label}; use crate::asm::arm64::*; use crate::codegen::split_patch_point; @@ -390,7 +388,7 @@ impl Assembler { } let mut asm_local = Assembler::new_with_asm(&self); - let live_ranges = take(&mut self.live_ranges); + let live_ranges = self.compute_live_ranges(); let mut iterator = self.instruction_iterator(); let asm = &mut asm_local; diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index b2ec95a9d4a84c..300763d15d588d 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -15,6 +15,7 @@ use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_opcode, side_exit_coun use crate::virtualmem::CodePtr; use crate::asm::{CodeBlock, Label}; use crate::state::rb_zjit_record_exit_stack; +use crate::bitset::BitSet; /// LIR Block ID. Unique ID for each block, and also defined in LIR so /// we can differentiate it from HIR block ids. @@ -1670,7 +1671,8 @@ impl Assembler // one assembler to a new one. pub fn new_block_from_old_block(&mut self, old_block: &BasicBlock) -> BlockId { let bb_id = BlockId(self.basic_blocks.len()); - let lir_bb = BasicBlock::new(bb_id, old_block.hir_block_id, old_block.is_entry, old_block.rpo_index); + let mut lir_bb = BasicBlock::new(bb_id, old_block.hir_block_id, old_block.is_entry, old_block.rpo_index); + lir_bb.parameters = old_block.parameters.clone(); self.basic_blocks.push(lir_bb); bb_id } @@ -1776,9 +1778,10 @@ impl Assembler // Helper to process branch arguments and return the label target let mut process_edge = |edge: &BranchEdge| -> Label { if !edge.args.is_empty() { + let params = &self.basic_blocks[edge.target.0].parameters; insns.push(Insn::ParallelMov { moves: edge.args.iter().enumerate() - .map(|(idx, &arg)| (Assembler::param_opnd(idx), arg)) + .map(|(idx, &arg)| (params[idx], arg)) .collect() }); } @@ -1839,6 +1842,215 @@ impl Assembler }; } + pub(super) fn compute_live_ranges(&self) -> LiveRanges { + fn record_use(opnd: Opnd, defs: &BitSet, uses: &mut BitSet) { + match opnd { + Opnd::VReg { idx, .. } => { + if !defs.get(idx) { + uses.insert(idx); + } + } + Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { + if !defs.get(idx) { + uses.insert(idx); + } + } + _ => {} + } + } + + fn update_start(range: &mut LiveRange, pos: usize) { + if range.start.is_none() || range.start.unwrap() > pos { + range.start = Some(pos); + } + } + + fn update_end(range: &mut LiveRange, pos: usize) { + if range.end.is_none() || range.end.unwrap() < pos { + range.end = Some(pos); + } + } + + let num_vregs = self.live_ranges.len(); + let num_blocks = self.basic_blocks.len(); + let mut block_uses: Vec> = (0..num_blocks) + .map(|_| BitSet::with_capacity(num_vregs)) + .collect(); + let mut block_defs: Vec> = (0..num_blocks) + .map(|_| BitSet::with_capacity(num_vregs)) + .collect(); + + for block in &self.basic_blocks { + let block_id = block.id.0; + let mut defs = BitSet::with_capacity(num_vregs); + let mut uses = BitSet::with_capacity(num_vregs); + + for param in &block.parameters { + if let Opnd::VReg { idx, .. } = param { + defs.insert(*idx); + } + } + + for insn in &block.insns { + match insn { + Insn::ParallelMov { moves } => { + for (dst, src) in moves { + record_use(*src, &defs, &mut uses); + match dst { + Opnd::VReg { idx, .. } => { + defs.insert(*idx); + } + _ => record_use(*dst, &defs, &mut uses), + } + } + } + _ => { + for opnd in insn.opnd_iter() { + record_use(*opnd, &defs, &mut uses); + } + if let Some(Opnd::VReg { idx, .. }) = insn.out_opnd() { + defs.insert(*idx); + } + } + } + } + + block_uses[block_id] = uses; + block_defs[block_id] = defs; + } + + let mut live_in: Vec> = (0..num_blocks) + .map(|_| BitSet::with_capacity(num_vregs)) + .collect(); + let mut live_out: Vec> = (0..num_blocks) + .map(|_| BitSet::with_capacity(num_vregs)) + .collect(); + + let sorted_blocks = self.sorted_blocks(); + loop { + let mut changed = false; + for block in sorted_blocks.iter().rev() { + let block_id = block.id.0; + let mut new_out = BitSet::with_capacity(num_vregs); + + for insn in &block.insns { + if !insn.is_terminator() { + continue; + } + if let Some(Target::Block(edge)) = insn.target() { + new_out.union_with(&live_in[edge.target.0]); + } + } + + if new_out != live_out[block_id] { + live_out[block_id] = new_out.clone(); + changed = true; + } + + let mut new_in = new_out; + new_in.subtract_with(&block_defs[block_id]); + new_in.union_with(&block_uses[block_id]); + + if new_in != live_in[block_id] { + live_in[block_id] = new_in; + changed = true; + } + } + + if !changed { + break; + } + } + + let mut block_start = vec![0usize; num_blocks]; + let mut block_end = vec![0usize; num_blocks]; + let mut idx = 0usize; + for block in &sorted_blocks { + block_start[block.id.0] = idx; + if block.insns.is_empty() { + block_end[block.id.0] = idx; + } else { + block_end[block.id.0] = idx + block.insns.len() - 1; + idx += block.insns.len(); + } + } + + let mut live_ranges = LiveRanges::new(num_vregs); + for block in &sorted_blocks { + let block_id = block.id.0; + let start_idx = block_start[block_id]; + let end_idx = block_end[block_id]; + + for param in &block.parameters { + if let Opnd::VReg { idx, .. } = param { + let range = &mut live_ranges[*idx]; + update_start(range, start_idx); + update_end(range, start_idx); + } + } + + for vreg_idx in live_out[block_id].iter_indices() { + let idx = VRegId(vreg_idx); + let range = &mut live_ranges[idx]; + update_end(range, end_idx); + if range.start.is_none() { + update_start(range, start_idx); + } + } + + let mut insn_idx = start_idx; + for insn in &block.insns { + for opnd in insn.opnd_iter() { + match *opnd { + Opnd::VReg { idx, .. } => { + let range = &mut live_ranges[idx]; + update_end(range, insn_idx); + if range.start.is_none() { + update_start(range, insn_idx); + } + } + Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { + let range = &mut live_ranges[idx]; + update_end(range, insn_idx); + if range.start.is_none() { + update_start(range, insn_idx); + } + } + _ => {} + } + } + + if let Some(Opnd::VReg { idx, .. }) = insn.out_opnd() { + let range = &mut live_ranges[*idx]; + update_start(range, insn_idx); + if range.end.is_none() { + update_end(range, insn_idx); + } + } + + insn_idx += 1; + } + } + + for range in &mut live_ranges.0 { + match (range.start, range.end) { + (None, None) => { + range.start = Some(0); + range.end = Some(0); + } + (None, Some(end)) => { + range.start = Some(end); + } + (Some(start), None) => { + range.end = Some(start); + } + _ => {} + } + } + + live_ranges + } + /// Build an Opnd::VReg and initialize its LiveRange pub(super) fn new_vreg(&mut self, num_bits: u8) -> Opnd { let vreg = Opnd::VReg { idx: VRegId(self.live_ranges.len()), num_bits }; @@ -1846,6 +2058,11 @@ impl Assembler vreg } + /// Create a new block parameter VReg without emitting an instruction. + pub fn new_param_opnd(&mut self) -> Opnd { + self.new_vreg(Opnd::DEFAULT_NUM_BITS) + } + /// Append an instruction onto the current list of instructions and update /// the live ranges of any instructions whose outputs are being used as /// operands to this instruction. @@ -1960,16 +2177,33 @@ impl Assembler // Remember the indexes of Insn::FrameSetup to update the stack size later let mut frame_setup_idxs: Vec<(BlockId, usize)> = vec![]; - // live_ranges is indexed by original `index` given by the iterator. + // live_ranges is indexed by VRegId. let mut asm_local = Assembler::new_with_asm(&self); let iterator = &mut self.instruction_iterator(); let asm = &mut asm_local; - let live_ranges = take(&mut self.live_ranges); + let live_ranges = self.compute_live_ranges(); + + let mut current_block_id = asm.current_block().id; + let alloc_block_params = |block: &BasicBlock, pool: &mut RegisterPool, vreg_opnd: &mut Vec>| { + for param in &block.parameters { + if let Opnd::VReg { idx, num_bits } = *param { + if vreg_opnd[idx.0].is_none() { + let opnd = pool.alloc_opnd(idx).with_num_bits(num_bits); + vreg_opnd[idx.0] = Some(opnd); + } + } + } + }; + alloc_block_params(asm.current_block(), &mut pool, &mut vreg_opnd); while let Some((index, mut insn)) = iterator.next(asm) { + if asm.current_block().id != current_block_id { + current_block_id = asm.current_block().id; + alloc_block_params(asm.current_block(), &mut pool, &mut vreg_opnd); + } // Remember the index of FrameSetup to bump slot_count when we know the max number of spilled VRegs. if let Insn::FrameSetup { .. } = insn { assert!(asm.current_block().is_entry); diff --git a/zjit/src/backend/tests.rs b/zjit/src/backend/tests.rs index 32b6fe9b5ef31e..9842a058c2b4c1 100644 --- a/zjit/src/backend/tests.rs +++ b/zjit/src/backend/tests.rs @@ -2,6 +2,7 @@ use crate::asm::CodeBlock; use crate::backend::lir::*; use crate::cruby::*; use crate::codegen::c_callable; +use crate::hir; use crate::options::rb_zjit_prepare_options; #[test] @@ -63,6 +64,28 @@ fn test_alloc_regs() { } } +#[test] +fn test_alloc_regs_across_blocks() { + rb_zjit_prepare_options(); // for asm.alloc_regs + let mut asm = Assembler::new(); + let entry = asm.new_block(hir::BlockId(0), true, 0); + let exit = asm.new_block(hir::BlockId(1), false, 1); + + asm.set_current_block(entry); + let entry_label = asm.new_label("entry_block"); + asm.write_label(entry_label); + let v0 = asm.add(EC, Opnd::UImm(1)); + asm.jmp(Target::Block(BranchEdge { target: exit, args: vec![] })); + + asm.set_current_block(exit); + let exit_label = asm.new_label("exit_block"); + asm.write_label(exit_label); + let v1 = asm.add(v0, Opnd::UImm(2)); + asm.cret(v1); + + asm.alloc_regs(Assembler::get_alloc_regs()).unwrap(); +} + fn setup_asm() -> (Assembler, CodeBlock) { rb_zjit_prepare_options(); // for get_option! on asm.compile let mut asm = Assembler::new(); diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index 138df69008fe2c..89eec60b6a3d08 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -1,4 +1,4 @@ -use std::mem::{self, take}; +use std::mem; use crate::asm::*; use crate::asm::x86_64::*; @@ -140,7 +140,7 @@ impl Assembler { { let mut asm_local = Assembler::new_with_asm(&self); let asm = &mut asm_local; - let live_ranges = take(&mut self.live_ranges); + let live_ranges = self.compute_live_ranges(); let mut iterator = self.instruction_iterator(); while let Some((index, mut insn)) = iterator.next(asm) { diff --git a/zjit/src/bitset.rs b/zjit/src/bitset.rs index b5b69abdeeb62d..2f72ac6bf85629 100644 --- a/zjit/src/bitset.rs +++ b/zjit/src/bitset.rs @@ -6,7 +6,7 @@ const ENTRY_NUM_BITS: usize = Entry::BITS as usize; // TODO(max): Make a `SmallBitSet` and `LargeBitSet` and switch between them if `num_bits` fits in // `Entry`. -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub struct BitSet + Copy> { entries: Vec, num_bits: usize, @@ -25,8 +25,8 @@ impl + Copy> BitSet { debug_assert!(idx.into() < self.num_bits); let entry_idx = idx.into() / ENTRY_NUM_BITS; let bit_idx = idx.into() % ENTRY_NUM_BITS; - let newly_inserted = (self.entries[entry_idx] & (1 << bit_idx)) == 0; - self.entries[entry_idx] |= 1 << bit_idx; + let newly_inserted = (self.entries[entry_idx] & (1u128 << bit_idx)) == 0; + self.entries[entry_idx] |= 1u128 << bit_idx; newly_inserted } @@ -37,11 +37,49 @@ impl + Copy> BitSet { } } + pub fn remove(&mut self, idx: T) -> bool { + debug_assert!(idx.into() < self.num_bits); + let entry_idx = idx.into() / ENTRY_NUM_BITS; + let bit_idx = idx.into() % ENTRY_NUM_BITS; + let mask = 1u128 << bit_idx; + let was_set = (self.entries[entry_idx] & mask) != 0; + self.entries[entry_idx] &= !mask; + was_set + } + pub fn get(&self, idx: T) -> bool { debug_assert!(idx.into() < self.num_bits); let entry_idx = idx.into() / ENTRY_NUM_BITS; let bit_idx = idx.into() % ENTRY_NUM_BITS; - (self.entries[entry_idx] & (1 << bit_idx)) != 0 + (self.entries[entry_idx] & (1u128 << bit_idx)) != 0 + } + + /// Modify `self` to include any bits set in `other`. Returns true if `self` was modified, + /// and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn union_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] |= other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Modify `self` to clear bits that are set in `other`. Returns true if `self` was modified, + /// and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn subtract_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] &= !other.entries[i]; + changed |= self.entries[i] != before; + } + changed } /// Modify `self` to only have bits set if they are also set in `other`. Returns true if `self` @@ -57,6 +95,50 @@ impl + Copy> BitSet { } changed } + + pub fn iter_indices(&self) -> BitSetIter<'_> { + BitSetIter { + entries: &self.entries, + num_bits: self.num_bits, + entry_idx: 0, + base_idx: 0, + remaining: 0, + } + } +} + +pub struct BitSetIter<'a> { + entries: &'a [Entry], + num_bits: usize, + entry_idx: usize, + base_idx: usize, + remaining: Entry, +} + +impl<'a> Iterator for BitSetIter<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + loop { + if self.remaining != 0 { + let tz = self.remaining.trailing_zeros() as usize; + let idx = self.base_idx + tz; + self.remaining &= !(1u128 << tz); + if idx < self.num_bits { + return Some(idx); + } + continue; + } + + if self.entry_idx >= self.entries.len() { + return None; + } + + self.remaining = self.entries[self.entry_idx]; + self.base_idx = self.entry_idx * ENTRY_NUM_BITS; + self.entry_idx += 1; + } + } } #[cfg(test)] diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 2fe9958e62e211..b9d85e958b42ab 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -305,7 +305,12 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func for (idx, &insn_id) in block.params().enumerate() { match function.find(insn_id) { Insn::Param => { - jit.opnds[insn_id.0] = Some(gen_param(&mut asm, idx)); + let is_entry_block = function.is_entry_block(block_id); + let param_opnd = gen_param(&mut asm, idx, is_entry_block); + if !is_entry_block { + asm.current_block().add_parameter(param_opnd); + } + jit.opnds[insn_id.0] = Some(param_opnd); }, insn => unreachable!("Non-param insn found in block.params: {insn:?}"), } @@ -1269,13 +1274,11 @@ fn gen_const_uint32(val: u32) -> lir::Opnd { } /// Compile a basic block argument -fn gen_param(asm: &mut Assembler, idx: usize) -> lir::Opnd { - // Allocate a register or a stack slot - match Assembler::param_opnd(idx) { - // If it's a register, insert LiveReg instruction to reserve the register - // in the register pool for register allocation. - param @ Opnd::Reg(_) => asm.live_reg_opnd(param), - param => param, +fn gen_param(asm: &mut Assembler, idx: usize, is_entry_block: bool) -> lir::Opnd { + if is_entry_block { + asm.load(Assembler::param_opnd(idx)) + } else { + asm.new_param_opnd() } } diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index e83ba4a6fa8414..711747bb5bc83e 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -5392,9 +5392,8 @@ impl Function { Ok(()) } - // Validate that every instruction use is from a block-local definition, which is a temporary - // constraint until we get a global register allocator. - // TODO(tenderworks): Remove this + // Validate that every instruction use is from a block-local definition. + // Kept for tests to exercise the previous block-local constraint. fn temporary_validate_block_local_definite_assignment(&self) -> Result<(), ValidationError> { for block in self.rpo() { let mut assigned = InsnSet::with_capacity(self.insns.len()); @@ -5728,7 +5727,6 @@ impl Function { pub fn validate(&self) -> Result<(), ValidationError> { self.validate_block_terminators_and_jumps()?; self.validate_definite_assignment()?; - self.temporary_validate_block_local_definite_assignment()?; self.validate_insn_uniqueness()?; self.validate_types()?; Ok(()) @@ -8036,6 +8034,18 @@ mod validation_tests { assert_matches_err(function.validate_definite_assignment(), ValidationError::OperandNotDefined(entry, val, dangling)); } + #[test] + fn allow_cross_block_use_with_dominance() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let target = function.new_block(0); + let value = function.push_insn(entry, Insn::Const { val: Const::CBool(true) }); + function.push_insn(entry, Insn::Jump(BranchEdge { target, args: vec![] })); + function.push_insn(target, Insn::ArrayDup { val: value, state: value }); + function.push_insn(target, Insn::Return { val: value }); + function.validate().unwrap(); + } + #[test] fn not_defined_within_bb_block_local() { let mut function = Function::new(std::ptr::null());