From bfbcbdb3f57a45c3faab7d77fe9256aba566b59b Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Tue, 4 Feb 2025 20:39:29 +0000 Subject: [PATCH] Track nested meta-tracing states. Previously we -- well, this one is entirely my fault, so "I"! -- tracked per-thread meta-tracing state as a single `MTThreadState` that we updates as necessary. This doesn't work when we have nested execution, tracing and the like, as the bug in #1571 highlighted. This commit moves us to a stack of `MTThreadState`s. The basic idea is that the stack always has at least one element: `Interpreting`. As we go through other states, we push / pop as appropriate. The test below (from Edd and Lukas) fails on `master` but is fixed by this commit. The implementation is a bit more awkward than one might hope as naive implementations either: 1. Spread lots of knowledge about the stack around the code. That's a disaster waiting to happen. 2. Run into borrow checker problems. This commit gets around this in two phases: 1. We pass closures to `MTThread` which can peek at the current `MTThreadState` (which isn't `Copy`!). 2. Those closures return the "things" we need to update outside that context. This is a bit awkward, and perhaps there's a better API, but this one is at least safe. Co-authored-by: Edd Barrett Co-authored-by: Lukas Diekmann --- tests/c/nested_execution.c | 134 ++++++++ ykrt/src/compile/jitc_yk/codegen/x64/deopt.rs | 9 +- ykrt/src/mt.rs | 301 +++++++++++------- 3 files changed, 319 insertions(+), 125 deletions(-) create mode 100644 tests/c/nested_execution.c diff --git a/tests/c/nested_execution.c b/tests/c/nested_execution.c new file mode 100644 index 000000000..f68599065 --- /dev/null +++ b/tests/c/nested_execution.c @@ -0,0 +1,134 @@ +// Run-time: +// env-var: YKD_SERIALISE_COMPILATION=1 +// env-var: YKD_LOG_IR=jit-pre-opt +// env-var: YK_LOG=4 +// stderr: +// enter +// yk-jit-event: start-tracing +// 6 +// enter +// 5 +// 4 +// 3 +// 2 +// 1 +// return +// yk-jit-event: stop-tracing +// --- Begin jit-pre-opt --- +// ... +// guard true, ... +// ... +// guard false, ... +// ... +// guard false, ... +// ... +// guard false, ... +// ... +// guard true, ... +// ... +// guard true, ... +// ... +// --- End jit-pre-opt --- +// 5 +// enter +// yk-jit-event: start-tracing +// 4 +// yk-jit-event: stop-tracing +// --- Begin jit-pre-opt --- +// ... +// guard false, ... +// ... +// guard false, ... +// ... +// guard true, ... +// ... +// --- End jit-pre-opt --- +// 3 +// yk-jit-event: enter-jit-code +// 2 +// 1 +// yk-jit-event: deoptimise +// return +// yk-jit-event: enter-jit-code +// 4 +// enter +// yk-jit-event: enter-jit-code +// 3 +// 2 +// 1 +// yk-jit-event: deoptimise +// return +// yk-jit-event: deoptimise +// c +// 3 +// enter +// yk-jit-event: enter-jit-code +// 2 +// 1 +// yk-jit-event: deoptimise +// return +// yk-jit-event: enter-jit-code +// yk-jit-event: deoptimise +// b +// 2 +// enter +// yk-jit-event: enter-jit-code +// 1 +// yk-jit-event: deoptimise +// return +// yk-jit-event: enter-jit-code +// yk-jit-event: deoptimise +// a +// 1 +// enter +// return +// return + +// Check that recursive execution finds the right guards. + +#include +#include +#include +#include +#include +#include + +void f(YkMT *mt, int who, YkLocation *loc1, YkLocation *loc2, int i) { + fprintf(stderr, "enter\n"); + while (i > 0) { + yk_mt_control_point(mt, loc1); + if (who) { + if (i == 1) { + fprintf(stderr, "a\n"); + } + if (i == 2) { + fprintf(stderr, "b\n"); + } + if (i == 3) { + fprintf(stderr, "c\n"); + } + } + fprintf(stderr, "%d\n", i); + i -= 1; + if (loc2 != NULL) { + f(mt, 0, loc2, NULL, i); + } + } + fprintf(stderr, "return\n"); +} + +int main(int argc, char **argv) { + YkMT *mt = yk_mt_new(NULL); + yk_mt_hot_threshold_set(mt, 0); + YkLocation loc1 = yk_location_new(); + YkLocation loc2 = yk_location_new(); + int i = 6; + NOOPT_VAL(loc1); + NOOPT_VAL(loc2); + NOOPT_VAL(i); + f(mt, 1, &loc1, &loc2, i); + yk_location_drop(loc1); + yk_location_drop(loc2); + yk_mt_shutdown(mt); + return (EXIT_SUCCESS); +} diff --git a/ykrt/src/compile/jitc_yk/codegen/x64/deopt.rs b/ykrt/src/compile/jitc_yk/codegen/x64/deopt.rs index 7cf565dc4..377cc8cde 100644 --- a/ykrt/src/compile/jitc_yk/codegen/x64/deopt.rs +++ b/ykrt/src/compile/jitc_yk/codegen/x64/deopt.rs @@ -101,8 +101,8 @@ pub(crate) extern "C" fn __yk_deopt( let info = &ctr.deoptinfo[&usize::from(gidx)]; let mt = Arc::clone(&ctr.mt); - ctr.mt - .stats + mt.deopt(); + mt.stats .timing_state(crate::log::stats::TimingState::Deopting); mt.log.log(Verbosity::JITEvent, "deoptimise"); @@ -352,9 +352,8 @@ pub(crate) extern "C" fn __yk_deopt( // The `clone` should really be `Arc::clone(&ctr)` but that doesn't play well with type // inference in this (unusual) case. - ctr.mt.guard_failure(ctr.clone(), gidx, frameaddr); - ctr.mt - .stats + mt.guard_failure(ctr.clone(), gidx, frameaddr); + mt.stats .timing_state(crate::log::stats::TimingState::OutsideYk); // Since we won't return from this function, drop `ctr` manually. diff --git a/ykrt/src/mt.rs b/ykrt/src/mt.rs index 390a5d1b7..20c36aa15 100644 --- a/ykrt/src/mt.rs +++ b/ykrt/src/mt.rs @@ -1,7 +1,7 @@ //! The main end-user interface to the meta-tracing system. use std::{ - assert_matches::debug_assert_matches, + assert_matches::{assert_matches, debug_assert_matches}, cell::RefCell, cmp, collections::VecDeque, @@ -409,13 +409,10 @@ impl MT { match self.transition_control_point(loc, frameaddr) { TransitionControlPoint::NoAction => (), TransitionControlPoint::AbortTracing => { - let thread_tracer = - MTThread::with( - |mtt| match mtt.tstate.replace(MTThreadState::Interpreting) { - MTThreadState::Tracing { thread_tracer, .. } => thread_tracer, - _ => unreachable!(), - }, - ); + let thread_tracer = MTThread::with(|mtt| match mtt.pop_tstate() { + MTThreadState::Tracing { thread_tracer, .. } => thread_tracer, + _ => unreachable!(), + }); thread_tracer.stop().ok(); self.log.log(Verbosity::JITEvent, "tracing-aborted"); } @@ -434,7 +431,7 @@ impl MT { } let trace_addr = ctr.entry(); MTThread::with(|mtt| { - mtt.set_running_trace(ctr); + mtt.push_tstate(MTThreadState::Executing { ctr }); }); self.stats.timing_state(TimingState::JitExecuting); @@ -450,13 +447,13 @@ impl MT { }; match Arc::clone(&tracer).start_recorder() { Ok(tt) => MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Tracing { + mtt.push_tstate(MTThreadState::Tracing { hl, thread_tracer: tt, promotions: Vec::new(), debug_strs: Vec::new(), frameaddr, - }; + }); }), Err(e) => { // FIXME: start_recorder needs a way of signalling temporary errors. @@ -486,18 +483,16 @@ impl MT { // Assuming no bugs elsewhere, the `unwrap`s cannot fail, because `StartTracing` // will have put a `Some` in the `Rc`. let (hl, thread_tracer, promotions, debug_strs) = - MTThread::with( - |mtt| match mtt.tstate.replace(MTThreadState::Interpreting) { - MTThreadState::Tracing { - hl, - thread_tracer, - promotions, - debug_strs, - frameaddr: _, - } => (hl, thread_tracer, promotions, debug_strs), - _ => unreachable!(), - }, - ); + MTThread::with(|mtt| match mtt.pop_tstate() { + MTThreadState::Tracing { + hl, + thread_tracer, + promotions, + debug_strs, + frameaddr: _, + } => (hl, thread_tracer, promotions, debug_strs), + _ => unreachable!(), + }); match thread_tracer.stop() { Ok(utrace) => { self.stats.timing_state(TimingState::None); @@ -523,21 +518,19 @@ impl MT { // Assuming no bugs elsewhere, the `unwrap`s cannot fail, because // `StartSideTracing` will have put a `Some` in the `Rc`. let (hl, thread_tracer, promotions, debug_strs) = - MTThread::with( - |mtt| match mtt.tstate.replace(MTThreadState::Interpreting) { - MTThreadState::Tracing { - hl, - thread_tracer, - promotions, - debug_strs, - frameaddr: tracing_frameaddr, - } => { - assert_eq!(frameaddr, tracing_frameaddr); - (hl, thread_tracer, promotions, debug_strs) - } - _ => unreachable!(), - }, - ); + MTThread::with(|mtt| match mtt.pop_tstate() { + MTThreadState::Tracing { + hl, + thread_tracer, + promotions, + debug_strs, + frameaddr: tracing_frameaddr, + } => { + assert_eq!(frameaddr, tracing_frameaddr); + (hl, thread_tracer, promotions, debug_strs) + } + _ => unreachable!(), + }); self.stats.timing_state(TimingState::TraceMapping); match thread_tracer.stop() { Ok(utrace) => { @@ -643,7 +636,7 @@ impl MT { } HotLocationKind::Tracing => { let hl = loc.hot_location_arc_clone().unwrap(); - match &*mtt.tstate.borrow() { + let (lk_kind, rtn) = mtt.peek_tstate(|tstate| match tstate { MTThreadState::Tracing { hl: thread_hl, frameaddr: tracing_frameaddr, @@ -652,7 +645,7 @@ impl MT { // This thread is tracing something... if !Arc::ptr_eq(thread_hl, &hl) { // ...but not this Location. - TransitionControlPoint::NoAction + (None, TransitionControlPoint::NoAction) } else { // ...and it's this location... match STACK_DIRECTION { @@ -663,8 +656,10 @@ impl MT { // within the same frame, or in a recursive // call. Either way, we do not want to unroll // the loop / recursion! - lk.kind = HotLocationKind::Compiling; - TransitionControlPoint::StopTracing + ( + Some(HotLocationKind::Compiling), + TransitionControlPoint::StopTracing, + ) } else { debug_assert!(frameaddr > *tracing_frameaddr); // We fell through to a caller frame. In other @@ -681,12 +676,12 @@ impl MT { // at a more propitious point in the future. self.stats.trace_recorded_err(); match lk.tracecompilation_error(self) { - TraceFailed::KeepTrying => { - lk.kind = HotLocationKind::Counting(0) - } + TraceFailed::KeepTrying => ( + Some(HotLocationKind::Counting(0)), + TransitionControlPoint::AbortTracing, + ), TraceFailed::DontTrace => todo!(), } - TransitionControlPoint::AbortTracing } } } @@ -705,22 +700,28 @@ impl MT { // Another thread was tracing this location but it's terminated. self.stats.trace_recorded_err(); match lk.tracecompilation_error(self) { - TraceFailed::KeepTrying => { - lk.kind = HotLocationKind::Tracing; - TransitionControlPoint::StartTracing(hl) - } + TraceFailed::KeepTrying => ( + Some(HotLocationKind::Tracing), + TransitionControlPoint::StartTracing(hl), + ), TraceFailed::DontTrace => { // FIXME: This is stupidly brutal. - lk.kind = HotLocationKind::DontTrace; - TransitionControlPoint::NoAction + ( + Some(HotLocationKind::DontTrace), + TransitionControlPoint::NoAction, + ) } } } else { // Another thread is tracing this location. - TransitionControlPoint::NoAction + (None, TransitionControlPoint::NoAction) } } + }); + if let Some(lk_kind) = lk_kind { + lk.kind = lk_kind; } + rtn } HotLocationKind::SideTracing { ref root_ctr, @@ -728,7 +729,7 @@ impl MT { ref parent_ctr, } => { let hl = loc.hot_location_arc_clone().unwrap(); - match &*mtt.tstate.borrow() { + let (lk_kind, rtn) = mtt.peek_tstate(move |tstate| match tstate { MTThreadState::Tracing { hl: thread_hl, frameaddr: tracing_frameaddr, @@ -737,7 +738,7 @@ impl MT { // This thread is tracing something... if !Arc::ptr_eq(thread_hl, &hl) { // ...but not this Location. - TransitionControlPoint::NoAction + (None, TransitionControlPoint::NoAction) } else { match STACK_DIRECTION { StackDirection::GrowsToHigherAddress => todo!(), @@ -750,15 +751,16 @@ impl MT { // to unroll the loop / recursion! let parent_ctr = Arc::clone(parent_ctr); let root_ctr_cl = Arc::clone(root_ctr); - lk.kind = HotLocationKind::Compiled( - Arc::clone(root_ctr), - ); - drop(lk); - TransitionControlPoint::StopSideTracing { - gidx, - parent_ctr, - root_ctr: root_ctr_cl, - } + ( + Some(HotLocationKind::Compiled( + Arc::clone(root_ctr), + )), + TransitionControlPoint::StopSideTracing { + gidx, + parent_ctr, + root_ctr: root_ctr_cl, + }, + ) } else { debug_assert!(frameaddr > *tracing_frameaddr); // We fell through to a caller frame. In other @@ -774,10 +776,12 @@ impl MT { // instead abort tracing, and hope we can start // at a more propitious point in the future. self.stats.trace_recorded_err(); - lk.kind = HotLocationKind::Compiled( - Arc::clone(root_ctr), - ); - TransitionControlPoint::AbortTracing + ( + Some(HotLocationKind::Compiled( + Arc::clone(root_ctr), + )), + TransitionControlPoint::AbortTracing, + ) } } } @@ -786,9 +790,13 @@ impl MT { _ => { // This thread isn't tracing anything. assert!(!is_tracing); - TransitionControlPoint::Execute(Arc::clone(root_ctr)) + (None, TransitionControlPoint::Execute(Arc::clone(root_ctr))) } + }); + if let Some(lk_kind) = lk_kind { + lk.kind = lk_kind; } + rtn } HotLocationKind::DontTrace => TransitionControlPoint::NoAction, } @@ -867,6 +875,15 @@ impl MT { } } + /// Inform this `MT` instance that `deopt` has occurred: this updates the stack of + /// [MTThreadState]s. + pub(crate) fn deopt(self: &Arc) { + MTThread::with(|mtt| { + let st = mtt.pop_tstate(); + assert_matches!(st, MTThreadState::Executing { .. }); + }); + } + /// Inform this meta-tracer that guard `gidx` has failed. /// // FIXME: Don't side trace the last guard of a side-trace as this guard always fails. @@ -888,13 +905,13 @@ impl MT { }; match Arc::clone(&tracer).start_recorder() { Ok(tt) => MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Tracing { + mtt.push_tstate(MTThreadState::Tracing { hl, thread_tracer: tt, promotions: Vec::new(), debug_strs: Vec::new(), frameaddr, - }; + }) }), Err(e) => todo!("{e:?}"), } @@ -983,8 +1000,9 @@ enum MTThreadState { /// Meta-tracer per-thread state. Note that this struct is neither `Send` nor `Sync`: it can only /// be accessed from within a single thread. pub struct MTThread { - /// Where in the "interpreting/tracing/executing" is this thread? - tstate: RefCell, + /// Where in the "interpreting/tracing/executing" is this thread? This `Vec` always has at + /// least 1 element in it. It should not be access directly: use the `*_tstate` methods. + tstate: RefCell>, // Raw pointers are neither send nor sync. _dont_send_or_sync_me: PhantomData<*mut ()>, } @@ -992,7 +1010,7 @@ pub struct MTThread { impl MTThread { fn new() -> Self { MTThread { - tstate: RefCell::new(MTThreadState::Interpreting), + tstate: RefCell::new(vec![MTThreadState::Interpreting]), _dont_send_or_sync_me: PhantomData, } } @@ -1011,20 +1029,51 @@ impl MTThread { /// Is this thread currently tracing something? pub(crate) fn is_tracing(&self) -> bool { - matches!(&*self.tstate.borrow(), &MTThreadState::Tracing { .. }) + matches!( + self.tstate.borrow().last().unwrap(), + &MTThreadState::Tracing { .. } + ) } - /// If a trace is currently running, return a reference to its `CompiledTrace`. + /// If a trace is currently running, return a reference to its [CompiledTrace]. pub(crate) fn running_trace(&self) -> Option> { - match &*self.tstate.borrow() { + match self.tstate.borrow().last().unwrap() { MTThreadState::Executing { ctr } => Some(Arc::clone(ctr)), _ => None, } } - /// Update the currently running trace. - pub(crate) fn set_running_trace(&self, ctr: Arc) { - *self.tstate.borrow_mut() = MTThreadState::Executing { ctr }; + /// Run the closure `f` and pass it an immutable reference to the last element on the stack of + /// [MTThreadState]s. + fn peek_tstate(&self, f: F) -> T + where + F: FnOnce(&MTThreadState) -> T, + { + f(self.tstate.borrow().last().unwrap()) + } + + /// Run the closure `f` and pass it a mutable reference to the last element on the + /// stack of [MTThreadState]s. + fn peek_mut_tstate(&self, f: F) -> T + where + F: FnOnce(&mut MTThreadState) -> T, + { + f(self.tstate.borrow_mut().last_mut().unwrap()) + } + + /// Pop the last element from the stack of [MTThreadState]s and return it. + /// + /// # Panics + /// + /// If this would remove the last [MTThreadState] from the stack. + fn pop_tstate(&self) -> MTThreadState { + debug_assert!(self.tstate.borrow_mut().len() > 1); + self.tstate.borrow_mut().pop().unwrap() + } + + /// Push `tstate` to the end of the stack of [MTThreadState]s. + fn push_tstate(&self, tstate: MTThreadState) { + self.tstate.borrow_mut().push(tstate); } /// Records `val` as a value to be promoted. Returns `true` if either: no trace is being @@ -1034,12 +1083,14 @@ impl MTThread { /// and further calls are probably pointless, though they will not cause the tracer to enter /// undefined behaviour territory. pub(crate) fn promote_i32(&self, val: i32) -> bool { - if let MTThreadState::Tracing { - ref mut promotions, .. - } = *self.tstate.borrow_mut() - { - promotions.extend_from_slice(&val.to_ne_bytes()); - } + self.peek_mut_tstate(|tstate| { + if let MTThreadState::Tracing { + ref mut promotions, .. + } = tstate + { + promotions.extend_from_slice(&val.to_ne_bytes()); + } + }); true } @@ -1050,12 +1101,14 @@ impl MTThread { /// and further calls are probably pointless, though they will not cause the tracer to enter /// undefined behaviour territory. pub(crate) fn promote_u32(&self, val: u32) -> bool { - if let MTThreadState::Tracing { - ref mut promotions, .. - } = *self.tstate.borrow_mut() - { - promotions.extend_from_slice(&val.to_ne_bytes()); - } + self.peek_mut_tstate(|tstate| { + if let MTThreadState::Tracing { + ref mut promotions, .. + } = tstate + { + promotions.extend_from_slice(&val.to_ne_bytes()); + } + }); true } @@ -1066,12 +1119,14 @@ impl MTThread { /// and further calls are probably pointless, though they will not cause the tracer to enter /// undefined behaviour territory. pub(crate) fn promote_i64(&self, val: i64) -> bool { - if let MTThreadState::Tracing { - ref mut promotions, .. - } = *self.tstate.borrow_mut() - { - promotions.extend_from_slice(&val.to_ne_bytes()); - } + self.peek_mut_tstate(|tstate| { + if let MTThreadState::Tracing { + ref mut promotions, .. + } = tstate + { + promotions.extend_from_slice(&val.to_ne_bytes()); + } + }); true } @@ -1082,23 +1137,27 @@ impl MTThread { /// and further calls are probably pointless, though they will not cause the tracer to enter /// undefined behaviour territory. pub(crate) fn promote_usize(&self, val: usize) -> bool { - if let MTThreadState::Tracing { - ref mut promotions, .. - } = *self.tstate.borrow_mut() - { - promotions.extend_from_slice(&val.to_ne_bytes()); - } + self.peek_mut_tstate(|tstate| { + if let MTThreadState::Tracing { + ref mut promotions, .. + } = tstate + { + promotions.extend_from_slice(&val.to_ne_bytes()); + } + }); true } /// Record a debug string. pub fn insert_debug_str(&self, msg: String) -> bool { - if let MTThreadState::Tracing { - ref mut debug_strs, .. - } = *self.tstate.borrow_mut() - { - debug_strs.push(msg); - } + self.peek_mut_tstate(|tstate| { + if let MTThreadState::Tracing { + ref mut debug_strs, .. + } = tstate + { + debug_strs.push(msg); + } + }); true } } @@ -1169,13 +1228,13 @@ mod tests { panic!() }; MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Tracing { + mtt.push_tstate(MTThreadState::Tracing { hl, thread_tracer: Box::new(DummyTraceRecorder), promotions: Vec::new(), debug_strs: Vec::new(), frameaddr: ptr::null_mut(), - }; + }); }); } @@ -1185,7 +1244,8 @@ mod tests { panic!() }; MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Interpreting; + mtt.pop_tstate(); + mtt.push_tstate(MTThreadState::Interpreting); }); } @@ -1196,13 +1256,13 @@ mod tests { panic!() }; MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Tracing { + mtt.push_tstate(MTThreadState::Tracing { hl, thread_tracer: Box::new(DummyTraceRecorder), promotions: Vec::new(), debug_strs: Vec::new(), frameaddr: ptr::null_mut(), - }; + }); }); } @@ -1243,7 +1303,8 @@ mod tests { match mt.transition_control_point(&loc, ptr::null_mut()) { TransitionControlPoint::StopSideTracing { .. } => { MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Interpreting; + mtt.pop_tstate(); + mtt.push_tstate(MTThreadState::Interpreting); }); assert!(matches!( loc.hot_location().unwrap().lock().kind, @@ -1319,13 +1380,13 @@ mod tests { TransitionControlPoint::NoAction => (), TransitionControlPoint::StartTracing(hl) => { MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Tracing { + mtt.push_tstate(MTThreadState::Tracing { hl, thread_tracer: Box::new(DummyTraceRecorder), promotions: Vec::new(), debug_strs: Vec::new(), frameaddr: ptr::null_mut(), - }; + }); }); break; } @@ -1550,13 +1611,13 @@ mod tests { TransitionControlPoint::StartTracing(hl) => { num_starts.fetch_add(1, Ordering::Relaxed); MTThread::with(|mtt| { - *mtt.tstate.borrow_mut() = MTThreadState::Tracing { + mtt.push_tstate(MTThreadState::Tracing { hl, thread_tracer: Box::new(DummyTraceRecorder), promotions: Vec::new(), debug_strs: Vec::new(), frameaddr: ptr::null_mut(), - }; + }); }); assert!(matches!( loc.hot_location().unwrap().lock().kind,