diff --git a/build/instructions_template.rs b/build/instructions_template.rs index 990daa733..517df8f64 100644 --- a/build/instructions_template.rs +++ b/build/instructions_template.rs @@ -660,6 +660,8 @@ enum InstructionTemplate { // cut instruction #[strum_discriminants(strum(props(Arity = "1", Name = "cut")))] Cut(RegType), + #[strum_discriminants(strum(props(Arity = "1", Name = "cut_prev")))] + CutPrev(RegType), #[strum_discriminants(strum(props(Arity = "1", Name = "get_level")))] GetLevel(RegType), #[strum_discriminants(strum(props(Arity = "1", Name = "get_prev_level")))] @@ -1333,6 +1335,10 @@ fn generate_instruction_preface() -> TokenStream { let rt_stub = reg_type_into_functor(r); functor!(atom!("cut"), [str(h, 0)], [rt_stub]) } + &Instruction::CutPrev(r) => { + let rt_stub = reg_type_into_functor(r); + functor!(atom!("cut_prev"), [str(h, 0)], [rt_stub]) + } &Instruction::GetLevel(r) => { let rt_stub = reg_type_into_functor(r); functor!(atom!("get_level"), [str(h, 0)], [rt_stub]) diff --git a/src/codegen.rs b/src/codegen.rs index e4e75006c..d918de528 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -950,10 +950,15 @@ impl<'b> CodeGenerator<'b> { code.push_back(instr!("proceed")); } } - &QueryTerm::LocalCut(var_num) => { + &QueryTerm::LocalCut { var_num, cut_prev } => { let code = branch_code_stack.code(code); let r = self.marker.get_binding(var_num); - code.push_back(instr!("cut", r)); + + code.push_back(if cut_prev { + instr!("cut_prev", r) + } else { + instr!("cut", r) + }); if self.marker.in_tail_position { if self.marker.var_data.allocates { diff --git a/src/forms.rs b/src/forms.rs index 9280e5350..6b8771976 100644 --- a/src/forms.rs +++ b/src/forms.rs @@ -202,7 +202,7 @@ pub enum QueryTerm { // register, clause type, subterms, clause call policy. Clause(Cell, ClauseType, Vec, CallPolicy), Fail, - LocalCut(usize), // var_num + LocalCut { var_num: usize, cut_prev: bool }, // var_num GlobalCut(usize), // var_num GetCutPoint { var_num: usize, prev_b: bool }, GetLevel(usize), // var_num diff --git a/src/machine/disjuncts.rs b/src/machine/disjuncts.rs index a8a01d364..f68ee75c8 100644 --- a/src/machine/disjuncts.rs +++ b/src/machine/disjuncts.rs @@ -154,8 +154,11 @@ enum TraversalState { Fail, GetCutPoint { var_num: usize, prev_b: bool }, Cut { var_num: usize, is_global: bool }, + CutPrev(usize), ResetCallPolicy(CallPolicy), Term(Term), + OverrideGlobalCutVar(usize), + ResetGlobalCutVarOverride(Option), RemoveBranchNum, // pop the current_branch_num and from the root set. AddBranchNum(BranchNumber), // set current_branch_num, add it to the root set RepBranchNum(BranchNumber), // replace current_branch_num and the latest in the root set @@ -171,6 +174,7 @@ pub struct VariableClassifier { var_num: usize, root_set: RootSet, global_cut_var_num: Option, + global_cut_var_num_override: Option, } #[derive(Debug, Default)] @@ -252,6 +256,7 @@ impl VariableClassifier { root_set: RootSet::new(), var_num: 0, global_cut_var_num: None, + global_cut_var_num_override: None, } } @@ -517,6 +522,12 @@ impl VariableClassifier { self.probe_in_situ_var(var_num); build_stack.push_chunk_term(QueryTerm::GetCutPoint { var_num, prev_b }); } + TraversalState::OverrideGlobalCutVar(var_num) => { + self.global_cut_var_num_override = Some(var_num); + } + TraversalState::ResetGlobalCutVarOverride(old_override) => { + self.global_cut_var_num_override = old_override; + } TraversalState::Cut { var_num, is_global } => { if self.try_set_chunk_at_inlined_boundary() { build_stack.add_chunk(); @@ -527,9 +538,18 @@ impl VariableClassifier { build_stack.push_chunk_term(if is_global { QueryTerm::GlobalCut(var_num) } else { - QueryTerm::LocalCut(var_num) + QueryTerm::LocalCut { var_num, cut_prev: false } }); } + TraversalState::CutPrev(var_num) => { + if self.try_set_chunk_at_inlined_boundary() { + build_stack.add_chunk(); + } + + self.probe_in_situ_var(var_num); + + build_stack.push_chunk_term(QueryTerm::LocalCut { var_num, cut_prev: true }); + } TraversalState::Fail => { build_stack.push_chunk_term(QueryTerm::Fail); } @@ -684,14 +704,13 @@ impl VariableClassifier { ))); state_stack.push(TraversalState::BuildDisjunct(build_stack_len)); state_stack.push(TraversalState::Fail); - state_stack.push(TraversalState::Cut { - var_num: self.var_num, - is_global: false, - }); + state_stack.push(TraversalState::CutPrev(self.var_num)); + state_stack.push(TraversalState::ResetGlobalCutVarOverride(self.global_cut_var_num_override)); state_stack.push(TraversalState::Term(not_term)); + state_stack.push(TraversalState::OverrideGlobalCutVar(self.var_num)); state_stack.push(TraversalState::GetCutPoint { var_num: self.var_num, - prev_b: true, + prev_b: false, }); self.current_chunk_type = ChunkType::Mid; @@ -786,17 +805,23 @@ impl VariableClassifier { )); } Term::Literal(_, Literal::Atom(atom!("!")) | Literal::Char('!')) => { - if self.global_cut_var_num.is_none() { - self.global_cut_var_num = Some(self.var_num); - self.var_num += 1; - } + let (var_num, is_global) = + if let Some(var_num) = self.global_cut_var_num_override { + (var_num, false) + } else if let Some(var_num) = self.global_cut_var_num { + (var_num, true) + } else { + let var_num = self.var_num; - self.probe_in_situ_var(self.global_cut_var_num.unwrap()); + self.global_cut_var_num = Some(var_num); + self.var_num += 1; - state_stack.push(TraversalState::Cut { - var_num: self.global_cut_var_num.unwrap(), - is_global: true, - }); + (var_num, true) + }; + + self.probe_in_situ_var(var_num); + + state_stack.push(TraversalState::Cut { var_num, is_global }); } Term::Literal(_, Literal::Atom(name)) => { if update_chunk_data(self, name, 0) { diff --git a/src/machine/dispatch.rs b/src/machine/dispatch.rs index e04961d93..5722484f2 100644 --- a/src/machine/dispatch.rs +++ b/src/machine/dispatch.rs @@ -1188,6 +1188,21 @@ impl Machine { self.machine_st.p += 1; } + &Instruction::CutPrev(r) => { + let value = self.machine_st[r]; + self.machine_st.cut_prev_body(value); + + if self.machine_st.fail { + self.machine_st.backtrack(); + continue; + } + + if (self.machine_st.run_cleaners_fn)(self) { + continue; + } + + self.machine_st.p += 1; + } &Instruction::Allocate(num_cells) => { self.machine_st.allocate(num_cells); } diff --git a/src/machine/machine_state.rs b/src/machine/machine_state.rs index 55a09ef54..033b30908 100644 --- a/src/machine/machine_state.rs +++ b/src/machine/machine_state.rs @@ -941,6 +941,25 @@ impl MachineState { } ); } + + #[inline(always)] + pub(super) fn cut_prev_body(&mut self, value: HeapCellValue) { + let b = self.b; + + read_heap_cell!(value, + (HeapCellValueTag::CutPoint, b0) => { + let b0 = b0.get_num() as usize; + let b0 = self.stack.index_or_frame(b0).prelude.b; + + if b > b0 { + self.b = b0; + } + } + _ => { + self.fail = true; + } + ); + } } #[derive(Debug)]