From 8b0831f2a3544c03103c755f3a1c34fd8bab7c70 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 9 Oct 2025 17:43:52 +0200 Subject: [PATCH 01/21] Add benchmark for a fixpoint iteration with nested cycles (#1001) * Add benchmark for a fixpoint iteration with nested cycles * Fix clippy warning --- benches/dataflow.rs | 43 ++++++++++++++++++- .../src/unexpected_cycle_recovery.rs | 4 +- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index 24f5a16ee..db099c6b2 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -167,5 +167,46 @@ fn dataflow(criterion: &mut Criterion) { }); } -criterion_group!(benches, dataflow); +/// Emulates a data flow problem of the form: +/// ```py +/// self.x0 = self.x1 + self.x2 + self.x3 + self.x4 +/// self.x1 = self.x0 + self.x2 + self.x3 + self.x4 +/// self.x2 = self.x0 + self.x1 + self.x3 + self.x4 +/// self.x3 = self.x0 + self.x1 + self.x2 + self.x4 +/// self.x4 = 0 +/// ``` +fn nested(criterion: &mut Criterion) { + criterion.bench_function("converge_diverge_nested", |b| { + b.iter_batched_ref( + || { + let mut db = salsa::DatabaseImpl::new(); + + let def_x0 = Definition::new(&db, None, 0); + let def_x1 = Definition::new(&db, None, 0); + let def_x2 = Definition::new(&db, None, 0); + let def_x3 = Definition::new(&db, None, 0); + let def_x4 = Definition::new(&db, None, 0); + + let use_x0 = Use::new(&db, vec![def_x1, def_x2, def_x3, def_x4]); + let use_x1 = Use::new(&db, vec![def_x0, def_x2, def_x3, def_x4]); + let use_x2 = Use::new(&db, vec![def_x0, def_x1, def_x3, def_x4]); + let use_x3 = Use::new(&db, vec![def_x0, def_x1, def_x3, def_x4]); + + def_x0.set_base(&mut db).to(Some(use_x0)); + def_x1.set_base(&mut db).to(Some(use_x1)); + def_x2.set_base(&mut db).to(Some(use_x2)); + def_x3.set_base(&mut db).to(Some(use_x3)); + + (db, def_x0) + }, + |(db, def_x0)| { + // All symbols converge on 0. + assert_eq!(infer_definition(db, *def_x0), Type::Values(Box::from([0]))); + }, + BatchSize::LargeInput, + ); + }); +} + +criterion_group!(benches, dataflow, nested); criterion_main!(benches); diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index a1cd1e73f..8d56d54f3 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -5,7 +5,7 @@ macro_rules! unexpected_cycle_recovery { ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); + std::mem::drop(($($other_inputs,)*)); panic!("cannot recover from cycle") }}; } @@ -14,7 +14,7 @@ macro_rules! unexpected_cycle_recovery { macro_rules! unexpected_cycle_initial { ($db:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); + std::mem::drop(($($other_inputs,)*)); panic!("no cycle initial value") }}; } From ef9f9329be6923acd050c8dddd172e3bc93e8051 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 16 Oct 2025 11:23:03 +0200 Subject: [PATCH 02/21] Run fixpoint per strongly connected component (#999) * Run nested cycles in a single fixpoint iteration Fix serde attribute * Remove inline from `validate_same_iteration` * Nits * Move locking into sync table * More trying * More in progress work * More progress * Fix most parallel tests * More bugfixes * Short circuit in some cases * Short circuit in drop * Delete some unused code * A working solution * Simplify more * Avoid repeated query lookups in `transfer_lock` * Use recursion for unblocking * Fix hang in `maybe_changed_after` * Move claiming of transferred memos into a separate function * More aggressive use of attributes * Make re-entrant a const parameter * Smaller clean-ups * Only collect cycle heads one level deep * More cleanups * More docs * More comments * More documentation, cleanups * More documentation, cleanups * Remove inline attribute * Fix failing tracked structs test * Fix panic * Fix persistence test * Add test for panic in nested cycle * Allow cycle initial values same-stack * Try inlining fetch * Remove some inline attributes * Add safety comment * Clippy * Panic if `provisional_retry` runs too many times * Better handling of panics in cycles * Don't use const-generic for `REENTRANT` * More nit improvements * Remove `IterationCount::panicked` * Prefer outer most cycles in `outer_cycle` * Code review feedback * Iterate only once in panic test when running with miri --- Cargo.toml | 2 +- src/active_query.rs | 4 +- src/cancelled.rs | 1 + src/cycle.rs | 308 ++++++++++-- src/function.rs | 65 ++- src/function/execute.rs | 464 +++++++++++++----- src/function/fetch.rs | 82 ++-- src/function/maybe_changed_after.rs | 185 ++++--- src/function/memo.rs | 262 +++++----- src/function/sync.rs | 358 +++++++++++++- src/ingredient.rs | 53 +- src/key.rs | 2 +- src/runtime.rs | 140 +++++- src/runtime/dependency_graph.rs | 406 ++++++++++++++- src/tracing.rs | 16 +- src/zalsa_local.rs | 102 +++- tests/backtrace.rs | 6 +- tests/cycle.rs | 8 +- tests/cycle_tracked.rs | 2 +- tests/parallel/cycle_a_t1_b_t2.rs | 2 +- tests/parallel/cycle_a_t1_b_t2_fallback.rs | 11 +- tests/parallel/cycle_nested_deep.rs | 1 + .../parallel/cycle_nested_deep_conditional.rs | 2 +- .../cycle_nested_deep_conditional_changed.rs | 12 +- tests/parallel/cycle_nested_deep_panic.rs | 142 ++++++ tests/parallel/cycle_nested_three_threads.rs | 15 +- tests/parallel/main.rs | 3 +- 27 files changed, 2093 insertions(+), 561 deletions(-) create mode 100644 tests/parallel/cycle_nested_deep_panic.rs diff --git a/Cargo.toml b/Cargo.toml index cc1cd0347..9c419e339 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ intrusive-collections = "0.9.7" parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" -smallvec = "1" +smallvec = { version = "1", features = ["const_new"] } thin-vec = { version = "0.2.14" } tracing = { version = "0.1", default-features = false, features = ["std"] } diff --git a/src/active_query.rs b/src/active_query.rs index 0b2231052..d830fece1 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -498,7 +498,7 @@ impl fmt::Display for Backtrace { if full { write!(fmt, " -> ({changed_at:?}, {durability:#?}")?; if !cycle_heads.is_empty() || !iteration_count.is_initial() { - write!(fmt, ", iteration = {iteration_count:?}")?; + write!(fmt, ", iteration = {iteration_count}")?; } write!(fmt, ")")?; } @@ -517,7 +517,7 @@ impl fmt::Display for Backtrace { } write!( fmt, - "{:?} -> {:?}", + "{:?} -> iteration = {}", head.database_key_index, head.iteration_count )?; } diff --git a/src/cancelled.rs b/src/cancelled.rs index 2f2f315d9..3c31bae5a 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -20,6 +20,7 @@ pub enum Cancelled { } impl Cancelled { + #[cold] pub(crate) fn throw(self) -> ! { // We use resume and not panic here to avoid running the panic // hook (that is, to avoid collecting and printing backtrace). diff --git a/src/cycle.rs b/src/cycle.rs index 12cb1cdc9..c9a9b82c1 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -44,14 +44,18 @@ //! result in a stable, converged cycle. If it does not (that is, if the result of another //! iteration of the cycle is not the same as the fallback value), we'll panic. //! -//! In nested cycle cases, the inner cycle head will iterate until its own cycle is resolved, but -//! the "final" value it then returns will still be provisional on the outer cycle head. The outer -//! cycle head may then iterate, which may result in a new set of iterations on the inner cycle, -//! for each iteration of the outer cycle. - +//! In nested cycle cases, the inner cycles are iterated as part of the outer cycle iteration. This helps +//! to significantly reduce the number of iterations needed to reach a fixpoint. For nested cycles, +//! the inner cycles head will transfer their lock ownership to the outer cycle. This ensures +//! that, over time, the outer cycle will hold all necessary locks to complete the fixpoint iteration. +//! Without this, different threads would compete for the locks of inner cycle heads, leading to potential +//! hangs (but not deadlocks). + +use std::iter::FusedIterator; use thin_vec::{thin_vec, ThinVec}; use crate::key::DatabaseKeyIndex; +use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use crate::sync::OnceLock; use crate::Revision; @@ -96,14 +100,47 @@ pub enum CycleRecoveryStrategy { /// would be the cycle head. It returns an "initial value" when the cycle is encountered (if /// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the /// cycle until it converges. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Debug)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct CycleHead { pub(crate) database_key_index: DatabaseKeyIndex, - pub(crate) iteration_count: IterationCount, + pub(crate) iteration_count: AtomicIterationCount, + + /// Marks a cycle head as removed within its `CycleHeads` container. + /// + /// Cycle heads are marked as removed when the memo from the last iteration (a provisional memo) + /// is used as the initial value for the next iteration. It's necessary to remove all but its own + /// head from the `CycleHeads` container, because the query might now depend on fewer cycles + /// (in case of conditional dependencies). However, we can't actually remove the cycle head + /// within `fetch_cold_cycle` because we only have a readonly memo. That's what `removed` is used for. + #[cfg_attr(feature = "persistence", serde(skip))] + removed: AtomicBool, +} + +impl CycleHead { + pub const fn new( + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) -> Self { + Self { + database_key_index, + iteration_count: AtomicIterationCount(AtomicU8::new(iteration_count.0)), + removed: AtomicBool::new(false), + } + } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)] +impl Clone for CycleHead { + fn clone(&self) -> Self { + Self { + database_key_index: self.database_key_index, + iteration_count: self.iteration_count.load().into(), + removed: self.removed.load(Ordering::Relaxed).into(), + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "persistence", serde(transparent))] pub struct IterationCount(u8); @@ -131,11 +168,69 @@ impl IterationCount { } } +impl std::fmt::Display for IterationCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug)] +pub(crate) struct AtomicIterationCount(AtomicU8); + +impl AtomicIterationCount { + pub(crate) fn load(&self) -> IterationCount { + IterationCount(self.0.load(Ordering::Relaxed)) + } + + pub(crate) fn load_mut(&mut self) -> IterationCount { + IterationCount(*self.0.get_mut()) + } + + pub(crate) fn store(&self, value: IterationCount) { + self.0.store(value.0, Ordering::Release); + } + + pub(crate) fn store_mut(&mut self, value: IterationCount) { + *self.0.get_mut() = value.0; + } +} + +impl std::fmt::Display for AtomicIterationCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.load().fmt(f) + } +} + +impl From for AtomicIterationCount { + fn from(iteration_count: IterationCount) -> Self { + AtomicIterationCount(iteration_count.0.into()) + } +} + +#[cfg(feature = "persistence")] +impl serde::Serialize for AtomicIterationCount { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.load().serialize(serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for AtomicIterationCount { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + IterationCount::deserialize(deserializer).map(Into::into) + } +} + /// Any provisional value generated by any query in a cycle will track the cycle head(s) (can be /// plural in case of nested cycles) representing the cycles it is part of, and the current /// iteration count for each cycle head. This struct tracks these cycle heads. #[derive(Clone, Debug, Default)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct CycleHeads(ThinVec); impl CycleHeads { @@ -143,15 +238,30 @@ impl CycleHeads { self.0.is_empty() } - pub(crate) fn initial(database_key_index: DatabaseKeyIndex) -> Self { + pub(crate) fn initial( + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) -> Self { Self(thin_vec![CycleHead { database_key_index, - iteration_count: IterationCount::initial(), + iteration_count: iteration_count.into(), + removed: false.into() }]) } - pub(crate) fn iter(&self) -> std::slice::Iter<'_, CycleHead> { - self.0.iter() + pub(crate) fn iter(&self) -> CycleHeadsIterator<'_> { + CycleHeadsIterator { + inner: self.0.iter(), + } + } + + /// Iterates over all cycle heads that aren't equal to `own`. + pub(crate) fn iter_not_eq( + &self, + own: DatabaseKeyIndex, + ) -> impl DoubleEndedIterator { + self.iter() + .filter(move |head| head.database_key_index != own) } pub(crate) fn contains(&self, value: &DatabaseKeyIndex) -> bool { @@ -159,17 +269,25 @@ impl CycleHeads { .any(|head| head.database_key_index == *value) } - pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { - let found = self - .0 - .iter() - .position(|&head| head.database_key_index == *value); - let Some(found) = found else { return false }; - self.0.swap_remove(found); - true + /// Removes all cycle heads except `except` by marking them as removed. + /// + /// Note that the heads aren't actually removed. They're only marked as removed and will be + /// skipped when iterating. This is because we might not have a mutable reference. + pub(crate) fn remove_all_except(&self, except: DatabaseKeyIndex) { + for head in self.0.iter() { + if head.database_key_index == except { + continue; + } + + head.removed.store(true, Ordering::Release); + } } - pub(crate) fn update_iteration_count( + /// Updates the iteration count for the head `cycle_head_index` to `new_iteration_count`. + /// + /// Unlike [`update_iteration_count`], this method takes a `&mut self` reference. It should + /// be preferred if possible, as it avoids atomic operations. + pub(crate) fn update_iteration_count_mut( &mut self, cycle_head_index: DatabaseKeyIndex, new_iteration_count: IterationCount, @@ -179,7 +297,24 @@ impl CycleHeads { .iter_mut() .find(|cycle_head| cycle_head.database_key_index == cycle_head_index) { - cycle_head.iteration_count = new_iteration_count; + cycle_head.iteration_count.store_mut(new_iteration_count); + } + } + + /// Updates the iteration count for the head `cycle_head_index` to `new_iteration_count`. + /// + /// Unlike [`update_iteration_count_mut`], this method takes a `&self` reference. + pub(crate) fn update_iteration_count( + &self, + cycle_head_index: DatabaseKeyIndex, + new_iteration_count: IterationCount, + ) { + if let Some(cycle_head) = self + .0 + .iter() + .find(|cycle_head| cycle_head.database_key_index == cycle_head_index) + { + cycle_head.iteration_count.store(new_iteration_count); } } @@ -188,15 +323,42 @@ impl CycleHeads { self.0.reserve(other.0.len()); for head in other { - if let Some(existing) = self - .0 - .iter() - .find(|candidate| candidate.database_key_index == head.database_key_index) - { - assert_eq!(existing.iteration_count, head.iteration_count); + debug_assert!(!head.removed.load(Ordering::Relaxed)); + self.insert(head.database_key_index, head.iteration_count.load()); + } + } + + pub(crate) fn insert( + &mut self, + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) -> bool { + if let Some(existing) = self + .0 + .iter_mut() + .find(|candidate| candidate.database_key_index == database_key_index) + { + let removed = existing.removed.get_mut(); + + if *removed { + *removed = false; + + true } else { - self.0.push(*head); + let existing_count = existing.iteration_count.load_mut(); + + assert_eq!( + existing_count, iteration_count, + "Can't merge cycle heads {:?} with different iteration counts ({existing_count:?}, {iteration_count:?})", + existing.database_key_index + ); + + false } + } else { + self.0 + .push(CycleHead::new(database_key_index, iteration_count)); + true } } @@ -206,6 +368,37 @@ impl CycleHeads { } } +#[cfg(feature = "persistence")] +impl serde::Serialize for CycleHeads { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(None)?; + for e in self { + if e.removed.load(Ordering::Relaxed) { + continue; + } + + seq.serialize_element(e)?; + } + seq.end() + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for CycleHeads { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let vec: ThinVec = serde::Deserialize::deserialize(deserializer)?; + Ok(CycleHeads(vec)) + } +} + impl IntoIterator for CycleHeads { type Item = CycleHead; type IntoIter = as IntoIterator>::IntoIter; @@ -215,9 +408,44 @@ impl IntoIterator for CycleHeads { } } +pub struct CycleHeadsIterator<'a> { + inner: std::slice::Iter<'a, CycleHead>, +} + +impl<'a> Iterator for CycleHeadsIterator<'a> { + type Item = &'a CycleHead; + + fn next(&mut self) -> Option { + loop { + let next = self.inner.next()?; + + if next.removed.load(Ordering::Relaxed) { + continue; + } + + return Some(next); + } + } +} + +impl FusedIterator for CycleHeadsIterator<'_> {} +impl DoubleEndedIterator for CycleHeadsIterator<'_> { + fn next_back(&mut self) -> Option { + loop { + let next = self.inner.next_back()?; + + if next.removed.load(Ordering::Relaxed) { + continue; + } + + return Some(next); + } + } +} + impl<'a> std::iter::IntoIterator for &'a CycleHeads { type Item = &'a CycleHead; - type IntoIter = std::slice::Iter<'a, CycleHead>; + type IntoIter = CycleHeadsIterator<'a>; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -248,21 +476,3 @@ pub enum ProvisionalStatus { }, FallbackImmediate, } - -impl ProvisionalStatus { - pub(crate) const fn iteration(&self) -> Option { - match self { - ProvisionalStatus::Provisional { iteration, .. } => Some(*iteration), - ProvisionalStatus::Final { iteration, .. } => Some(*iteration), - ProvisionalStatus::FallbackImmediate => None, - } - } - - pub(crate) const fn verified_at(&self) -> Option { - match self { - ProvisionalStatus::Provisional { verified_at, .. } => Some(*verified_at), - ProvisionalStatus::Final { verified_at, .. } => Some(*verified_at), - ProvisionalStatus::FallbackImmediate => None, - } - } -} diff --git a/src/function.rs b/src/function.rs index 58f773895..259dff14b 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,5 +1,5 @@ pub(crate) use maybe_changed_after::{VerifyCycleHeads, VerifyResult}; -pub(crate) use sync::SyncGuard; +pub(crate) use sync::{ClaimGuard, ClaimResult, Reentrancy, SyncGuard, SyncOwner, SyncTable}; use std::any::Any; use std::fmt; @@ -8,11 +8,11 @@ use std::sync::atomic::Ordering; use std::sync::OnceLock; use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, + empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, + ProvisionalStatus, }; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; -use crate::function::sync::{ClaimResult, SyncTable}; use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::{Ingredient, WaitForResult}; use crate::key::DatabaseKeyIndex; @@ -92,7 +92,18 @@ pub trait Configuration: Any { /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return /// value from the latest iteration of this cycle. `count` is the number of cycle iterations - /// we've already completed. + /// completed so far. + /// + /// # Iteration count semantics + /// + /// The `count` parameter isn't guaranteed to start from zero or to be contiguous: + /// + /// * **Initial value**: `count` may be non-zero on the first call for a given query if that + /// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case, + /// `count` continues from the nested cycle's iteration count rather than resetting to zero. + /// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle + /// and the value for this query remains unchanged for one iteration. But the outer cycle might + /// keep iterating because other heads keep changing. fn recover_from_cycle<'db>( db: &'db Self::DbView, value: &Self::Output<'db>, @@ -358,6 +369,41 @@ where }) } + fn set_cycle_iteration_count(&self, zalsa: &Zalsa, input: Id, iteration_count: IterationCount) { + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return; + }; + + memo.revisions + .set_iteration_count(Self::database_key_index(self, input), iteration_count); + } + + fn finalize_cycle_head(&self, zalsa: &Zalsa, input: Id) { + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return; + }; + + memo.revisions.verified_final.store(true, Ordering::Release); + } + + fn cycle_converged(&self, zalsa: &Zalsa, input: Id) -> bool { + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return true; + }; + + memo.revisions.cycle_converged() + } + + fn mark_as_transfer_target(&self, key_index: Id) -> Option { + self.sync_table.mark_as_transfer_target(key_index) + } + fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads { self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) .map(|memo| memo.cycle_heads()) @@ -372,9 +418,12 @@ where /// * [`WaitResult::Cycle`] Claiming the `key_index` results in a cycle because it's on the current's thread query stack or /// running on another thread that is blocked on this thread. fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { - match self.sync_table.try_claim(zalsa, key_index) { + match self + .sync_table + .try_claim(zalsa, key_index, Reentrancy::Deny) + { ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on), - ClaimResult::Cycle => WaitForResult::Cycle, + ClaimResult::Cycle { inner } => WaitForResult::Cycle { inner }, ClaimResult::Claimed(_) => WaitForResult::Available, } } @@ -435,10 +484,6 @@ where unreachable!("function does not allocate pages") } - fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { - C::CYCLE_STRATEGY - } - #[cfg(feature = "accumulator")] unsafe fn accumulated<'db>( &'db self, diff --git a/src/function/execute.rs b/src/function/execute.rs index 9521a9dce..67f76e145 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,12 +1,18 @@ +use smallvec::SmallVec; + use crate::active_query::CompletedQuery; -use crate::cycle::{CycleRecoveryStrategy, IterationCount}; +use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; -use crate::function::{Configuration, IngredientImpl}; +use crate::function::sync::ReleaseMode; +use crate::function::{ClaimGuard, Configuration, IngredientImpl}; +use crate::ingredient::WaitForResult; use crate::plumbing::ZalsaLocal; use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::sync::thread; use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; +use crate::{tracing, Cancelled}; use crate::{DatabaseKeyIndex, Event, EventKind, Id}; impl IngredientImpl @@ -26,12 +32,15 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - zalsa: &'db Zalsa, + mut claim_guard: ClaimGuard<'db>, zalsa_local: &'db ZalsaLocal, - database_key_index: DatabaseKeyIndex, opt_old_memo: Option<&Memo<'db, C>>, ) -> &'db Memo<'db, C> { + let database_key_index = claim_guard.database_key_index(); + let zalsa = claim_guard.zalsa(); + let id = database_key_index.key_index(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); crate::tracing::info!("{:?}: executing query", database_key_index); @@ -40,7 +49,6 @@ where database_key: database_key_index, }) }); - let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); let (new_value, mut completed_query) = match C::CYCLE_STRATEGY { CycleRecoveryStrategy::Panic => Self::execute_query( @@ -94,9 +102,8 @@ where CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate( db, opt_old_memo, - zalsa, + &mut claim_guard, zalsa_local, - database_key_index, memo_ingredient_index, ), }; @@ -117,6 +124,7 @@ where // outputs and update the tracked struct IDs for seeding the next revision. self.diff_outputs(zalsa, database_key_index, old_memo, &completed_query); } + self.insert_memo( zalsa, id, @@ -133,25 +141,53 @@ where &'db self, db: &'db C::DbView, opt_old_memo: Option<&Memo<'db, C>>, - zalsa: &'db Zalsa, + claim_guard: &mut ClaimGuard<'db>, zalsa_local: &'db ZalsaLocal, - database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, ) -> (C::Output<'db>, CompletedQuery) { + claim_guard.set_release_mode(ReleaseMode::Default); + + let database_key_index = claim_guard.database_key_index(); + let zalsa = claim_guard.zalsa(); + let id = database_key_index.key_index(); - let mut iteration_count = IterationCount::initial(); - let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); // Our provisional value from the previous iteration, when doing fixpoint iteration. - // Initially it's set to None, because the initial provisional value is created lazily, - // only when a cycle is actually encountered. - let mut opt_last_provisional: Option<&Memo<'db, C>> = None; + // This is different from `opt_old_memo` which might be from a different revision. + let mut last_provisional_memo: Option<&Memo<'db, C>> = None; + + // TODO: Can we seed those somehow? let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); - let _guard = ClearCycleHeadIfPanicking::new(self, zalsa, id, memo_ingredient_index); + let mut iteration_count = IterationCount::initial(); + + if let Some(old_memo) = opt_old_memo { + if old_memo.verified_at.load() == zalsa.current_revision() + && old_memo.cycle_heads().contains(&database_key_index) + { + let memo_iteration_count = old_memo.revisions.iteration(); + + // The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query. + // However, the locking doesn't handle the case where a thread fetches the result of a panicking + // cycle head query **after** all locks were released. That's what we do here. + // We could consider re-executing the entire cycle but: + // a) It's tricky to ensure that all queries participating in the cycle will re-execute + // (we can't rely on `iteration_count` being updated for nested cycles because the nested cycles may have completed successfully). + // b) It's guaranteed that this query will panic again anyway. + // That's why we simply propagate the panic here. It simplifies our lives and it also avoids duplicate panic messages. + if old_memo.value.is_none() { + tracing::warn!("Propagating panic for cycle head that panicked in an earlier execution in that revision"); + Cancelled::PropagatedPanic.throw(); + } + last_provisional_memo = Some(old_memo); + iteration_count = memo_iteration_count; + } + } - loop { - let previous_memo = opt_last_provisional.or(opt_old_memo); + let _poison_guard = + PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); + let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); + let (new_value, completed_query) = loop { // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next // query with these ids because the query might re-create them as part of the next iteration. @@ -160,118 +196,267 @@ where // if they aren't recreated when reaching the final iteration. active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); - let (mut new_value, mut completed_query) = - Self::execute_query(db, zalsa, active_query, previous_memo); + let (mut new_value, mut completed_query) = Self::execute_query( + db, + zalsa, + active_query, + last_provisional_memo.or(opt_old_memo), + ); + + // If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty) + let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else { + claim_guard.set_release_mode(ReleaseMode::SelfOnly); + break (new_value, completed_query); + }; + + // Take the cycle heads to not-fight-rust's-borrow-checker. + let mut cycle_heads = std::mem::take(cycle_heads); + let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> = + SmallVec::new_const(); + let mut max_iteration_count = iteration_count; + let mut depends_on_self = false; + + // Ensure that we resolve the latest cycle heads from any provisional value this query depended on during execution. + // This isn't required in a single-threaded execution, but it's not guaranteed that `cycle_heads` contains all cycles + // in a multi-threaded execution: + // + // t1: a -> b + // t2: c -> b (blocks on t1) + // t1: a -> b -> c (cycle, returns fixpoint initial with c(0) in heads) + // t1: a -> b (completes b, b has c(0) in its cycle heads, releases `b`, which resumes `t2`, and `retry_provisional` blocks on `c` (t2)) + // t2: c -> a (cycle, returns fixpoint initial for a with a(0) in heads) + // t2: completes c, `provisional_retry` blocks on `a` (t2) + // t1: a (completes `b` with `c` in heads) + // + // Note how `a` only depends on `c` but not `a`. This is because `a` only saw the initial value of `c` and wasn't updated when `c` completed. + // That's why we need to resolve the cycle heads recursively so `cycle_heads` contains all cycle heads at the moment this query completed. + for head in &cycle_heads { + max_iteration_count = max_iteration_count.max(head.iteration_count.load()); + depends_on_self |= head.database_key_index == database_key_index; + + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + for nested_head in + ingredient.cycle_heads(zalsa, head.database_key_index.key_index()) + { + let nested_as_tuple = ( + nested_head.database_key_index, + nested_head.iteration_count.load(), + ); + + if !cycle_heads.contains(&nested_head.database_key_index) + && !missing_heads.contains(&nested_as_tuple) + { + missing_heads.push(nested_as_tuple); + } + } + } + + for (head_key, iteration_count) in missing_heads { + max_iteration_count = max_iteration_count.max(iteration_count); + depends_on_self |= head_key == database_key_index; + + cycle_heads.insert(head_key, iteration_count); + } + + let outer_cycle = outer_cycle(zalsa, zalsa_local, &cycle_heads, database_key_index); // Did the new result we got depend on our own provisional value, in a cycle? - if let Some(cycle_heads) = completed_query - .revisions - .cycle_heads_mut() - .filter(|cycle_heads| cycle_heads.contains(&database_key_index)) - { - let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() + // If not, return because this query is not a cycle head. + if !depends_on_self { + // For as long as this query participates in any cycle, don't release its lock, instead + // transfer it to the outermost cycle head (if any). This prevents any other thread + // from claiming this query (all cycle heads are potential entry points to the same cycle), + // which would result in them competing for the same locks (we want the locks to converge to a single cycle head). + if let Some(outer_cycle) = outer_cycle { + claim_guard.set_release_mode(ReleaseMode::TransferTo(outer_cycle)); } else { - // This is our first time around the loop; a provisional value must have been - // inserted into the memo table when the cycle was hit, so let's pull our - // initial provisional value from there. - let memo = self - .get_memo_from_table_for(zalsa, id, memo_ingredient_index) - .filter(|memo| memo.verified_at.load() == zalsa.current_revision()) - .unwrap_or_else(|| { - unreachable!( - "{database_key_index:#?} is a cycle head, \ + claim_guard.set_release_mode(ReleaseMode::SelfOnly); + } + + completed_query.revisions.set_cycle_heads(cycle_heads); + break (new_value, completed_query); + } + + // Get the last provisional value for this query so that we can compare it with the new value + // to test if the cycle converged. + let last_provisional_value = if let Some(last_provisional) = last_provisional_memo { + // We have a last provisional value from our previous time around the loop. + last_provisional.value.as_ref() + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + let memo = self + .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .unwrap_or_else(|| { + unreachable!( + "{database_key_index:#?} is a cycle head, \ but no provisional memo found" - ) - }); + ) + }); - debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; + debug_assert!(memo.may_be_provisional()); + memo.value.as_ref() + }; - let last_provisional_value = last_provisional_value.expect( - "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", - ); - crate::tracing::debug!( - "{database_key_index:?}: execute: \ - I am a cycle head, comparing last provisional value with new value" - ); - // If the new result is equal to the last provisional result, the cycle has - // converged and we are done. - if !C::values_equal(&new_value, last_provisional_value) { - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle( - db, - &new_value, - iteration_count.as_u32(), - C::id_to_input(zalsa, id), - ) { - crate::CycleRecoveryAction::Iterate => {} - crate::CycleRecoveryAction::Fallback(fallback_value) => { - crate::tracing::debug!( - "{database_key_index:?}: execute: user cycle_fn says to fall back" - ); - new_value = fallback_value; - } - } - // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` - // which is less than `u32::MAX`. - iteration_count = iteration_count.increment().unwrap_or_else(|| { - tracing::warn!( - "{database_key_index:?}: execute: too many cycle iterations" + let last_provisional_value = last_provisional_value.expect( + "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", + ); + tracing::debug!( + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" + ); + + let this_converged = C::values_equal(&new_value, last_provisional_value); + + // If this is the outermost cycle, use the maximum iteration count of all cycles. + // This is important for when later iterations introduce new cycle heads (that then + // become the outermost cycle). We want to ensure that the iteration count keeps increasing + // for all queries or they won't be re-executed because `validate_same_iteration` would + // pass when we go from 1 -> 0 and then increment by 1 to 1). + iteration_count = if outer_cycle.is_none() { + max_iteration_count + } else { + // Otherwise keep the iteration count because outer cycles + // already have a cycle head with this exact iteration count (and we don't allow + // heads from different iterations). + iteration_count + }; + + if !this_converged { + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count.as_u32(), + C::id_to_input(zalsa, id), + ) { + crate::CycleRecoveryAction::Iterate => {} + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: user cycle_fn says to fall back" ); - panic!("{database_key_index:?}: execute: too many cycle iterations") - }); - zalsa.event(&|| { - Event::new(EventKind::WillIterateCycle { - database_key: database_key_index, - iteration_count, - }) - }); - cycle_heads.update_iteration_count(database_key_index, iteration_count); - completed_query - .revisions - .update_iteration_count(iteration_count); - crate::tracing::info!("{database_key_index:?}: execute: iterate again...",); - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new( - Some(new_value), - zalsa.current_revision(), - completed_query.revisions, - ), - memo_ingredient_index, - )); - last_stale_tracked_ids = completed_query.stale_tracked_structs; - - active_query = zalsa_local.push_query(database_key_index, iteration_count); - - continue; + new_value = fallback_value; + } } - crate::tracing::debug!( - "{database_key_index:?}: execute: fixpoint iteration has a final value" + } + + if let Some(outer_cycle) = outer_cycle { + tracing::info!( + "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" ); - cycle_heads.remove(&database_key_index); - - if cycle_heads.is_empty() { - // If there are no more cycle heads, we can mark this as verified. - completed_query - .revisions - .verified_final - .store(true, Ordering::Relaxed); + + completed_query.revisions.set_cycle_heads(cycle_heads); + // Store whether this cycle has converged, so that the outer cycle can check it. + completed_query + .revisions + .set_cycle_converged(this_converged); + + // Transfer ownership of this query to the outer cycle, so that it can claim it + // and other threads don't compete for the same lock. + claim_guard.set_release_mode(ReleaseMode::TransferTo(outer_cycle)); + + break (new_value, completed_query); + } + + // If this is the outermost cycle, test if all inner cycles have converged as well. + let converged = this_converged + && cycle_heads.iter_not_eq(database_key_index).all(|head| { + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + let converged = + ingredient.cycle_converged(zalsa, head.database_key_index.key_index()); + + if !converged { + tracing::debug!("inner cycle {database_key_index:?} has not converged"); + } + + converged + }); + + if converged { + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value after {iteration_count:?} iterations" + ); + + // Set the nested cycles as verified. This is necessary because + // `validate_provisional` doesn't follow cycle heads recursively (and the memos now depend on all cycle heads). + for head in cycle_heads.iter_not_eq(database_key_index) { + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + ingredient.finalize_cycle_head(zalsa, head.database_key_index.key_index()); } + + *completed_query.revisions.verified_final.get_mut() = true; + + break (new_value, completed_query); + } + + // The fixpoint iteration hasn't converged. Iterate again... + iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); + panic!("{database_key_index:?}: execute: too many cycle iterations") + }); + + zalsa.event(&|| { + Event::new(EventKind::WillIterateCycle { + database_key: database_key_index, + iteration_count, + }) + }); + + tracing::info!( + "{database_key_index:?}: execute: iterate again ({iteration_count:?})...", + ); + + // Update the iteration count of nested cycles. + for head in cycle_heads.iter_not_eq(database_key_index) { + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + ingredient.set_cycle_iteration_count( + zalsa, + head.database_key_index.key_index(), + iteration_count, + ); } - crate::tracing::debug!( - "{database_key_index:?}: execute: result.revisions = {revisions:#?}", - revisions = &completed_query.revisions + // Update the iteration count of this cycle head, but only after restoring + // the cycle heads array (or this becomes a no-op). + completed_query.revisions.set_cycle_heads(cycle_heads); + completed_query + .revisions + .update_iteration_count_mut(database_key_index, iteration_count); + + let new_memo = self.insert_memo( + zalsa, + id, + Memo::new( + Some(new_value), + zalsa.current_revision(), + completed_query.revisions, + ), + memo_ingredient_index, ); - break (new_value, completed_query); - } + last_provisional_memo = Some(new_memo); + + last_stale_tracked_ids = completed_query.stale_tracked_structs; + active_query = zalsa_local.push_query(database_key_index, iteration_count); + + continue; + }; + + tracing::debug!( + "{database_key_index:?}: execute_maybe_iterate: result.revisions = {revisions:#?}", + revisions = &completed_query.revisions + ); + + (new_value, completed_query) } #[inline] @@ -325,14 +510,14 @@ where /// a new fix point initial value if that happens. /// /// We could insert a fixpoint initial value here, but it seems unnecessary. -struct ClearCycleHeadIfPanicking<'a, C: Configuration> { +struct PoisonProvisionalIfPanicking<'a, C: Configuration> { ingredient: &'a IngredientImpl, zalsa: &'a Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, } -impl<'a, C: Configuration> ClearCycleHeadIfPanicking<'a, C> { +impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { fn new( ingredient: &'a IngredientImpl, zalsa: &'a Zalsa, @@ -348,9 +533,9 @@ impl<'a, C: Configuration> ClearCycleHeadIfPanicking<'a, C> { } } -impl Drop for ClearCycleHeadIfPanicking<'_, C> { +impl Drop for PoisonProvisionalIfPanicking<'_, C> { fn drop(&mut self) { - if std::thread::panicking() { + if thread::panicking() { let revisions = QueryRevisions::fixpoint_initial(self.ingredient.database_key_index(self.id)); @@ -360,3 +545,44 @@ impl Drop for ClearCycleHeadIfPanicking<'_, C> { } } } + +/// Returns the key of any potential outer cycle head or `None` if there is no outer cycle. +/// +/// That is, any query that's currently blocked on the result computed by this query (claiming it results in a cycle). +fn outer_cycle( + zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, + cycle_heads: &CycleHeads, + current_key: DatabaseKeyIndex, +) -> Option { + // First, look for the outer most cycle head on the same thread. + // Using the outer most over the inner most should reduce the need + // for transitive transfers. + // SAFETY: We don't call into with_query_stack recursively + if let Some(same_thread) = unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + stack + .iter() + .find(|active_query| { + cycle_heads.contains(&active_query.database_key_index) + && active_query.database_key_index != current_key + }) + .map(|active_query| active_query.database_key_index) + }) + } { + return Some(same_thread); + } + + // Check for any outer cycle head running on a different thread. + cycle_heads + .iter_not_eq(current_key) + .rfind(|head| { + let ingredient = zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + matches!( + ingredient.wait_for(zalsa, head.database_key_index.key_index()), + WaitForResult::Cycle { inner: false } + ) + }) + .map(|head| head.database_key_index) +} diff --git a/src/function/fetch.rs b/src/function/fetch.rs index a1b6658f6..ef42708a7 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -4,7 +4,7 @@ use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::maybe_changed_after::VerifyCycleHeads; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; -use crate::function::{Configuration, IngredientImpl}; +use crate::function::{Configuration, IngredientImpl, Reentrancy}; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; use crate::{DatabaseKeyIndex, Id}; @@ -13,6 +13,7 @@ impl IngredientImpl where C: Configuration, { + #[inline] pub fn fetch<'db>( &'db self, db: &'db C::DbView, @@ -57,11 +58,19 @@ where id: Id, ) -> &'db Memo<'db, C> { let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); + let mut retry_count = 0; loop { if let Some(memo) = self .fetch_hot(zalsa, id, memo_ingredient_index) .or_else(|| { - self.fetch_cold_with_retry(zalsa, zalsa_local, db, id, memo_ingredient_index) + self.fetch_cold_with_retry( + zalsa, + zalsa_local, + db, + id, + memo_ingredient_index, + &mut retry_count, + ) }) { return memo; @@ -95,7 +104,6 @@ where } } - #[inline(never)] fn fetch_cold_with_retry<'db>( &'db self, zalsa: &'db Zalsa, @@ -103,6 +111,7 @@ where db: &'db C::DbView, id: Id, memo_ingredient_index: MemoIngredientIndex, + retry_count: &mut u32, ) -> Option<&'db Memo<'db, C>> { let memo = self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)?; @@ -114,7 +123,7 @@ where // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles // never have provisional entries. if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate - || !memo.provisional_retry(zalsa, zalsa_local, self.database_key_index(id)) + || !memo.provisional_retry(zalsa, zalsa_local, self.database_key_index(id), retry_count) { Some(memo) } else { @@ -132,21 +141,21 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = match self.sync_table.try_claim(zalsa, id) { + let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) { ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { blocked_on.block_on(zalsa); - let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo { - // This isn't strictly necessary, but if this is a provisional memo for an inner cycle, - // await all outer cycle heads to give the thread driving it a chance to complete - // (we don't want multiple threads competing for the queries participating in the same cycle). - if memo.value.is_some() && memo.may_be_provisional() { - memo.block_on_heads(zalsa, zalsa_local); + if let Some(memo) = memo { + if memo.value.is_some() { + memo.block_on_heads(zalsa, zalsa_local); + } } } + return None; } ClaimResult::Cycle { .. } => { @@ -200,39 +209,10 @@ where // still valid for the current revision. return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; } - - // If this is a provisional memo from the same revision, await all its cycle heads because - // we need to ensure that only one thread is iterating on a cycle at a given time. - // For example, if we have a nested cycle like so: - // ``` - // a -> b -> c -> b - // -> a - // - // d -> b - // ``` - // thread 1 calls `a` and `a` completes the inner cycle `b -> c` but hasn't finished the outer cycle `a` yet. - // thread 2 now calls `b`. We don't want that thread 2 iterates `b` while thread 1 is iterating `a` at the same time - // because it can result in thread b overriding provisional memos that thread a has accessed already and still relies upon. - // - // By waiting, we ensure that thread 1 completes a (based on a provisional value for `b`) and `b` - // becomes the new outer cycle, which thread 2 drives to completion. - if old_memo.may_be_provisional() - && old_memo.verified_at.load() == zalsa.current_revision() - { - // Try to claim all cycle heads of the provisional memo. If we can't because - // some head is running on another thread, drop our claim guard to give that thread - // a chance to take ownership of this query and complete it as part of its fixpoint iteration. - // We will then block on the cycle head and retry once all cycle heads completed. - if !old_memo.try_claim_heads(zalsa, zalsa_local) { - drop(claim_guard); - old_memo.block_on_heads(zalsa, zalsa_local); - return None; - } - } } } - let memo = self.execute(db, zalsa, zalsa_local, database_key_index, opt_old_memo); + let memo = self.execute(db, claim_guard, zalsa_local, opt_old_memo); Some(memo) } @@ -257,6 +237,19 @@ where let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); if can_shallow_update.yes() { self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); + + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { + memo.revisions + .cycle_heads() + .remove_all_except(database_key_index); + } + + crate::tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + returning last provisional value: {:#?}", + memo.revisions + ); + // SAFETY: memo is present in memo_map. return unsafe { self.extend_memo_lifetime(memo) }; } @@ -299,7 +292,10 @@ where let mut completed_query = active_query.pop(); completed_query .revisions - .set_cycle_heads(CycleHeads::initial(database_key_index)); + .set_cycle_heads(CycleHeads::initial( + database_key_index, + IterationCount::initial(), + )); // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. *completed_query.revisions.verified_final.get_mut() = false; self.insert_memo( diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 4f69655cd..698285055 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,10 +2,10 @@ use rustc_hash::FxHashMap; #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleRecoveryStrategy, ProvisionalStatus}; -use crate::function::memo::Memo; +use crate::cycle::{CycleHeads, CycleRecoveryStrategy, ProvisionalStatus}; +use crate::function::memo::{Memo, TryClaimCycleHeadsIter, TryClaimHeadsResult}; use crate::function::sync::ClaimResult; -use crate::function::{Configuration, IngredientImpl}; +use crate::function::{Configuration, IngredientImpl, Reentrancy}; use crate::key::DatabaseKeyIndex; use crate::sync::atomic::Ordering; @@ -141,7 +141,10 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let _claim_guard = match self.sync_table.try_claim(zalsa, key_index) { + let claim_guard = match self + .sync_table + .try_claim(zalsa, key_index, Reentrancy::Deny) + { ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { blocked_on.block_on(zalsa); @@ -175,10 +178,8 @@ where // If `validate_maybe_provisional` returns `true`, but only because all cycle heads are from the same iteration, // carry over the cycle heads so that the caller verifies them. - if old_memo.may_be_provisional() { - for head in old_memo.cycle_heads() { - cycle_heads.insert_head(head.database_key_index); - } + for head in old_memo.cycle_heads() { + cycle_heads.insert_head(head.database_key_index); } return Some(if old_memo.revisions.changed_at > revision { @@ -227,7 +228,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let memo = self.execute(db, zalsa, zalsa_local, database_key_index, Some(old_memo)); + let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo)); let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. @@ -323,12 +324,11 @@ where } let last_changed = zalsa.last_changed_revision(memo.revisions.durability); - crate::tracing::debug!( - "{database_key_index:?}: check_durability(memo = {memo:#?}, last_changed={:?} <= verified_at={:?}) = {:?}", + crate::tracing::trace!( + "{database_key_index:?}: check_durability({database_key_index:#?}, last_changed={:?} <= verified_at={:?}) = {:?}", last_changed, verified_at, last_changed <= verified_at, - memo = memo.tracing_debug() ); if last_changed <= verified_at { // No input of the suitable durability has changed since last verified. @@ -365,28 +365,48 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, ) -> bool { - !memo.may_be_provisional() - || self.validate_provisional(zalsa, database_key_index, memo) - || self.validate_same_iteration(zalsa, zalsa_local, database_key_index, memo) + if !memo.may_be_provisional() { + return true; + } + + let cycle_heads = memo.cycle_heads(); + + if cycle_heads.is_empty() { + return true; + } + + crate::tracing::trace!( + "{database_key_index:?}: validate_may_be_provisional(memo = {memo:#?})", + memo = memo.tracing_debug() + ); + + let verified_at = memo.verified_at.load(); + + self.validate_provisional(zalsa, database_key_index, memo, verified_at, cycle_heads) + || self.validate_same_iteration( + zalsa, + zalsa_local, + database_key_index, + verified_at, + cycle_heads, + ) } /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and /// return true, if not return false. - #[inline] fn validate_provisional( &self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, + memo_verified_at: Revision, + cycle_heads: &CycleHeads, ) -> bool { crate::tracing::trace!( - "{database_key_index:?}: validate_provisional(memo = {memo:#?})", - memo = memo.tracing_debug() + "{database_key_index:?}: validate_provisional({database_key_index:?})", ); - let memo_verified_at = memo.verified_at.load(); - - for cycle_head in memo.revisions.cycle_heads() { + for cycle_head in cycle_heads { // Test if our cycle heads (with the same revision) are now finalized. let Some(kind) = zalsa .lookup_ingredient(cycle_head.database_key_index.ingredient_index()) @@ -413,7 +433,7 @@ where // // If we don't account for the iteration, then `a` (from iteration 0) will be finalized // because its cycle head `b` is now finalized, but `b` never pulled `a` in the last iteration. - if iteration != cycle_head.iteration_count { + if iteration != cycle_head.iteration_count.load() { return false; } @@ -449,92 +469,61 @@ where &self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, - database_key_index: DatabaseKeyIndex, - memo: &Memo<'_, C>, + memo_database_key_index: DatabaseKeyIndex, + memo_verified_at: Revision, + cycle_heads: &CycleHeads, ) -> bool { - crate::tracing::trace!( - "{database_key_index:?}: validate_same_iteration(memo = {memo:#?})", - memo = memo.tracing_debug() - ); - - let cycle_heads = memo.revisions.cycle_heads(); - if cycle_heads.is_empty() { - return true; - } - - let verified_at = memo.verified_at.load(); + crate::tracing::trace!("validate_same_iteration({memo_database_key_index:?})",); // This is an optimization to avoid unnecessary re-execution within the same revision. // Don't apply it when verifying memos from past revisions. We want them to re-execute // to verify their cycle heads and all participating queries. - if verified_at != zalsa.current_revision() { + if memo_verified_at != zalsa.current_revision() { return false; } - // SAFETY: We do not access the query stack reentrantly. - unsafe { - zalsa_local.with_query_stack_unchecked(|stack| { - cycle_heads.iter().all(|cycle_head| { + // Always return `false` for cycle initial values "unless" they are running in the same thread. + if cycle_heads + .iter() + .all(|head| head.database_key_index == memo_database_key_index) + { + // SAFETY: We do not access the query stack reentrantly. + let on_stack = unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { stack .iter() .rev() - .find(|query| query.database_key_index == cycle_head.database_key_index) - .map(|query| query.iteration_count()) - .or_else(|| { - // If the cycle head isn't on our stack because: - // - // * another thread holds the lock on the cycle head (but it waits for the current query to complete) - // * we're in `maybe_changed_after` because `maybe_changed_after` doesn't modify the cycle stack - // - // check if the latest memo has the same iteration count. - - // However, we've to be careful to skip over fixpoint initial values: - // If the head is the memo we're trying to validate, always return `None` - // to force a re-execution of the query. This is necessary because the query - // has obviously not completed its iteration yet. - // - // This should be rare but the `cycle_panic` test fails on some platforms (mainly GitHub actions) - // without this check. What happens there is that: - // - // * query a blocks on query b - // * query b tries to claim a, fails to do so and inserts the fixpoint initial value - // * query b completes and has `a` as head. It returns its query result Salsa blocks query b from - // exiting inside `block_on` (or the thread would complete before the cycle iteration is complete) - // * query a resumes but panics because of the fixpoint iteration function - // * query b resumes. It rexecutes its own query which then tries to fetch a (which depends on itself because it's a fixpoint initial value). - // Without this check, `validate_same_iteration` would return `true` because the latest memo for `a` is the fixpoint initial value. - // But it should return `false` so that query b's thread re-executes `a` (which then also causes the panic). - // - // That's why we always return `None` if the cycle head is the same as the current database key index. - if cycle_head.database_key_index == database_key_index { - return None; - } + .any(|query| query.database_key_index == memo_database_key_index) + }) + }; - let ingredient = zalsa.lookup_ingredient( - cycle_head.database_key_index.ingredient_index(), - ); - let wait_result = ingredient - .wait_for(zalsa, cycle_head.database_key_index.key_index()); + return on_stack; + } - if !wait_result.is_cycle() { - return None; - } + let cycle_heads_iter = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, cycle_heads); - let provisional_status = ingredient.provisional_status( - zalsa, - cycle_head.database_key_index.key_index(), - )?; + for cycle_head in cycle_heads_iter { + match cycle_head { + TryClaimHeadsResult::Cycle { + head_iteration_count, + memo_iteration_count: current_iteration_count, + verified_at: head_verified_at, + } => { + if head_verified_at != memo_verified_at { + return false; + } - if provisional_status.verified_at() == Some(verified_at) { - provisional_status.iteration() - } else { - None - } - }) - == Some(cycle_head.iteration_count) - }) - }) + if head_iteration_count != current_iteration_count { + return false; + } + } + _ => { + return false; + } + } } + + true } /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the @@ -553,6 +542,12 @@ where cycle_heads: &mut VerifyCycleHeads, can_shallow_update: ShallowUpdate, ) -> VerifyResult { + // If the value is from the same revision but is still provisional, consider it changed + // because we're now in a new iteration. + if can_shallow_update == ShallowUpdate::Verified && old_memo.may_be_provisional() { + return VerifyResult::changed(); + } + crate::tracing::debug!( "{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})", old_memo = old_memo.tracing_debug() @@ -562,12 +557,6 @@ where match old_memo.revisions.origin.as_ref() { QueryOriginRef::Derived(edges) => { - // If the value is from the same revision but is still provisional, consider it changed - // because we're now in a new iteration. - if can_shallow_update == ShallowUpdate::Verified && old_memo.may_be_provisional() { - return VerifyResult::changed(); - } - #[cfg(feature = "accumulator")] let mut inputs = InputAccumulatedValues::Empty; let mut child_cycle_heads = Vec::new(); diff --git a/src/function/memo.rs b/src/function/memo.rs index 793f4832a..302ca73c3 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -3,10 +3,11 @@ use std::fmt::{Debug, Formatter}; use std::mem::transmute; use std::ptr::NonNull; -use crate::cycle::{empty_cycle_heads, CycleHead, CycleHeads, IterationCount, ProvisionalStatus}; +use crate::cycle::{ + empty_cycle_heads, CycleHeads, CycleHeadsIterator, IterationCount, ProvisionalStatus, +}; use crate::function::{Configuration, IngredientImpl}; -use crate::hash::FxHashSet; -use crate::ingredient::{Ingredient, WaitForResult}; +use crate::ingredient::WaitForResult; use crate::key::DatabaseKeyIndex; use crate::revision::AtomicRevision; use crate::runtime::Running; @@ -143,21 +144,23 @@ impl<'db, C: Configuration> Memo<'db, C> { zalsa: &Zalsa, zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, + retry_count: &mut u32, ) -> bool { - if self.revisions.cycle_heads().is_empty() { - return false; - } - - if !self.may_be_provisional() { - return false; - }; - if self.block_on_heads(zalsa, zalsa_local) { // If we get here, we are a provisional value of // the cycle head (either initial value, or from a later iteration) and should be // returned to caller to allow fixpoint iteration to proceed. false } else { + assert!( + *retry_count <= 20000, + "Provisional memo retry limit exceeded for {database_key_index:?}; \ + this usually indicates a bug in salsa's cycle caching/locking. \ + (retried {retry_count} times)", + ); + + *retry_count += 1; + // all our cycle heads are complete; re-fetch // and we should get a non-provisional memo. crate::tracing::debug!( @@ -176,33 +179,50 @@ impl<'db, C: Configuration> Memo<'db, C> { // IMPORTANT: If you make changes to this function, make sure to run `cycle_nested_deep` with // shuttle with at least 10k iterations. - // The most common case is that the entire cycle is running in the same thread. - // If that's the case, short circuit and return `true` immediately. - if self.all_cycles_on_stack(zalsa_local) { + let cycle_heads = self.cycle_heads(); + if cycle_heads.is_empty() { return true; } - // Otherwise, await all cycle heads, recursively. - return block_on_heads_cold(zalsa, self.cycle_heads()); + return block_on_heads_cold(zalsa, zalsa_local, cycle_heads); #[inline(never)] - fn block_on_heads_cold(zalsa: &Zalsa, heads: &CycleHeads) -> bool { + fn block_on_heads_cold( + zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, + heads: &CycleHeads, + ) -> bool { let _entered = crate::tracing::debug_span!("block_on_heads").entered(); - let mut cycle_heads = TryClaimCycleHeadsIter::new(zalsa, heads); + let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, heads); let mut all_cycles = true; - while let Some(claim_result) = cycle_heads.next() { + for claim_result in cycle_heads { match claim_result { - TryClaimHeadsResult::Cycle => {} - TryClaimHeadsResult::Finalized => { - all_cycles = false; + TryClaimHeadsResult::Cycle { + memo_iteration_count: current_iteration_count, + head_iteration_count, + .. + } => { + // We need to refetch if the head now has a new iteration count. + // This is to avoid a race between thread A and B: + // * thread A is in `blocks_on` (`retry_provisional`) for the memo `c`. It owns the lock for `e` + // * thread B owns `d` and calls `c`. `c` didn't depend on `e` in the first iteration. + // Thread B completes the first iteration (which bumps the iteration count on `c`). + // `c` now depends on E in the second iteration, introducing a new cycle head. + // Thread B transfers ownership of `c` to thread A (which awakes A). + // * Thread A now continues, there are no other cycle heads, so all queries result in a cycle. + // However, `d` has now a new iteration count, so it's important that we refetch `c`. + + if current_iteration_count != head_iteration_count { + all_cycles = false; + } } TryClaimHeadsResult::Available => { all_cycles = false; } TryClaimHeadsResult::Running(running) => { all_cycles = false; - running.block_on(&mut cycle_heads); + running.block_on(zalsa); } } } @@ -211,51 +231,6 @@ impl<'db, C: Configuration> Memo<'db, C> { } } - /// Tries to claim all cycle heads to see if they're finalized or available. - /// - /// Unlike `block_on_heads`, this code does not block on any cycle head. Instead it returns `false` if - /// claiming all cycle heads failed because one of them is running on another thread. - pub(super) fn try_claim_heads(&self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal) -> bool { - let _entered = crate::tracing::debug_span!("try_claim_heads").entered(); - if self.all_cycles_on_stack(zalsa_local) { - return true; - } - - let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, self.revisions.cycle_heads()); - - for claim_result in cycle_heads { - match claim_result { - TryClaimHeadsResult::Cycle - | TryClaimHeadsResult::Finalized - | TryClaimHeadsResult::Available => {} - TryClaimHeadsResult::Running(_) => { - return false; - } - } - } - - true - } - - fn all_cycles_on_stack(&self, zalsa_local: &ZalsaLocal) -> bool { - let cycle_heads = self.revisions.cycle_heads(); - if cycle_heads.is_empty() { - return true; - } - - // SAFETY: We do not access the query stack reentrantly. - unsafe { - zalsa_local.with_query_stack_unchecked(|stack| { - cycle_heads.iter().all(|cycle_head| { - stack - .iter() - .rev() - .any(|query| query.database_key_index == cycle_head.database_key_index) - }) - }) - } - } - /// Cycle heads that should be propagated to dependent queries. #[inline(always)] pub(super) fn cycle_heads(&self) -> &CycleHeads { @@ -473,118 +448,111 @@ mod persistence { } pub(super) enum TryClaimHeadsResult<'me> { - /// Claiming every cycle head results in a cycle head. - Cycle, - - /// The cycle head has been finalized. - Finalized, + /// Claiming the cycle head results in a cycle. + Cycle { + head_iteration_count: IterationCount, + memo_iteration_count: IterationCount, + verified_at: Revision, + }, /// The cycle head is not finalized, but it can be claimed. Available, /// The cycle head is currently executed on another thread. - Running(RunningCycleHead<'me>), -} - -pub(super) struct RunningCycleHead<'me> { - inner: Running<'me>, - ingredient: &'me dyn Ingredient, -} - -impl<'a> RunningCycleHead<'a> { - fn block_on(self, cycle_heads: &mut TryClaimCycleHeadsIter<'a>) { - let key_index = self.inner.database_key().key_index(); - self.inner.block_on(cycle_heads.zalsa); - - cycle_heads.queue_ingredient_heads(self.ingredient, key_index); - } + Running(Running<'me>), } /// Iterator to try claiming the transitive cycle heads of a memo. -struct TryClaimCycleHeadsIter<'a> { +pub(super) struct TryClaimCycleHeadsIter<'a> { zalsa: &'a Zalsa, - queue: Vec, - queued: FxHashSet, + zalsa_local: &'a ZalsaLocal, + cycle_heads: CycleHeadsIterator<'a>, } impl<'a> TryClaimCycleHeadsIter<'a> { - fn new(zalsa: &'a Zalsa, heads: &CycleHeads) -> Self { - let queue: Vec<_> = heads.iter().copied().collect(); - let queued: FxHashSet<_> = queue.iter().copied().collect(); - + pub(super) fn new( + zalsa: &'a Zalsa, + zalsa_local: &'a ZalsaLocal, + cycle_heads: &'a CycleHeads, + ) -> Self { Self { zalsa, - queue, - queued, + zalsa_local, + cycle_heads: cycle_heads.iter(), } } - - fn queue_ingredient_heads(&mut self, ingredient: &dyn Ingredient, key: Id) { - // Recursively wait for all cycle heads that this head depends on. It's important - // that we fetch those from the updated memo because the cycle heads can change - // between iterations and new cycle heads can be added if a query depeonds on - // some cycle heads depending on a specific condition being met - // (`a` calls `b` and `c` in iteration 0 but `c` and `d` in iteration 1 or later). - // IMPORTANT: It's critical that we get the cycle head from the latest memo - // here, in case the memo has become part of another cycle (we need to block on that too!). - self.queue.extend( - ingredient - .cycle_heads(self.zalsa, key) - .iter() - .copied() - .filter(|head| self.queued.insert(*head)), - ) - } } impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { type Item = TryClaimHeadsResult<'me>; fn next(&mut self) -> Option { - let head = self.queue.pop()?; + let head = self.cycle_heads.next()?; let head_database_key = head.database_key_index; + let head_iteration_count = head.iteration_count.load(); + + // The most common case is that the head is already in the query stack. So let's check that first. + // SAFETY: We do not access the query stack reentrantly. + if let Some(current_iteration_count) = unsafe { + self.zalsa_local.with_query_stack_unchecked(|stack| { + stack + .iter() + .rev() + .find(|query| query.database_key_index == head_database_key) + .map(|query| query.iteration_count()) + }) + } { + crate::tracing::trace!( + "Waiting for {head_database_key:?} results in a cycle (because it is already in the query stack)" + ); + return Some(TryClaimHeadsResult::Cycle { + head_iteration_count, + memo_iteration_count: current_iteration_count, + verified_at: self.zalsa.current_revision(), + }); + } + let head_key_index = head_database_key.key_index(); let ingredient = self .zalsa .lookup_ingredient(head_database_key.ingredient_index()); - let cycle_head_kind = ingredient - .provisional_status(self.zalsa, head_key_index) - .unwrap_or(ProvisionalStatus::Provisional { - iteration: IterationCount::initial(), - verified_at: Revision::start(), - }); + match ingredient.wait_for(self.zalsa, head_key_index) { + WaitForResult::Cycle { .. } => { + // We hit a cycle blocking on the cycle head; this means this query actively + // participates in the cycle and some other query is blocked on this thread. + crate::tracing::trace!("Waiting for {head_database_key:?} results in a cycle"); + + let provisional_status = ingredient + .provisional_status(self.zalsa, head_key_index) + .expect("cycle head memo to exist"); + let (current_iteration_count, verified_at) = match provisional_status { + ProvisionalStatus::Provisional { + iteration, + verified_at, + } + | ProvisionalStatus::Final { + iteration, + verified_at, + } => (iteration, verified_at), + ProvisionalStatus::FallbackImmediate => { + (IterationCount::initial(), self.zalsa.current_revision()) + } + }; - match cycle_head_kind { - ProvisionalStatus::Final { .. } | ProvisionalStatus::FallbackImmediate => { - // This cycle is already finalized, so we don't need to wait on it; - // keep looping through cycle heads. - crate::tracing::trace!("Dependent cycle head {head:?} has been finalized."); - Some(TryClaimHeadsResult::Finalized) + Some(TryClaimHeadsResult::Cycle { + memo_iteration_count: current_iteration_count, + head_iteration_count, + verified_at, + }) } - ProvisionalStatus::Provisional { .. } => { - match ingredient.wait_for(self.zalsa, head_key_index) { - WaitForResult::Cycle { .. } => { - // We hit a cycle blocking on the cycle head; this means this query actively - // participates in the cycle and some other query is blocked on this thread. - crate::tracing::debug!("Waiting for {head:?} results in a cycle"); - Some(TryClaimHeadsResult::Cycle) - } - WaitForResult::Running(running) => { - crate::tracing::debug!("Ingredient {head:?} is running: {running:?}"); + WaitForResult::Running(running) => { + crate::tracing::trace!("Ingredient {head_database_key:?} is running: {running:?}"); - Some(TryClaimHeadsResult::Running(RunningCycleHead { - inner: running, - ingredient, - })) - } - WaitForResult::Available => { - self.queue_ingredient_heads(ingredient, head_key_index); - Some(TryClaimHeadsResult::Available) - } - } + Some(TryClaimHeadsResult::Running(running)) } + WaitForResult::Available => Some(TryClaimHeadsResult::Available), } } } diff --git a/src/function/sync.rs b/src/function/sync.rs index 0a88844af..97a36262c 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -1,9 +1,13 @@ use rustc_hash::FxHashMap; +use std::collections::hash_map::OccupiedEntry; use crate::key::DatabaseKeyIndex; -use crate::runtime::{BlockResult, Running, WaitResult}; -use crate::sync::thread::{self, ThreadId}; +use crate::runtime::{ + BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult, +}; +use crate::sync::thread::{self}; use crate::sync::Mutex; +use crate::tracing; use crate::zalsa::Zalsa; use crate::{Id, IngredientIndex}; @@ -20,17 +24,36 @@ pub(crate) enum ClaimResult<'a> { /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. - Cycle, + Cycle { + /// `true` if this is a cycle with an inner query. For example, if `a` transferred its ownership to + /// `b`. If the thread claiming `b` tries to claim `a`, then this results in a cycle except when calling + /// [`SyncTable::try_claim`] with [`Reentrant::Allow`]. + inner: bool, + }, /// Successfully claimed the query. Claimed(ClaimGuard<'a>), } pub(crate) struct SyncState { - id: ThreadId, + /// The thread id that currently owns this query (actively executing it or iterating it as part of a larger cycle). + id: SyncOwner, /// Set to true if any other queries are blocked, /// waiting for this query to complete. anyone_waiting: bool, + + /// Whether any other query has transferred its lock ownership to this query. + /// This is only an optimization so that the expensive unblocking of transferred queries + /// can be skipped if `false`. This field might be `true` in cases where queries *were* transferred + /// to this query, but have since then been transferred to another query (in a later iteration). + is_transfer_target: bool, + + /// Whether this query has been claimed by the query that currently owns it. + /// + /// If `a` has been transferred to `b` and the stack for t1 is `b -> a`, then `a` can be claimed + /// and `claimed_twice` is set to `true`. However, t2 won't be able to claim `a` because + /// it doesn't own `b`. + claimed_twice: bool, } impl SyncTable { @@ -41,14 +64,34 @@ impl SyncTable { } } - pub(crate) fn try_claim<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> ClaimResult<'me> { + /// Claims the given key index, or blocks if it is running on another thread. + pub(crate) fn try_claim<'me>( + &'me self, + zalsa: &'me Zalsa, + key_index: Id, + reentrant: Reentrancy, + ) -> ClaimResult<'me> { let mut write = self.syncs.lock(); match write.entry(key_index) { std::collections::hash_map::Entry::Occupied(occupied_entry) => { + let id = match occupied_entry.get().id { + SyncOwner::Thread(id) => id, + SyncOwner::Transferred => { + return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) { + Ok(claimed) => claimed, + Err(other_thread) => match other_thread.block(write) { + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + BlockResult::Running(running) => ClaimResult::Running(running), + }, + } + } + }; + let &mut SyncState { - id, ref mut anyone_waiting, + .. } = occupied_entry.into_mut(); + // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this // value. Everything that is written is also protected @@ -62,22 +105,116 @@ impl SyncTable { write, ) { BlockResult::Running(blocked_on) => ClaimResult::Running(blocked_on), - BlockResult::Cycle => ClaimResult::Cycle, + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, } } std::collections::hash_map::Entry::Vacant(vacant_entry) => { vacant_entry.insert(SyncState { - id: thread::current().id(), + id: SyncOwner::Thread(thread::current().id()), anyone_waiting: false, + is_transfer_target: false, + claimed_twice: false, }); ClaimResult::Claimed(ClaimGuard { key_index, zalsa, sync_table: self, + mode: ReleaseMode::Default, }) } } } + + #[cold] + #[inline(never)] + fn try_claim_transferred<'me>( + &'me self, + zalsa: &'me Zalsa, + mut entry: OccupiedEntry, + reentrant: Reentrancy, + ) -> Result, Box>> { + let key_index = *entry.key(); + let database_key_index = DatabaseKeyIndex::new(self.ingredient, key_index); + let thread_id = thread::current().id(); + + match zalsa + .runtime() + .block_transferred(database_key_index, thread_id) + { + BlockTransferredResult::ImTheOwner if reentrant.is_allow() => { + let SyncState { + id, claimed_twice, .. + } = entry.into_mut(); + debug_assert!(!*claimed_twice); + + *id = SyncOwner::Thread(thread_id); + *claimed_twice = true; + + Ok(ClaimResult::Claimed(ClaimGuard { + key_index, + zalsa, + sync_table: self, + mode: ReleaseMode::SelfOnly, + })) + } + BlockTransferredResult::ImTheOwner => Ok(ClaimResult::Cycle { inner: true }), + BlockTransferredResult::OwnedBy(other_thread) => { + entry.get_mut().anyone_waiting = true; + Err(other_thread) + } + BlockTransferredResult::Released => { + entry.insert(SyncState { + id: SyncOwner::Thread(thread_id), + anyone_waiting: false, + is_transfer_target: false, + claimed_twice: false, + }); + Ok(ClaimResult::Claimed(ClaimGuard { + key_index, + zalsa, + sync_table: self, + mode: ReleaseMode::Default, + })) + } + } + } + + /// Marks `key_index` as a transfer target. + /// + /// Returns the `SyncOwnerId` of the thread that currently owns this query. + /// + /// Note: The result of this method will immediately become stale unless the thread owning `key_index` + /// is currently blocked on this thread (claiming `key_index` from this thread results in a cycle). + pub(super) fn mark_as_transfer_target(&self, key_index: Id) -> Option { + let mut syncs = self.syncs.lock(); + syncs.get_mut(&key_index).map(|state| { + // We set `anyone_waiting` to true because it is used in `ClaimGuard::release` + // to exit early if the query doesn't need to release any locks. + // However, there are now dependent queries that need to be released, that's why we set `anyone_waiting` to true, + // so that `ClaimGuard::release` no longer exits early. + state.anyone_waiting = true; + state.is_transfer_target = true; + + state.id + }) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum SyncOwner { + /// Query is owned by this thread + Thread(thread::ThreadId), + + /// The query's lock ownership has been transferred to another query. + /// E.g. if `a` transfers its ownership to `b`, then only the thread in the critical path + /// to complete b` can claim `a` (in most instances, only the thread owning `b` can claim `a`). + /// + /// The thread owning `a` is stored in the `DependencyGraph`. + /// + /// A query can be marked as `Transferred` even if it has since then been released by the owning query. + /// In that case, the query is effectively unclaimed and the `Transferred` state is stale. The reason + /// for this is that it avoids the need for locking each sync table when releasing the transferred queries. + Transferred, } /// Marks an active 'claim' in the synchronization map. The claim is @@ -87,33 +224,147 @@ pub(crate) struct ClaimGuard<'me> { key_index: Id, zalsa: &'me Zalsa, sync_table: &'me SyncTable, + mode: ReleaseMode, } -impl ClaimGuard<'_> { - fn remove_from_map_and_unblock_queries(&self) { +impl<'me> ClaimGuard<'me> { + pub(crate) const fn zalsa(&self) -> &'me Zalsa { + self.zalsa + } + + pub(crate) const fn database_key_index(&self) -> DatabaseKeyIndex { + DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index) + } + + pub(crate) fn set_release_mode(&mut self, mode: ReleaseMode) { + self.mode = mode; + } + + #[cold] + #[inline(never)] + fn release_panicking(&self) { let mut syncs = self.sync_table.syncs.lock(); + let state = syncs.remove(&self.key_index).expect("key claimed twice?"); + tracing::debug!( + "Release claim on {:?} due to panic", + self.database_key_index() + ); + + self.release(state, WaitResult::Panicked); + } + + #[inline(always)] + fn release(&self, state: SyncState, wait_result: WaitResult) { + let SyncState { + anyone_waiting, + is_transfer_target, + claimed_twice, + .. + } = state; + + if !anyone_waiting { + return; + } + + let runtime = self.zalsa.runtime(); + let database_key_index = self.database_key_index(); - let SyncState { anyone_waiting, .. } = - syncs.remove(&self.key_index).expect("key claimed twice?"); - - if anyone_waiting { - let database_key = DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index); - self.zalsa.runtime().unblock_queries_blocked_on( - database_key, - if thread::panicking() { - tracing::info!("Unblocking queries blocked on {database_key:?} after a panick"); - WaitResult::Panicked - } else { - WaitResult::Completed - }, - ) + if claimed_twice { + runtime.undo_transfer_lock(database_key_index); } + + if is_transfer_target { + runtime.unblock_transferred_queries_owned_by(database_key_index, wait_result); + } + + runtime.unblock_queries_blocked_on(database_key_index, wait_result); + } + + #[cold] + #[inline(never)] + fn release_self(&self) { + let mut syncs = self.sync_table.syncs.lock(); + let std::collections::hash_map::Entry::Occupied(mut state) = syncs.entry(self.key_index) + else { + panic!("key should only be claimed/released once"); + }; + + if state.get().claimed_twice { + state.get_mut().claimed_twice = false; + state.get_mut().id = SyncOwner::Transferred; + } else { + self.release(state.remove(), WaitResult::Completed); + } + } + + #[cold] + #[inline(never)] + pub(crate) fn transfer(&self, new_owner: DatabaseKeyIndex) { + let owner_ingredient = self.zalsa.lookup_ingredient(new_owner.ingredient_index()); + + // Get the owning thread of `new_owner`. + // The thread id is guaranteed to not be stale because `new_owner` must be blocked on `self_key` + // or `transfer_lock` will panic (at least in debug builds). + let Some(new_owner_thread_id) = + owner_ingredient.mark_as_transfer_target(new_owner.key_index()) + else { + self.release( + self.sync_table + .syncs + .lock() + .remove(&self.key_index) + .expect("key should only be claimed/released once"), + WaitResult::Panicked, + ); + + panic!("new owner to be a locked query") + }; + + let mut syncs = self.sync_table.syncs.lock(); + + let self_key = self.database_key_index(); + tracing::debug!( + "Transferring lock ownership of {self_key:?} to {new_owner:?} ({new_owner_thread_id:?})" + ); + + let SyncState { + id, claimed_twice, .. + } = syncs + .get_mut(&self.key_index) + .expect("key should only be claimed/released once"); + + self.zalsa + .runtime() + .transfer_lock(self_key, new_owner, new_owner_thread_id); + + *id = SyncOwner::Transferred; + *claimed_twice = false; } } impl Drop for ClaimGuard<'_> { fn drop(&mut self) { - self.remove_from_map_and_unblock_queries() + if thread::panicking() { + self.release_panicking(); + return; + } + + match self.mode { + ReleaseMode::Default => { + let mut syncs = self.sync_table.syncs.lock(); + let state = syncs + .remove(&self.key_index) + .expect("key should only be claimed/released once"); + + self.release(state, WaitResult::Completed); + } + ReleaseMode::SelfOnly => { + self.release_self(); + } + ReleaseMode::TransferTo(new_owner) => { + self.transfer(new_owner); + } + } } } @@ -122,3 +373,60 @@ impl std::fmt::Debug for SyncTable { f.debug_struct("SyncTable").finish() } } + +/// Controls how the lock is released when the `ClaimGuard` is dropped. +#[derive(Copy, Clone, Debug, Default)] +pub(crate) enum ReleaseMode { + /// The default release mode. + /// + /// Releases the query for which this claim guard holds the lock and any queries that have + /// transferred ownership to this query. + #[default] + Default, + + /// Only releases the lock for this query. Any query that has transferred ownership to this query + /// will remain locked. + /// + /// If this thread panics, the query will be released as normal (default mode). + SelfOnly, + + /// Transfers the ownership of the lock to the specified query. + /// + /// The query will remain locked and only the thread owning the transfer target will be resumed. + /// + /// The transfer target must be a query that's blocked on this query to guarantee that the transfer target doesn't complete + /// before the transfer is finished (which would leave this query locked forever). + /// + /// If this thread panics, the query will be released as normal (default mode). + TransferTo(DatabaseKeyIndex), +} + +impl std::fmt::Debug for ClaimGuard<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClaimGuard") + .field("key_index", &self.key_index) + .field("mode", &self.mode) + .finish_non_exhaustive() + } +} + +/// Controls whether this thread can claim a query that transferred its ownership to a query +/// this thread currently holds the lock for. +/// +/// For example: if query `a` transferred its ownership to query `b`, and this thread holds +/// the lock for `b`, then this thread can also claim `a` — but only when using [`Self::Allow`]. +#[derive(Copy, Clone, PartialEq, Eq)] +pub(crate) enum Reentrancy { + /// Allow `try_claim` to reclaim a query's that transferred its ownership to a query + /// hold by this thread. + Allow, + + /// Only allow claiming queries that haven't been claimed by any thread. + Deny, +} + +impl Reentrancy { + const fn is_allow(self) -> bool { + matches!(self, Reentrancy::Allow) + } +} diff --git a/src/ingredient.rs b/src/ingredient.rs index 3cf36ae61..9b377e4d1 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,7 +1,7 @@ use std::any::{Any, TypeId}; use std::fmt; -use crate::cycle::{empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, ProvisionalStatus}; +use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::hash::{FxHashSet, FxIndexSet}; @@ -93,9 +93,19 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// on an other thread, it's up to caller to block until the result becomes available if desired. /// A return value of [`WaitForResult::Cycle`] means that a cycle was encountered; the waited-on query is either already claimed /// by the current thread, or by a thread waiting on the current thread. - fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { - _ = (zalsa, key_index); - WaitForResult::Available + fn wait_for<'me>(&'me self, _zalsa: &'me Zalsa, _key_index: Id) -> WaitForResult<'me> { + unreachable!( + "wait_for should only be called on cycle heads and only functions can be cycle heads" + ); + } + + /// Invoked when a query transfers its lock-ownership to `_key_index`. Returns the thread + /// owning the lock for `_key_index` or `None` if `_key_index` is not claimed. + /// + /// Note: The returned `SyncOwnerId` may be outdated as soon as this function returns **unless** + /// it's guaranteed that `_key_index` is blocked on the current thread. + fn mark_as_transfer_target(&self, _key_index: Id) -> Option { + unreachable!("mark_as_transfer_target should only be called on functions"); } /// Invoked when the value `output_key` should be marked as valid in the current revision. @@ -157,11 +167,27 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { } // Function ingredient methods - /// If this ingredient is a participant in a cycle, what is its cycle recovery strategy? - /// (Really only relevant to [`crate::function::FunctionIngredient`], - /// since only function ingredients push themselves onto the active query stack.) - fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { - unreachable!("only function ingredients can be part of a cycle") + /// Tests if the (nested) cycle head `_input` has converged in the most recent iteration. + /// + /// Returns `false` if the Memo doesn't exist or if called on a non-cycle head. + fn cycle_converged(&self, _zalsa: &Zalsa, _input: Id) -> bool { + unreachable!("cycle_converged should only be called on cycle heads and only functions can be cycle heads"); + } + + /// Updates the iteration count for the (nested) cycle head `_input` to `iteration_count`. + /// + /// This is a no-op if the memo doesn't exist or if called on a Memo without cycle heads. + fn set_cycle_iteration_count( + &self, + _zalsa: &Zalsa, + _input: Id, + _iteration_count: IterationCount, + ) { + unreachable!("increment_iteration_count should only be called on cycle heads and only functions can be cycle heads"); + } + + fn finalize_cycle_head(&self, _zalsa: &Zalsa, _input: Id) { + unreachable!("finalize_cycle_head should only be called on cycle heads and only functions can be cycle heads"); } /// What were the inputs (if any) that were used to create the value at `key_index`. @@ -302,14 +328,9 @@ pub(crate) fn fmt_index(debug_name: &str, id: Id, fmt: &mut fmt::Formatter<'_>) write!(fmt, "{debug_name}({id:?})") } +#[derive(Debug)] pub enum WaitForResult<'me> { Running(Running<'me>), Available, - Cycle, -} - -impl WaitForResult<'_> { - pub const fn is_cycle(&self) -> bool { - matches!(self, WaitForResult::Cycle) - } + Cycle { inner: bool }, } diff --git a/src/key.rs b/src/key.rs index 82d922565..364015756 100644 --- a/src/key.rs +++ b/src/key.rs @@ -18,7 +18,7 @@ pub struct DatabaseKeyIndex { impl DatabaseKeyIndex { #[inline] - pub(crate) fn new(ingredient_index: IngredientIndex, key_index: Id) -> Self { + pub(crate) const fn new(ingredient_index: IngredientIndex, key_index: Id) -> Self { Self { key_index, ingredient_index, diff --git a/src/runtime.rs b/src/runtime.rs index 8436c684d..670d6d62f 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,6 +1,6 @@ use self::dependency_graph::DependencyGraph; use crate::durability::Durability; -use crate::function::SyncGuard; +use crate::function::{SyncGuard, SyncOwner}; use crate::key::DatabaseKeyIndex; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::sync::thread::{self, ThreadId}; @@ -58,6 +58,57 @@ pub(crate) enum BlockResult<'me> { Cycle, } +pub(crate) enum BlockTransferredResult<'me> { + /// The current thread is the owner of the transferred query + /// and it can claim it if it wants to. + ImTheOwner, + + /// The query is owned/running on another thread. + OwnedBy(Box>), + + /// The query has transferred its ownership to another query previously but that query has + /// since then completed and released the lock. + Released, +} + +pub(super) struct BlockOnTransferredOwner<'me> { + dg: crate::sync::MutexGuard<'me, DependencyGraph>, + /// The query that we're trying to claim. + database_key: DatabaseKeyIndex, + /// The thread that currently owns the lock for the transferred query. + other_id: ThreadId, + /// The current thread that is trying to claim the transferred query. + thread_id: ThreadId, +} + +impl<'me> BlockOnTransferredOwner<'me> { + /// Block on the other thread to complete the computation. + pub(super) fn block(self, query_mutex_guard: SyncGuard<'me>) -> BlockResult<'me> { + // Cycle in the same thread. + if self.thread_id == self.other_id { + return BlockResult::Cycle; + } + + if self.dg.depends_on(self.other_id, self.thread_id) { + crate::tracing::debug!( + "block_on: cycle detected for {:?} in thread {thread_id:?} on {:?}", + self.database_key, + self.other_id, + thread_id = self.thread_id + ); + return BlockResult::Cycle; + } + + BlockResult::Running(Running(Box::new(BlockedOnInner { + dg: self.dg, + query_mutex_guard, + database_key: self.database_key, + other_id: self.other_id, + thread_id: self.thread_id, + }))) + } +} + pub struct Running<'me>(Box>); struct BlockedOnInner<'me> { @@ -69,10 +120,6 @@ struct BlockedOnInner<'me> { } impl Running<'_> { - pub(crate) fn database_key(&self) -> DatabaseKeyIndex { - self.0.database_key - } - /// Blocks on the other thread to complete the computation. pub(crate) fn block_on(self, zalsa: &Zalsa) { let BlockedOnInner { @@ -210,7 +257,7 @@ impl Runtime { let r_old = self.current_revision(); let r_new = r_old.next(); self.revisions[0] = r_new; - crate::tracing::debug!("new_revision: {r_old:?} -> {r_new:?}"); + crate::tracing::info!("new_revision: {r_old:?} -> {r_new:?}"); r_new } @@ -253,9 +300,40 @@ impl Runtime { }))) } + /// Tries to claim ownership of a transferred query where `thread_id` is the current thread and `query` + /// is the query (that had its ownership transferred) to claim. + /// + /// For this operation to be reasonable, the caller must ensure that the sync table lock on `query` is not released + /// before this operation completes. + pub(super) fn block_transferred( + &self, + query: DatabaseKeyIndex, + current_id: ThreadId, + ) -> BlockTransferredResult<'_> { + let dg = self.dependency_graph.lock(); + + let owner_thread = dg.thread_id_of_transferred_query(query, None); + + let Some(owner_thread_id) = owner_thread else { + // The query transferred its ownership but the owner has since then released the lock. + return BlockTransferredResult::Released; + }; + + if owner_thread_id == current_id || dg.depends_on(owner_thread_id, current_id) { + BlockTransferredResult::ImTheOwner + } else { + // Lock is owned by another thread, wait for it to be released. + BlockTransferredResult::OwnedBy(Box::new(BlockOnTransferredOwner { + dg, + database_key: query, + other_id: owner_thread_id, + thread_id: current_id, + })) + } + } + /// Invoked when this runtime completed computing `database_key` with - /// the given result `wait_result` (`wait_result` should be `None` if - /// computing `database_key` panicked and could not complete). + /// the given result `wait_result`. /// This function unblocks any dependent queries and allows them /// to continue executing. pub(crate) fn unblock_queries_blocked_on( @@ -268,6 +346,52 @@ impl Runtime { .unblock_runtimes_blocked_on(database_key, wait_result); } + /// Unblocks all transferred queries that are owned by `database_key` recursively. + /// + /// Invoked when a query completes that has been marked as transfer target (it has + /// queries that transferred their lock ownership to it) with the given `wait_result`. + /// + /// This function unblocks any dependent queries and allows them to continue executing. The + /// query `database_key` is not unblocked by this function. + #[cold] + pub(crate) fn unblock_transferred_queries_owned_by( + &self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + self.dependency_graph + .lock() + .unblock_runtimes_blocked_on_transferred_queries_owned_by(database_key, wait_result); + } + + /// Removes the ownership transfer of `query`'s lock if it exists. + /// + /// If `query` has transferred its lock ownership to another query, this function will remove that transfer, + /// so that `query` now owns its lock again. + #[cold] + pub(super) fn undo_transfer_lock(&self, query: DatabaseKeyIndex) { + self.dependency_graph.lock().undo_transfer_lock(query); + } + + /// Transfers ownership of the lock for `query` to `new_owner_key`. + /// + /// For this operation to be reasonable, the caller must ensure that the sync table lock on `query` is not released + /// and that `new_owner_key` is currently blocked on `query`. Otherwise, `new_owner_key` might + /// complete before the lock is transferred, leaving `query` locked forever. + pub(super) fn transfer_lock( + &self, + query: DatabaseKeyIndex, + new_owner_key: DatabaseKeyIndex, + new_owner_id: SyncOwner, + ) { + self.dependency_graph.lock().transfer_lock( + query, + thread::current().id(), + new_owner_key, + new_owner_id, + ); + } + #[cfg(feature = "persistence")] pub(crate) fn deserialize_from(&mut self, other: &mut Runtime) { // The only field that is serialized is `revisions`. diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index fd26c04fa..403f7c544 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -3,11 +3,16 @@ use std::pin::Pin; use rustc_hash::FxHashMap; use smallvec::SmallVec; +use crate::function::SyncOwner; use crate::key::DatabaseKeyIndex; use crate::runtime::dependency_graph::edge::EdgeCondvar; use crate::runtime::WaitResult; use crate::sync::thread::ThreadId; use crate::sync::MutexGuard; +use crate::tracing; + +type QueryDependents = FxHashMap>; +type TransferredDependents = FxHashMap>; #[derive(Debug, Default)] pub(super) struct DependencyGraph { @@ -15,16 +20,26 @@ pub(super) struct DependencyGraph { /// `K` is blocked on some query executing in the runtime `V`. /// This encodes a graph that must be acyclic (or else deadlock /// will result). - edges: FxHashMap, + edges: Edges, /// Encodes the `ThreadId` that are blocked waiting for the result /// of a given query. - query_dependents: FxHashMap>, + query_dependents: QueryDependents, /// When a key K completes which had dependent queries Qs blocked on it, /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will /// come here to fetch their results. wait_results: FxHashMap, + + /// A `K -> Q` pair indicates that the query `K`'s lock is now owned by the query + /// `Q`. It's important that `transferred` always forms a tree (must be acyclic), + /// or else deadlock will result. + transferred: FxHashMap, + + /// A `K -> [Q]` pair indicates that the query `K` owns the locks of + /// `Q`. This is the reverse mapping of `transferred` to allow efficient unlocking + /// of all dependent queries when `K` completes. + transferred_dependents: TransferredDependents, } impl DependencyGraph { @@ -32,15 +47,7 @@ impl DependencyGraph { /// /// (i.e., there is a path from `from_id` to `to_id` in the graph.) pub(super) fn depends_on(&self, from_id: ThreadId, to_id: ThreadId) -> bool { - let mut p = from_id; - while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) { - if q == to_id { - return true; - } - - p = q; - } - p == to_id + self.edges.depends_on(from_id, to_id) } /// Modifies the graph so that `from_id` is blocked @@ -138,6 +145,381 @@ impl DependencyGraph { // notify the thread. edge.notify(); } + + /// Invoked when the query `database_key` completes and it owns the locks of other queries + /// (the queries transferred their locks to `database_key`). + pub(super) fn unblock_runtimes_blocked_on_transferred_queries_owned_by( + &mut self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + fn unblock_recursive( + me: &mut DependencyGraph, + query: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + me.transferred.remove(&query); + + for query in me.transferred_dependents.remove(&query).unwrap_or_default() { + me.unblock_runtimes_blocked_on(query, wait_result); + unblock_recursive(me, query, wait_result); + } + } + + // If `database_key` is `c` and it has been transferred to `b` earlier, remove its entry. + tracing::trace!( + "unblock_runtimes_blocked_on_transferred_queries_owned_by({database_key:?}" + ); + + if let Some((_, owner)) = self.transferred.remove(&database_key) { + // If this query previously transferred its lock ownership to another query, remove + // it from that queries dependents as it is now completing. + self.transferred_dependents + .get_mut(&owner) + .unwrap() + .remove(&database_key); + } + + unblock_recursive(self, database_key, wait_result); + } + + pub(super) fn undo_transfer_lock(&mut self, database_key: DatabaseKeyIndex) { + if let Some((_, owner)) = self.transferred.remove(&database_key) { + self.transferred_dependents + .get_mut(&owner) + .unwrap() + .remove(&database_key); + } + } + + /// Recursively resolves the thread id that currently owns the lock for `database_key`. + /// + /// Returns `None` if `database_key` hasn't (or has since then been released) transferred its lock + /// and the thread id must be looked up in the `SyncTable` instead. + pub(super) fn thread_id_of_transferred_query( + &self, + database_key: DatabaseKeyIndex, + ignore: Option, + ) -> Option { + let &(mut resolved_thread, owner) = self.transferred.get(&database_key)?; + + let mut current_owner = owner; + + while let Some(&(next_thread, next_key)) = self.transferred.get(¤t_owner) { + if Some(next_key) == ignore { + break; + } + resolved_thread = next_thread; + current_owner = next_key; + } + + Some(resolved_thread) + } + + /// Modifies the graph so that the lock on `query` (currently owned by `current_thread`) is + /// transferred to `new_owner` (which is owned by `new_owner_id`). + pub(super) fn transfer_lock( + &mut self, + query: DatabaseKeyIndex, + current_thread: ThreadId, + new_owner: DatabaseKeyIndex, + new_owner_id: SyncOwner, + ) { + let new_owner_thread = match new_owner_id { + SyncOwner::Thread(thread) => thread, + SyncOwner::Transferred => { + // Skip over `query` to skip over any existing mapping from `new_owner` to `query` that may + // exist from previous transfers. + self.thread_id_of_transferred_query(new_owner, Some(query)) + .expect("new owner should be blocked on `query`") + } + }; + + debug_assert!( + new_owner_thread == current_thread || self.depends_on(new_owner_thread, current_thread), + "new owner {new_owner:?} ({new_owner_thread:?}) must be blocked on {query:?} ({current_thread:?})" + ); + + let thread_changed = match self.transferred.entry(query) { + std::collections::hash_map::Entry::Vacant(entry) => { + // Transfer `c -> b` and there's no existing entry for `c`. + entry.insert((new_owner_thread, new_owner)); + current_thread != new_owner_thread + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + // If we transfer to the same owner as before, return immediately as this is a no-op. + if entry.get() == &(new_owner_thread, new_owner) { + return; + } + + // `Transfer `c -> b` after a previous `c -> d` mapping. + // Update the owner and remove the query from the old owner's dependents. + let &(old_owner_thread, old_owner) = entry.get(); + + // For the example below, remove `d` from `b`'s dependents.` + self.transferred_dependents + .get_mut(&old_owner) + .unwrap() + .remove(&query); + + entry.insert((new_owner_thread, new_owner)); + + // If we have `c -> a -> d` and we now insert a mapping `d -> c`, rewrite the mapping to + // `d -> c -> a` to avoid cycles. + // + // Or, starting with `e -> c -> a -> d -> b` insert `d -> c`. We need to rewrite the tree to + // ``` + // e -> c -> a -> b + // d / + // ``` + // + // + // A cycle between transfers can occur when a later iteration has a different outer most query than + // a previous iteration. The second iteration then hits `cycle_initial` for a different head, (e.g. for `c` where it previously was `d`). + let mut last_segment = self.transferred.entry(new_owner); + + while let std::collections::hash_map::Entry::Occupied(mut entry) = last_segment { + let source = *entry.key(); + let next_target = entry.get().1; + + // If it's `a -> d`, remove `a -> d` and insert an edge from `a -> b` + if next_target == query { + tracing::trace!( + "Remap edge {source:?} -> {next_target:?} to {source:?} -> {old_owner:?} to prevent a cycle", + ); + + // Remove `a` from the dependents of `d` and remove the mapping from `a -> d`. + self.transferred_dependents + .get_mut(&query) + .unwrap() + .remove(&source); + + // if the old mapping was `c -> d` and we now insert `d -> c`, remove `d -> c` + if old_owner == new_owner { + entry.remove(); + } else { + // otherwise (when `d` pointed to some other query, e.g. `b` in the example), + // add an edge from `a` to `b` + entry.insert((old_owner_thread, old_owner)); + self.transferred_dependents + .get_mut(&old_owner) + .unwrap() + .push(source); + } + + break; + } + + last_segment = self.transferred.entry(next_target); + } + + // We simply assume here that the thread has changed because we'd have to walk the entire + // transferred chaine of `old_owner` to know if the thread has changed. This won't save us much + // compared to just updating all dependent threads. + true + } + }; + + // Register `c` as a dependent of `b`. + let all_dependents = self.transferred_dependents.entry(new_owner).or_default(); + debug_assert!(!all_dependents.contains(&new_owner)); + all_dependents.push(query); + + if thread_changed { + tracing::debug!("Unblocking new owner of transfer target {new_owner:?}"); + self.unblock_transfer_target(query, new_owner_thread); + self.update_transferred_edges(query, new_owner_thread); + } + } + + /// Finds the one query in the dependents of the `source_query` (the one that is transferred to a new owner) + /// on which the `new_owner_id` thread blocks on and unblocks it, to ensure progress. + fn unblock_transfer_target(&mut self, source_query: DatabaseKeyIndex, new_owner_id: ThreadId) { + /// Finds the thread that's currently blocking the `new_owner_id` thread. + /// + /// Returns `Some` if there's such a thread where the first element is the query + /// that the thread is blocked on (key into `query_dependents`) and the second element + /// is the index in the list of blocked threads (index into the `query_dependents` value) for that query. + fn find_blocked_thread( + me: &DependencyGraph, + query: DatabaseKeyIndex, + new_owner_id: ThreadId, + ) -> Option<(DatabaseKeyIndex, usize)> { + if let Some(blocked_threads) = me.query_dependents.get(&query) { + for (i, id) in blocked_threads.iter().copied().enumerate() { + if id == new_owner_id || me.edges.depends_on(new_owner_id, id) { + return Some((query, i)); + } + } + } + + me.transferred_dependents + .get(&query) + .iter() + .copied() + .flatten() + .find_map(|dependent| find_blocked_thread(me, *dependent, new_owner_id)) + } + + if let Some((query, query_dependents_index)) = + find_blocked_thread(self, source_query, new_owner_id) + { + let blocked_threads = self.query_dependents.get_mut(&query).unwrap(); + + let thread_id = blocked_threads.swap_remove(query_dependents_index); + if blocked_threads.is_empty() { + self.query_dependents.remove(&query); + } + + self.unblock_runtime(thread_id, WaitResult::Completed); + } + } + + fn update_transferred_edges(&mut self, query: DatabaseKeyIndex, new_owner_thread: ThreadId) { + fn update_transferred_edges( + edges: &mut Edges, + query_dependents: &QueryDependents, + transferred_dependents: &TransferredDependents, + query: DatabaseKeyIndex, + new_owner_thread: ThreadId, + ) { + tracing::trace!("update_transferred_edges({query:?}"); + if let Some(dependents) = query_dependents.get(&query) { + for dependent in dependents.iter() { + let edge = edges.get_mut(dependent).unwrap(); + + tracing::trace!( + "Rewrite edge from {:?} to {new_owner_thread:?}", + edge.blocked_on_id + ); + edge.blocked_on_id = new_owner_thread; + debug_assert!( + !edges.depends_on(new_owner_thread, *dependent), + "Circular reference between blocked edges: {:#?}", + edges + ); + } + }; + + if let Some(dependents) = transferred_dependents.get(&query) { + for dependent in dependents { + update_transferred_edges( + edges, + query_dependents, + transferred_dependents, + *dependent, + new_owner_thread, + ) + } + } + } + + update_transferred_edges( + &mut self.edges, + &self.query_dependents, + &self.transferred_dependents, + query, + new_owner_thread, + ) + } +} + +#[derive(Debug, Default)] +struct Edges(FxHashMap); + +impl Edges { + fn depends_on(&self, from_id: ThreadId, to_id: ThreadId) -> bool { + let mut p = from_id; + while let Some(q) = self.0.get(&p).map(|edge| edge.blocked_on_id) { + if q == to_id { + return true; + } + + p = q; + } + p == to_id + } + + fn get_mut(&mut self, id: &ThreadId) -> Option<&mut edge::Edge> { + self.0.get_mut(id) + } + + fn contains_key(&self, id: &ThreadId) -> bool { + self.0.contains_key(id) + } + + fn insert(&mut self, id: ThreadId, edge: edge::Edge) { + self.0.insert(id, edge); + } + + fn remove(&mut self, id: &ThreadId) -> Option { + self.0.remove(id) + } +} + +#[derive(Debug)] +struct SmallSet(SmallVec<[T; N]>); + +impl SmallSet +where + T: PartialEq, +{ + const fn new() -> Self { + Self(SmallVec::new_const()) + } + + fn push(&mut self, value: T) { + debug_assert!(!self.0.contains(&value)); + + self.0.push(value); + } + + fn contains(&self, value: &T) -> bool { + self.0.contains(value) + } + + fn remove(&mut self, value: &T) -> bool { + if let Some(index) = self.0.iter().position(|x| x == value) { + self.0.swap_remove(index); + true + } else { + false + } + } + + fn iter(&self) -> std::slice::Iter<'_, T> { + self.0.iter() + } +} + +impl IntoIterator for SmallSet { + type Item = T; + type IntoIter = smallvec::IntoIter<[T; N]>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a SmallSet +where + T: PartialEq, +{ + type Item = &'a T; + type IntoIter = std::slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl Default for SmallSet +where + T: PartialEq, +{ + fn default() -> Self { + Self::new() + } } mod edge { @@ -165,7 +547,7 @@ mod edge { /// Signalled whenever a query with dependents completes. /// Allows those dependents to check if they are ready to unblock. - // condvar: unsafe<'stack_frame> Pin<&'stack_frame Condvar>, + /// `condvar: unsafe<'stack_frame> Pin<&'stack_frame Condvar>` condvar: Pin<&'static EdgeCondvar>, } diff --git a/src/tracing.rs b/src/tracing.rs index 47f95d00e..6d3ae8851 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -7,6 +7,12 @@ macro_rules! trace { }; } +macro_rules! warn_event { + ($($x:tt)*) => { + crate::tracing::event!(WARN, $($x)*) + }; +} + macro_rules! info { ($($x:tt)*) => { crate::tracing::event!(INFO, $($x)*) @@ -25,6 +31,13 @@ macro_rules! debug_span { }; } +#[expect(unused_macros)] +macro_rules! info_span { + ($($x:tt)*) => { + crate::tracing::span!(INFO, $($x)*) + }; +} + macro_rules! event { ($level:ident, $($x:tt)*) => {{ let event = { @@ -51,4 +64,5 @@ macro_rules! span { }}; } -pub(crate) use {debug, debug_span, event, info, span, trace}; +#[expect(unused_imports)] +pub(crate) use {debug, debug_span, event, info, info_span, span, trace, warn_event as warn}; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index e332b516f..39d0c489c 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,4 +1,6 @@ use std::cell::{RefCell, UnsafeCell}; +use std::fmt; +use std::fmt::Formatter; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; @@ -11,7 +13,7 @@ use crate::accumulator::{ Accumulator, }; use crate::active_query::{CompletedQuery, QueryStack}; -use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount}; +use crate::cycle::{empty_cycle_heads, AtomicIterationCount, CycleHeads, IterationCount}; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::runtime::Stamp; @@ -513,7 +515,8 @@ impl QueryRevisionsExtra { accumulated, cycle_heads, tracked_struct_ids, - iteration, + iteration: iteration.into(), + cycle_converged: false, })) }; @@ -521,7 +524,6 @@ impl QueryRevisionsExtra { } } -#[derive(Debug)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] struct QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] @@ -561,7 +563,12 @@ struct QueryRevisionsExtraInner { /// iterate again. cycle_heads: CycleHeads, - iteration: IterationCount, + iteration: AtomicIterationCount, + + /// Stores for nested cycle heads whether they've converged in the last iteration. + /// This value is always `false` for other queries. + #[cfg_attr(feature = "persistence", serde(skip))] + cycle_converged: bool, } impl QueryRevisionsExtraInner { @@ -573,6 +580,7 @@ impl QueryRevisionsExtraInner { tracked_struct_ids, cycle_heads, iteration: _, + cycle_converged: _, } = self; #[cfg(feature = "accumulator")] @@ -583,6 +591,44 @@ impl QueryRevisionsExtraInner { } } +impl fmt::Debug for QueryRevisionsExtraInner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + struct FmtTrackedStructIds<'a>(&'a ThinVec<(Identity, Id)>); + + impl fmt::Debug for FmtTrackedStructIds<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut f = f.debug_list(); + + if self.0.len() > 5 { + f.entries(&self.0[..5]); + f.finish_non_exhaustive() + } else { + f.entries(self.0); + f.finish() + } + } + } + + let mut f = f.debug_struct("QueryRevisionsExtraInner"); + + f.field("cycle_heads", &self.cycle_heads) + .field("iteration", &self.iteration) + .field("cycle_converged", &self.cycle_converged); + + #[cfg(feature = "accumulator")] + { + f.field("accumulated", &self.accumulated); + } + + f.field( + "tracked_struct_ids", + &FmtTrackedStructIds(&self.tracked_struct_ids), + ); + + f.finish() + } +} + #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; 4]>()]; @@ -605,7 +651,7 @@ impl QueryRevisions { #[cfg(feature = "accumulator")] AccumulatedMap::default(), ThinVec::default(), - CycleHeads::initial(query), + CycleHeads::initial(query, IterationCount::initial()), IterationCount::initial(), ), } @@ -654,17 +700,55 @@ impl QueryRevisions { }; } - pub(crate) const fn iteration(&self) -> IterationCount { + pub(crate) fn cycle_converged(&self) -> bool { match &self.extra.0 { - Some(extra) => extra.iteration, + Some(extra) => extra.cycle_converged, + None => false, + } + } + + pub(crate) fn set_cycle_converged(&mut self, cycle_converged: bool) { + if let Some(extra) = &mut self.extra.0 { + extra.cycle_converged = cycle_converged + } + } + + pub(crate) fn iteration(&self) -> IterationCount { + match &self.extra.0 { + Some(extra) => extra.iteration.load(), None => IterationCount::initial(), } } + pub(crate) fn set_iteration_count( + &self, + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) { + let Some(extra) = &self.extra.0 else { + return; + }; + debug_assert!(extra.iteration.load() <= iteration_count); + + extra.iteration.store(iteration_count); + + extra + .cycle_heads + .update_iteration_count(database_key_index, iteration_count); + } + /// Updates the iteration count if this query has any cycle heads. Otherwise it's a no-op. - pub(crate) fn update_iteration_count(&mut self, iteration_count: IterationCount) { + pub(crate) fn update_iteration_count_mut( + &mut self, + cycle_head_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) { if let Some(extra) = &mut self.extra.0 { - extra.iteration = iteration_count + extra.iteration.store_mut(iteration_count); + + extra + .cycle_heads + .update_iteration_count_mut(cycle_head_index, iteration_count); } } diff --git a/tests/backtrace.rs b/tests/backtrace.rs index 74124c1ab..b611cac86 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -108,7 +108,7 @@ fn backtrace_works() { at tests/backtrace.rs:32 1: query_cycle(Id(2)) at tests/backtrace.rs:45 - cycle heads: query_cycle(Id(2)) -> IterationCount(0) + cycle heads: query_cycle(Id(2)) -> iteration = 0 2: query_f(Id(2)) at tests/backtrace.rs:40 "#]] @@ -119,9 +119,9 @@ fn backtrace_works() { query stacktrace: 0: query_e(Id(3)) -> (R1, Durability::LOW) at tests/backtrace.rs:32 - 1: query_cycle(Id(3)) -> (R1, Durability::HIGH, iteration = IterationCount(0)) + 1: query_cycle(Id(3)) -> (R1, Durability::HIGH, iteration = 0) at tests/backtrace.rs:45 - cycle heads: query_cycle(Id(3)) -> IterationCount(0) + cycle heads: query_cycle(Id(3)) -> iteration = 0 2: query_f(Id(3)) -> (R1, Durability::HIGH) at tests/backtrace.rs:40 "#]] diff --git a/tests/cycle.rs b/tests/cycle.rs index 7a7e26a07..5e46cc0be 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -95,18 +95,22 @@ impl Input { } } + #[track_caller] fn assert(&self, db: &dyn Db, expected: Value) { assert_eq!(self.eval(db), expected) } + #[track_caller] fn assert_value(&self, db: &dyn Db, expected: u8) { self.assert(db, Value::N(expected)) } + #[track_caller] fn assert_bounds(&self, db: &dyn Db) { self.assert(db, Value::OutOfBounds) } + #[track_caller] fn assert_count(&self, db: &dyn Db) { self.assert(db, Value::TooManyIterations) } @@ -893,7 +897,7 @@ fn cycle_unchanged() { /// /// If nothing in a nested cycle changed in the new revision, no part of the cycle should /// re-execute. -#[test] +#[test_log::test] fn cycle_unchanged_nested() { let mut db = ExecuteValidateLoggerDatabase::default(); let a_in = Inputs::new(&db, vec![]); @@ -978,7 +982,7 @@ fn cycle_unchanged_nested_intertwined() { e.assert_value(&db, 60); } - db.assert_logs_len(15 + i); + db.assert_logs_len(13 + i); // next revision, we change only A, which is not part of the cycle and the cycle does not // depend on. diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 154ba3370..2e0c2cfd0 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -269,7 +269,7 @@ fn cycle_recover_with_structs<'db>( CycleRecoveryAction::Iterate } -#[test] +#[test_log::test] fn test_cycle_with_fixpoint_structs() { let mut db = EventLoggerDatabase::default(); diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index d9d5ca365..ad21b7963 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -62,7 +62,7 @@ fn initial(_db: &dyn KnobsDatabase) -> CycleValue { #[test_log::test] fn the_test() { crate::sync::check(|| { - tracing::debug!("New run"); + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); diff --git a/tests/parallel/cycle_a_t1_b_t2_fallback.rs b/tests/parallel/cycle_a_t1_b_t2_fallback.rs index 8005a9c23..b2d6631cc 100644 --- a/tests/parallel/cycle_a_t1_b_t2_fallback.rs +++ b/tests/parallel/cycle_a_t1_b_t2_fallback.rs @@ -55,11 +55,18 @@ fn the_test() { use crate::Knobs; crate::sync::check(|| { + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); - let t1 = thread::spawn(move || query_a(&db_t1)); - let t2 = thread::spawn(move || query_b(&db_t2)); + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + query_b(&db_t2) + }); let (r_t1, r_t2) = (t1.join(), t2.join()); diff --git a/tests/parallel/cycle_nested_deep.rs b/tests/parallel/cycle_nested_deep.rs index 7b7c2f42a..f2b355616 100644 --- a/tests/parallel/cycle_nested_deep.rs +++ b/tests/parallel/cycle_nested_deep.rs @@ -63,6 +63,7 @@ fn initial(_db: &dyn KnobsDatabase) -> CycleValue { #[test_log::test] fn the_test() { crate::sync::check(|| { + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); let db_t3 = db_t1.clone(); diff --git a/tests/parallel/cycle_nested_deep_conditional.rs b/tests/parallel/cycle_nested_deep_conditional.rs index 316612845..4eff75189 100644 --- a/tests/parallel/cycle_nested_deep_conditional.rs +++ b/tests/parallel/cycle_nested_deep_conditional.rs @@ -72,7 +72,7 @@ fn initial(_db: &dyn KnobsDatabase) -> CycleValue { #[test_log::test] fn the_test() { crate::sync::check(|| { - tracing::debug!("New run"); + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); let db_t3 = db_t1.clone(); diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs index 7c96d808d..51d506456 100644 --- a/tests/parallel/cycle_nested_deep_conditional_changed.rs +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -81,7 +81,7 @@ fn the_test() { use crate::sync; use salsa::Setter as _; sync::check(|| { - tracing::debug!("New run"); + tracing::debug!("Starting new run"); // This is a bit silly but it works around https://github.com/awslabs/shuttle/issues/192 static INITIALIZE: sync::Mutex> = @@ -108,36 +108,36 @@ fn the_test() { } let t1 = thread::spawn(move || { + let _span = tracing::info_span!("t1", thread_id = ?thread::current().id()).entered(); let (db, input) = get_db(|db, input| { query_a(db, input); }); - let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); - query_a(&db, input) }); let t2 = thread::spawn(move || { + let _span = tracing::info_span!("t2", thread_id = ?thread::current().id()).entered(); let (db, input) = get_db(|db, input| { query_b(db, input); }); - let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered(); query_b(&db, input) }); let t3 = thread::spawn(move || { + let _span = tracing::info_span!("t3", thread_id = ?thread::current().id()).entered(); let (db, input) = get_db(|db, input| { query_d(db, input); }); - let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); query_d(&db, input) }); let t4 = thread::spawn(move || { + let _span = tracing::info_span!("t4", thread_id = ?thread::current().id()).entered(); + let (db, input) = get_db(|db, input| { query_e(db, input); }); - let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); query_e(&db, input) }); diff --git a/tests/parallel/cycle_nested_deep_panic.rs b/tests/parallel/cycle_nested_deep_panic.rs new file mode 100644 index 000000000..8b89f362a --- /dev/null +++ b/tests/parallel/cycle_nested_deep_panic.rs @@ -0,0 +1,142 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Tests that salsa doesn't get stuck after a panic in a nested cycle function. + +use crate::sync::thread; +use crate::{Knobs, KnobsDatabase}; +use std::fmt; +use std::panic::catch_unwind; + +use salsa::CycleRecoveryAction; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let c_value = query_c(db); + CycleValue(c_value.0 + 1).min(MAX) +} + +#[salsa::tracked] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + let d_value = query_d(db); + + if d_value > CycleValue(0) { + let e_value = query_e(db); + let b_value = query_b(db); + CycleValue(d_value.0.max(e_value.0).max(b_value.0)) + } else { + let a_value = query_a(db); + CycleValue(d_value.0.max(a_value.0)) + } +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_d(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_e(db: &dyn KnobsDatabase) -> CycleValue { + query_c(db) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +fn run() { + tracing::debug!("Starting new run"); + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + let db_t3 = db_t1.clone(); + let db_t4 = db_t1.clone(); + + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + catch_unwind(|| { + db_t1.wait_for(1); + query_a(&db_t1) + }) + }); + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + catch_unwind(|| { + db_t2.wait_for(1); + + query_b(&db_t2) + }) + }); + let t3 = thread::spawn(move || { + let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); + catch_unwind(|| { + db_t3.signal(2); + query_d(&db_t3) + }) + }); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + + assert_is_set_cycle_error(r_t1); + assert_is_set_cycle_error(r_t2); + assert_is_set_cycle_error(r_t3); + + // Pulling the cycle again at a later point should still result in a panic. + assert_is_set_cycle_error(catch_unwind(|| query_d(&db_t4))); +} + +#[test_log::test] +fn the_test() { + let count = if cfg!(miri) { 1 } else { 200 }; + + for _ in 0..count { + run() + } +} + +#[track_caller] +fn assert_is_set_cycle_error(result: Result>) +where + T: fmt::Debug, +{ + let err = result.expect_err("expected an error"); + + if let Some(message) = err.downcast_ref::<&str>() { + assert!( + message.contains("set cycle_fn/cycle_initial to fixpoint iterate"), + "Expected error message to contain 'set cycle_fn/cycle_initial to fixpoint iterate', but got: {}", + message + ); + } else if let Some(message) = err.downcast_ref::() { + assert!( + message.contains("set cycle_fn/cycle_initial to fixpoint iterate"), + "Expected error message to contain 'set cycle_fn/cycle_initial to fixpoint iterate', but got: {}", + message + ); + } else if err.downcast_ref::().is_some() { + // This is okay, because Salsa throws a Cancelled::PropagatedPanic when a panic occurs in a query + // that it blocks on. + } else { + std::panic::resume_unwind(err); + } +} diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs index c761a80f4..22232bd85 100644 --- a/tests/parallel/cycle_nested_three_threads.rs +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -76,9 +76,18 @@ fn the_test() { let db_t2 = db_t1.clone(); let db_t3 = db_t1.clone(); - let t1 = thread::spawn(move || query_a(&db_t1)); - let t2 = thread::spawn(move || query_b(&db_t2)); - let t3 = thread::spawn(move || query_c(&db_t3)); + let t1 = thread::spawn(move || { + let _span = tracing::info_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + let t2 = thread::spawn(move || { + let _span = tracing::info_span!("t2", thread_id = ?thread::current().id()).entered(); + query_b(&db_t2) + }); + let t3 = thread::spawn(move || { + let _span = tracing::info_span!("t3", thread_id = ?thread::current().id()).entered(); + query_c(&db_t3) + }); let r_t1 = t1.join().unwrap(); let r_t2 = t2.join().unwrap(); diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index a764a864c..6bc89d2a2 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -9,6 +9,7 @@ mod cycle_ab_peeping_c; mod cycle_nested_deep; mod cycle_nested_deep_conditional; mod cycle_nested_deep_conditional_changed; +mod cycle_nested_deep_panic; mod cycle_nested_three_threads; mod cycle_nested_three_threads_changed; mod cycle_panic; @@ -33,7 +34,7 @@ pub(crate) mod sync { pub use shuttle::thread; pub fn check(f: impl Fn() + Send + Sync + 'static) { - shuttle::check_pct(f, 1000, 50); + shuttle::check_pct(f, 2500, 50); } } From 9cfe41c343ff43f258520967081e32caad467bc0 Mon Sep 17 00:00:00 2001 From: Ben Beasley Date: Sun, 19 Oct 2025 11:17:05 +0100 Subject: [PATCH 03/21] Fix missing license files in published macros/macro-rules crates (#1009) --- components/salsa-macro-rules/LICENSE-APACHE | 1 + components/salsa-macro-rules/LICENSE-MIT | 1 + components/salsa-macros/LICENSE-APACHE | 1 + components/salsa-macros/LICENSE-MIT | 1 + 4 files changed, 4 insertions(+) create mode 120000 components/salsa-macro-rules/LICENSE-APACHE create mode 120000 components/salsa-macro-rules/LICENSE-MIT create mode 120000 components/salsa-macros/LICENSE-APACHE create mode 120000 components/salsa-macros/LICENSE-MIT diff --git a/components/salsa-macro-rules/LICENSE-APACHE b/components/salsa-macro-rules/LICENSE-APACHE new file mode 120000 index 000000000..1cd601d0a --- /dev/null +++ b/components/salsa-macro-rules/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/components/salsa-macro-rules/LICENSE-MIT b/components/salsa-macro-rules/LICENSE-MIT new file mode 120000 index 000000000..b2cfbdc7b --- /dev/null +++ b/components/salsa-macro-rules/LICENSE-MIT @@ -0,0 +1 @@ +../../LICENSE-MIT \ No newline at end of file diff --git a/components/salsa-macros/LICENSE-APACHE b/components/salsa-macros/LICENSE-APACHE new file mode 120000 index 000000000..1cd601d0a --- /dev/null +++ b/components/salsa-macros/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/components/salsa-macros/LICENSE-MIT b/components/salsa-macros/LICENSE-MIT new file mode 120000 index 000000000..b2cfbdc7b --- /dev/null +++ b/components/salsa-macros/LICENSE-MIT @@ -0,0 +1 @@ +../../LICENSE-MIT \ No newline at end of file From a4113cd472539fdbc44a4e6a139d0124211da921 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 20 Oct 2025 13:32:46 +0200 Subject: [PATCH 04/21] Simplify `WaitGroup` implementation (#958) * Simplify `WaitGroup` implementation * Slightly cheaper `get_mut` Co-authored-by: Ibraheem Ahmed --------- Co-authored-by: Ibraheem Ahmed --- src/storage.rs | 41 ++++++++++++++++++++--------------------- src/table.rs | 5 ++++- src/views.rs | 6 +++++- src/zalsa_local.rs | 4 ++-- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index f63981e4f..443b53221 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -25,8 +25,6 @@ pub struct StorageHandle { impl Clone for StorageHandle { fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - Self { zalsa_impl: self.zalsa_impl.clone(), coordinate: CoordinateDrop(Arc::clone(&self.coordinate)), @@ -53,7 +51,7 @@ impl StorageHandle { Self { zalsa_impl: Arc::new(Zalsa::new::(event_callback, jars)), coordinate: CoordinateDrop(Arc::new(Coordinate { - clones: Mutex::new(1), + coordinate_lock: Mutex::default(), cvar: Default::default(), })), phantom: PhantomData, @@ -95,17 +93,6 @@ impl Drop for Storage { } } -struct Coordinate { - /// Counter of the number of clones of actor. Begins at 1. - /// Incremented when cloned, decremented when dropped. - clones: Mutex, - cvar: Condvar, -} - -// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an -// inconsistent state. -impl RefUnwindSafe for Coordinate {} - impl Default for Storage { fn default() -> Self { Self::new(None) @@ -168,12 +155,15 @@ impl Storage { .zalsa_impl .event(&|| Event::new(EventKind::DidSetCancellationFlag)); - let mut clones = self.handle.coordinate.clones.lock(); - while *clones != 1 { - clones = self.handle.coordinate.cvar.wait(clones); - } - // The ref count on the `Arc` should now be 1 - let zalsa = Arc::get_mut(&mut self.handle.zalsa_impl).unwrap(); + let mut coordinate_lock = self.handle.coordinate.coordinate_lock.lock(); + let zalsa = loop { + if Arc::strong_count(&self.handle.zalsa_impl) == 1 { + // SAFETY: The strong count is 1, and we never create any weak pointers, + // so we have a unique reference. + break unsafe { &mut *(Arc::as_ptr(&self.handle.zalsa_impl).cast_mut()) }; + } + coordinate_lock = self.handle.coordinate.cvar.wait(coordinate_lock); + }; // cancellation is done, so reset the flag zalsa.runtime_mut().reset_cancellation_flag(); zalsa @@ -260,6 +250,16 @@ impl Clone for Storage { } } +/// A simplified `WaitGroup`, this is used together with `Arc` as the actual counter +struct Coordinate { + coordinate_lock: Mutex<()>, + cvar: Condvar, +} + +// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an +// inconsistent state. +impl RefUnwindSafe for Coordinate {} + struct CoordinateDrop(Arc); impl std::ops::Deref for CoordinateDrop { @@ -272,7 +272,6 @@ impl std::ops::Deref for CoordinateDrop { impl Drop for CoordinateDrop { fn drop(&mut self) { - *self.0.clones.lock() -= 1; self.0.cvar.notify_all(); } } diff --git a/src/table.rs b/src/table.rs index 53cf10cce..5505c1c05 100644 --- a/src/table.rs +++ b/src/table.rs @@ -252,7 +252,10 @@ impl Table { } let allocated_idx = self.push_page::(ingredient, memo_types.clone()); - assert_eq!(allocated_idx, page_idx); + assert_eq!( + allocated_idx, page_idx, + "allocated index does not match requested index" + ); } }; } diff --git a/src/views.rs b/src/views.rs index d449779c3..d58f349f0 100644 --- a/src/views.rs +++ b/src/views.rs @@ -108,7 +108,11 @@ impl Views { &self, func: fn(NonNull) -> NonNull, ) -> &DatabaseDownCaster { - assert_eq!(self.source_type_id, TypeId::of::()); + assert_eq!( + self.source_type_id, + TypeId::of::(), + "mismatched source type" + ); let target_type_id = TypeId::of::(); if let Some((_, caster)) = self .view_casters diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 39d0c489c..7b0399178 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1173,7 +1173,7 @@ impl ActiveQueryGuard<'_> { unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); frame.tracked_struct_ids_mut().seed(tracked_struct_ids); }) @@ -1195,7 +1195,7 @@ impl ActiveQueryGuard<'_> { unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids); }) From ffa811dca2352c6c54e10346df363fdc0d51dd46 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 22 Oct 2025 13:41:56 +0200 Subject: [PATCH 05/21] Remove experimental parallel feature (#1013) --- src/lib.rs | 5 - src/parallel.rs | 91 ------------ tests/parallel/main.rs | 3 - tests/parallel/parallel_cancellation.rs | 67 --------- tests/parallel/parallel_join.rs | 176 ------------------------ tests/parallel/parallel_map.rs | 100 -------------- 6 files changed, 442 deletions(-) delete mode 100644 src/parallel.rs delete mode 100644 tests/parallel/parallel_cancellation.rs delete mode 100644 tests/parallel/parallel_join.rs delete mode 100644 tests/parallel/parallel_map.rs diff --git a/src/lib.rs b/src/lib.rs index 8ab47379d..8c50c9052 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,11 +37,6 @@ mod zalsa_local; #[cfg(not(feature = "inventory"))] mod nonce; -#[cfg(feature = "rayon")] -mod parallel; - -#[cfg(feature = "rayon")] -pub use parallel::{join, par_map}; #[cfg(feature = "macros")] pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update}; diff --git a/src/parallel.rs b/src/parallel.rs deleted file mode 100644 index 8a0bde655..000000000 --- a/src/parallel.rs +++ /dev/null @@ -1,91 +0,0 @@ -use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; - -use crate::{database::RawDatabase, views::DatabaseDownCaster, Database}; - -pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C -where - Db: Database + ?Sized + Send, - F: Fn(&Db, T) -> R + Sync + Send, - T: Send, - R: Send + Sync, - C: FromParallelIterator, -{ - let views = db.zalsa().views(); - let caster = &views.downcaster_for::(); - let db_caster = &views.downcaster_for::(); - inputs - .into_par_iter() - .map_with( - DbForkOnClone(db.fork_db(), caster, db_caster), - |db, element| op(db.as_view(), element), - ) - .collect() -} - -struct DbForkOnClone<'views, Db: Database + ?Sized>( - RawDatabase<'static>, - &'views DatabaseDownCaster, - &'views DatabaseDownCaster, -); - -// SAFETY: `T: Send` -> `&own T: Send`, `DbForkOnClone` is an owning pointer -unsafe impl Send for DbForkOnClone<'_, Db> {} - -impl DbForkOnClone<'_, Db> { - fn as_view(&self) -> &Db { - // SAFETY: The downcaster ensures that the pointer is valid for the lifetime of the view. - unsafe { self.1.downcast_unchecked(self.0) } - } -} - -impl Drop for DbForkOnClone<'_, Db> { - fn drop(&mut self) { - // SAFETY: `caster` is derived from a `db` fitting for our database clone - let db = unsafe { self.1.downcast_mut_unchecked(self.0) }; - // SAFETY: `db` has been box allocated and leaked by `fork_db` - _ = unsafe { Box::from_raw(db) }; - } -} - -impl Clone for DbForkOnClone<'_, Db> { - fn clone(&self) -> Self { - DbForkOnClone( - // SAFETY: `caster` is derived from a `db` fitting for our database clone - unsafe { self.2.downcast_unchecked(self.0) }.fork_db(), - self.1, - self.2, - ) - } -} - -pub fn join(db: &Db, a: A, b: B) -> (RA, RB) -where - A: FnOnce(&Db) -> RA + Send, - B: FnOnce(&Db) -> RB + Send, - RA: Send, - RB: Send, -{ - #[derive(Copy, Clone)] - struct AssertSend(T); - // SAFETY: We send owning pointers over, which are Send, given the `Db` type parameter above is Send - unsafe impl Send for AssertSend {} - - let caster = &db.zalsa().views().downcaster_for::(); - // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get - // moved to another thread before the closure is executed - let db_a = AssertSend(db.fork_db()); - let db_b = AssertSend(db.fork_db()); - let res = rayon::join( - // SAFETY: `caster` is derived from a `db` fitting for our database clone - move || a(unsafe { caster.downcast_unchecked({ db_a }.0) }), - // SAFETY: `caster` is derived from a `db` fitting for our database clone - move || b(unsafe { caster.downcast_unchecked({ db_b }.0) }), - ); - - // SAFETY: `db` has been box allocated and leaked by `fork_db` - // FIXME: Clean this mess up, RAII - _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_a.0)) }; - // SAFETY: `db` has been box allocated and leaked by `fork_db` - _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_b.0)) }; - res -} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 6bc89d2a2..859d14f47 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -14,9 +14,6 @@ mod cycle_nested_three_threads; mod cycle_nested_three_threads_changed; mod cycle_panic; mod cycle_provisional_depending_on_itself; -mod parallel_cancellation; -mod parallel_join; -mod parallel_map; #[cfg(not(feature = "shuttle"))] pub(crate) mod sync { diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs deleted file mode 100644 index a82437d54..000000000 --- a/tests/parallel/parallel_cancellation.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Shuttle doesn't like panics inside of its runtime. -#![cfg(not(feature = "shuttle"))] - -//! Test for thread cancellation. -use salsa::{Cancelled, Setter}; - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input(debug)] -struct MyInput { - field: i32, -} - -#[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: MyInput) -> MyInput { - db.signal(1); - db.wait_for(2); - dummy(db, input) -} - -#[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { - panic!("should never get here!") -} - -// Cancellation signalling test -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 -// | wait for stage 1 -// signal stage 1 set input, triggers cancellation -// wait for stage 2 (blocks) triggering cancellation sends stage 2 -// | -// (unblocked) -// dummy -// panics - -#[test] -fn execute() { - let mut db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - db.signal_on_did_cancel(2); - input.set_field(&mut db).to(2); - - // Assert thread A *should* was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); - - // and inspect the output - expect_test::expect![[r#" - PendingWrite - "#]] - .assert_debug_eq(&cancelled); -} diff --git a/tests/parallel/parallel_join.rs b/tests/parallel/parallel_join.rs deleted file mode 100644 index f39e9a5fc..000000000 --- a/tests/parallel/parallel_join.rs +++ /dev/null @@ -1,176 +0,0 @@ -#![cfg(all(feature = "rayon", not(feature = "shuttle")))] - -// test for rayon-like join interactions. - -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; - -use salsa::{Cancelled, Database, Setter, Storage}; - -use crate::signal::Signal; - -#[salsa::input] -struct ParallelInput { - a: u32, - b: u32, -} - -#[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) { - salsa::join(db, |db| input.a(db) + 1, |db| input.b(db) - 1) -} - -#[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) { - db.signal(1); - salsa::join( - db, - |db| { - db.wait_for(2); - input.a(db) + dummy(db) - }, - |db| { - db.wait_for(2); - input.b(db) + dummy(db) - }, - ) -} - -#[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase) -> u32 { - panic!("should never get here!") -} - -#[test] -#[cfg_attr(miri, ignore)] -fn execute() { - let db = salsa::DatabaseImpl::new(); - - let input = ParallelInput::new(&db, 10, 20); - - tracked_fn(&db, input); -} - -// we expect this to panic, as `salsa::par_map` needs to be called from a query. -#[test] -#[cfg_attr(miri, ignore)] -#[should_panic] -fn direct_calls_panic() { - let db = salsa::DatabaseImpl::new(); - - let input = ParallelInput::new(&db, 10, 20); - let (_, _) = salsa::join(&db, |db| input.a(db) + 1, |db| input.b(db) - 1); -} - -// Cancellation signalling test -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 -// | wait for stage 1 -// signal stage 1 set input, triggers cancellation -// wait for stage 2 (blocks) triggering cancellation sends stage 2 -// | -// (unblocked) -// dummy -// panics - -#[test] -#[cfg_attr(miri, ignore)] -fn execute_cancellation() { - let mut db = Knobs::default(); - - let input = ParallelInput::new(&db, 10, 20); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - db.signal_on_did_cancel(2); - input.set_a(&mut db).to(30); - - // Assert thread A was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); - - // and inspect the output - expect_test::expect![[r#" - PendingWrite - "#]] - .assert_debug_eq(&cancelled); -} - -#[salsa::db] -trait KnobsDatabase: Database { - fn signal(&self, stage: usize); - fn wait_for(&self, stage: usize); -} - -/// A copy of `tests\parallel\setup.rs` that does not assert, as the assert is incorrect for the -/// purposes of this test. -#[salsa::db] -struct Knobs { - storage: salsa::Storage, - signal: Arc, - signal_on_did_cancel: Arc, -} - -impl Knobs { - pub fn signal_on_did_cancel(&self, stage: usize) { - self.signal_on_did_cancel.store(stage, Ordering::Release); - } -} - -impl Clone for Knobs { - #[track_caller] - fn clone(&self) -> Self { - Self { - storage: self.storage.clone(), - signal: self.signal.clone(), - signal_on_did_cancel: self.signal_on_did_cancel.clone(), - } - } -} - -impl Default for Knobs { - fn default() -> Self { - let signal = >::default(); - let signal_on_did_cancel = Arc::new(AtomicUsize::new(0)); - - Self { - storage: Storage::new(Some(Box::new({ - let signal = signal.clone(); - let signal_on_did_cancel = signal_on_did_cancel.clone(); - move |event| { - if let salsa::EventKind::DidSetCancellationFlag = event.kind { - signal.signal(signal_on_did_cancel.load(Ordering::Acquire)); - } - } - }))), - signal, - signal_on_did_cancel, - } - } -} - -#[salsa::db] -impl salsa::Database for Knobs {} - -#[salsa::db] -impl KnobsDatabase for Knobs { - fn signal(&self, stage: usize) { - self.signal.signal(stage); - } - - fn wait_for(&self, stage: usize) { - self.signal.wait_for(stage); - } -} diff --git a/tests/parallel/parallel_map.rs b/tests/parallel/parallel_map.rs deleted file mode 100644 index f05b73363..000000000 --- a/tests/parallel/parallel_map.rs +++ /dev/null @@ -1,100 +0,0 @@ -#![cfg(all(feature = "rayon", not(feature = "shuttle")))] -// test for rayon-like parallel map interactions. - -use salsa::{Cancelled, Setter}; - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -struct ParallelInput { - field: Vec, -} - -#[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec { - salsa::par_map(db, input.field(db), |_db, field| field + 1) -} - -#[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec { - db.signal(1); - salsa::par_map(db, input.field(db), |db, field| { - db.wait_for(2); - field + dummy(db) - }) -} - -#[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase) -> u32 { - panic!("should never get here!") -} - -#[test] -#[cfg_attr(miri, ignore)] -fn execute() { - let db = salsa::DatabaseImpl::new(); - - let counts = (1..=10).collect::>(); - let input = ParallelInput::new(&db, counts); - - tracked_fn(&db, input); -} - -// we expect this to panic, as `salsa::par_map` needs to be called from a query. -#[test] -#[cfg_attr(miri, ignore)] -#[should_panic] -fn direct_calls_panic() { - let db = salsa::DatabaseImpl::new(); - - let counts = (1..=10).collect::>(); - let input = ParallelInput::new(&db, counts); - let _: Vec = salsa::par_map(&db, input.field(&db), |_db, field| field + 1); -} - -// Cancellation signalling test -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 -// | wait for stage 1 -// signal stage 1 set input, triggers cancellation -// wait for stage 2 (blocks) triggering cancellation sends stage 2 -// | -// (unblocked) -// dummy -// panics - -#[test] -#[cfg_attr(miri, ignore)] -fn execute_cancellation() { - let mut db = Knobs::default(); - - let counts = (1..=10).collect::>(); - let input = ParallelInput::new(&db, counts); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let counts = (2..=20).collect::>(); - - db.signal_on_did_cancel(2); - input.set_field(&mut db).to(counts); - - // Assert thread A *should* was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); - - // and inspect the output - expect_test::expect![[r#" - PendingWrite - "#]] - .assert_debug_eq(&cancelled); -} From 16d51d63d515aca6b646529a73c037633b7c1ec4 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 23 Oct 2025 08:15:01 +0200 Subject: [PATCH 06/21] Fix hangs in multithreaded fixpoint iteration (#1010) * Fix race condition between releasing a transferred query's lock and the same query blocking on the outer head in `provisional_retry` * Fix infinite loop in `provisional_retry --- src/active_query.rs | 4 - src/function/execute.rs | 19 +++- src/function/fetch.rs | 46 +------- src/function/maybe_changed_after.rs | 4 +- src/function/memo.rs | 89 ++------------- src/function/sync.rs | 50 ++++++--- src/runtime.rs | 10 +- src/runtime/dependency_graph.rs | 68 +++++++---- tests/parallel/cycle_iteration_mismatch.rs | 124 +++++++++++++++++++++ tests/parallel/main.rs | 1 + 10 files changed, 243 insertions(+), 172 deletions(-) create mode 100644 tests/parallel/cycle_iteration_mismatch.rs diff --git a/src/active_query.rs b/src/active_query.rs index d830fece1..bb5987fcd 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -158,10 +158,6 @@ impl ActiveQuery { } } - pub(super) fn iteration_count(&self) -> IterationCount { - self.iteration_count - } - pub(crate) fn tracked_struct_ids(&self) -> &IdentityMap { &self.tracked_struct_ids } diff --git a/src/function/execute.rs b/src/function/execute.rs index 67f76e145..aa4339bef 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -28,6 +28,11 @@ where /// * `db`, the database. /// * `active_query`, the active stack frame for the query to execute. /// * `opt_old_memo`, the older memo, if any existed. Used for backdating. + /// + /// # Returns + /// The newly computed memo or `None` if this query is part of a larger cycle + /// and `execute` blocked on a cycle head running on another thread. In this case, + /// the memo is potentially outdated and needs to be refetched. #[inline(never)] pub(super) fn execute<'db>( &'db self, @@ -35,7 +40,7 @@ where mut claim_guard: ClaimGuard<'db>, zalsa_local: &'db ZalsaLocal, opt_old_memo: Option<&Memo<'db, C>>, - ) -> &'db Memo<'db, C> { + ) -> Option<&'db Memo<'db, C>> { let database_key_index = claim_guard.database_key_index(); let zalsa = claim_guard.zalsa(); @@ -80,7 +85,7 @@ where // We need to mark the memo as finalized so other cycle participants that have fallbacks // will be verified (participants that don't have fallbacks will not be verified). memo.revisions.verified_final.store(true, Ordering::Release); - return memo; + return Some(memo); } // If we're in the middle of a cycle and we have a fallback, use it instead. @@ -125,7 +130,7 @@ where self.diff_outputs(zalsa, database_key_index, old_memo, &completed_query); } - self.insert_memo( + let memo = self.insert_memo( zalsa, id, Memo::new( @@ -134,7 +139,13 @@ where completed_query.revisions, ), memo_ingredient_index, - ) + ); + + if claim_guard.drop() { + None + } else { + Some(memo) + } } fn execute_maybe_iterate<'db>( diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ef42708a7..a3f3705f4 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -58,20 +58,11 @@ where id: Id, ) -> &'db Memo<'db, C> { let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); - let mut retry_count = 0; + loop { if let Some(memo) = self .fetch_hot(zalsa, id, memo_ingredient_index) - .or_else(|| { - self.fetch_cold_with_retry( - zalsa, - zalsa_local, - db, - id, - memo_ingredient_index, - &mut retry_count, - ) - }) + .or_else(|| self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)) { return memo; } @@ -104,33 +95,6 @@ where } } - fn fetch_cold_with_retry<'db>( - &'db self, - zalsa: &'db Zalsa, - zalsa_local: &'db ZalsaLocal, - db: &'db C::DbView, - id: Id, - memo_ingredient_index: MemoIngredientIndex, - retry_count: &mut u32, - ) -> Option<&'db Memo<'db, C>> { - let memo = self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)?; - - // If we get back a provisional cycle memo, and it's provisional on any cycle heads - // that are claimed by a different thread, we can't propagate the provisional memo - // any further (it could escape outside the cycle); we need to block on the other - // thread completing fixpoint iteration of the cycle, and then we can re-query for - // our no-longer-provisional memo. - // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles - // never have provisional entries. - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate - || !memo.provisional_retry(zalsa, zalsa_local, self.database_key_index(id), retry_count) - { - Some(memo) - } else { - None - } - } - fn fetch_cold<'db>( &'db self, zalsa: &'db Zalsa, @@ -151,7 +115,7 @@ where if let Some(memo) = memo { if memo.value.is_some() { - memo.block_on_heads(zalsa, zalsa_local); + memo.block_on_heads(zalsa); } } } @@ -212,9 +176,7 @@ where } } - let memo = self.execute(db, claim_guard, zalsa_local, opt_old_memo); - - Some(memo) + self.execute(db, claim_guard, zalsa_local, opt_old_memo) } #[cold] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 698285055..62839e865 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -228,7 +228,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo)); + let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo))?; let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. @@ -500,7 +500,7 @@ where return on_stack; } - let cycle_heads_iter = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, cycle_heads); + let cycle_heads_iter = TryClaimCycleHeadsIter::new(zalsa, cycle_heads); for cycle_head in cycle_heads_iter { match cycle_head { diff --git a/src/function/memo.rs b/src/function/memo.rs index 302ca73c3..2e84bc04f 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -14,7 +14,7 @@ use crate::runtime::Running; use crate::sync::atomic::Ordering; use crate::table::memo::MemoTableWithTypesMut; use crate::zalsa::{MemoIngredientIndex, Zalsa}; -use crate::zalsa_local::{QueryOriginRef, QueryRevisions, ZalsaLocal}; +use crate::zalsa_local::{QueryOriginRef, QueryRevisions}; use crate::{Event, EventKind, Id, Revision}; impl IngredientImpl { @@ -132,50 +132,12 @@ impl<'db, C: Configuration> Memo<'db, C> { !self.revisions.verified_final.load(Ordering::Relaxed) } - /// Invoked when `refresh_memo` is about to return a memo to the caller; if that memo is - /// provisional, and its cycle head is claimed by another thread, we need to wait for that - /// other thread to complete the fixpoint iteration, and then retry fetching our own memo. - /// - /// Return `true` if the caller should retry, `false` if the caller should go ahead and return - /// this memo to the caller. - #[inline(always)] - pub(super) fn provisional_retry( - &self, - zalsa: &Zalsa, - zalsa_local: &ZalsaLocal, - database_key_index: DatabaseKeyIndex, - retry_count: &mut u32, - ) -> bool { - if self.block_on_heads(zalsa, zalsa_local) { - // If we get here, we are a provisional value of - // the cycle head (either initial value, or from a later iteration) and should be - // returned to caller to allow fixpoint iteration to proceed. - false - } else { - assert!( - *retry_count <= 20000, - "Provisional memo retry limit exceeded for {database_key_index:?}; \ - this usually indicates a bug in salsa's cycle caching/locking. \ - (retried {retry_count} times)", - ); - - *retry_count += 1; - - // all our cycle heads are complete; re-fetch - // and we should get a non-provisional memo. - crate::tracing::debug!( - "Retrying provisional memo {database_key_index:?} after awaiting cycle heads." - ); - true - } - } - /// Blocks on all cycle heads (recursively) that this memo depends on. /// /// Returns `true` if awaiting all cycle heads results in a cycle. This means, they're all waiting /// for us to make progress. #[inline(always)] - pub(super) fn block_on_heads(&self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal) -> bool { + pub(super) fn block_on_heads(&self, zalsa: &Zalsa) -> bool { // IMPORTANT: If you make changes to this function, make sure to run `cycle_nested_deep` with // shuttle with at least 10k iterations. @@ -184,16 +146,12 @@ impl<'db, C: Configuration> Memo<'db, C> { return true; } - return block_on_heads_cold(zalsa, zalsa_local, cycle_heads); + return block_on_heads_cold(zalsa, cycle_heads); #[inline(never)] - fn block_on_heads_cold( - zalsa: &Zalsa, - zalsa_local: &ZalsaLocal, - heads: &CycleHeads, - ) -> bool { + fn block_on_heads_cold(zalsa: &Zalsa, heads: &CycleHeads) -> bool { let _entered = crate::tracing::debug_span!("block_on_heads").entered(); - let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, heads); + let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, heads); let mut all_cycles = true; for claim_result in cycle_heads { @@ -447,6 +405,7 @@ mod persistence { } } +#[derive(Debug)] pub(super) enum TryClaimHeadsResult<'me> { /// Claiming the cycle head results in a cycle. Cycle { @@ -465,19 +424,15 @@ pub(super) enum TryClaimHeadsResult<'me> { /// Iterator to try claiming the transitive cycle heads of a memo. pub(super) struct TryClaimCycleHeadsIter<'a> { zalsa: &'a Zalsa, - zalsa_local: &'a ZalsaLocal, + cycle_heads: CycleHeadsIterator<'a>, } impl<'a> TryClaimCycleHeadsIter<'a> { - pub(super) fn new( - zalsa: &'a Zalsa, - zalsa_local: &'a ZalsaLocal, - cycle_heads: &'a CycleHeads, - ) -> Self { + pub(super) fn new(zalsa: &'a Zalsa, cycle_heads: &'a CycleHeads) -> Self { Self { zalsa, - zalsa_local, + cycle_heads: cycle_heads.iter(), } } @@ -488,31 +443,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { fn next(&mut self) -> Option { let head = self.cycle_heads.next()?; - let head_database_key = head.database_key_index; - let head_iteration_count = head.iteration_count.load(); - - // The most common case is that the head is already in the query stack. So let's check that first. - // SAFETY: We do not access the query stack reentrantly. - if let Some(current_iteration_count) = unsafe { - self.zalsa_local.with_query_stack_unchecked(|stack| { - stack - .iter() - .rev() - .find(|query| query.database_key_index == head_database_key) - .map(|query| query.iteration_count()) - }) - } { - crate::tracing::trace!( - "Waiting for {head_database_key:?} results in a cycle (because it is already in the query stack)" - ); - return Some(TryClaimHeadsResult::Cycle { - head_iteration_count, - memo_iteration_count: current_iteration_count, - verified_at: self.zalsa.current_revision(), - }); - } - let head_key_index = head_database_key.key_index(); let ingredient = self .zalsa @@ -543,7 +474,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { Some(TryClaimHeadsResult::Cycle { memo_iteration_count: current_iteration_count, - head_iteration_count, + head_iteration_count: head.iteration_count.load(), verified_at, }) } diff --git a/src/function/sync.rs b/src/function/sync.rs index 97a36262c..02f1bffd0 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -273,11 +273,11 @@ impl<'me> ClaimGuard<'me> { runtime.undo_transfer_lock(database_key_index); } + runtime.unblock_queries_blocked_on(database_key_index, wait_result); + if is_transfer_target { runtime.unblock_transferred_queries_owned_by(database_key_index, wait_result); } - - runtime.unblock_queries_blocked_on(database_key_index, wait_result); } #[cold] @@ -299,7 +299,7 @@ impl<'me> ClaimGuard<'me> { #[cold] #[inline(never)] - pub(crate) fn transfer(&self, new_owner: DatabaseKeyIndex) { + pub(crate) fn transfer(&self, new_owner: DatabaseKeyIndex) -> bool { let owner_ingredient = self.zalsa.lookup_ingredient(new_owner.ingredient_index()); // Get the owning thread of `new_owner`. @@ -333,22 +333,27 @@ impl<'me> ClaimGuard<'me> { .get_mut(&self.key_index) .expect("key should only be claimed/released once"); - self.zalsa - .runtime() - .transfer_lock(self_key, new_owner, new_owner_thread_id); - *id = SyncOwner::Transferred; *claimed_twice = false; + + self.zalsa + .runtime() + .transfer_lock(self_key, new_owner, new_owner_thread_id, syncs) } -} -impl Drop for ClaimGuard<'_> { - fn drop(&mut self) { - if thread::panicking() { - self.release_panicking(); - return; - } + /// Drops the claim on the memo. + /// + /// Returns `true` if the lock was transferred to another query and + /// this thread blocked waiting for the new owner's lock to be released. + /// In that case, any computed memo need to be refetched because they may have + /// changed since `drop` was called. + pub(crate) fn drop(mut self) -> bool { + let refetch = self.drop_impl(); + std::mem::forget(self); + refetch + } + fn drop_impl(&mut self) -> bool { match self.mode { ReleaseMode::Default => { let mut syncs = self.sync_table.syncs.lock(); @@ -357,17 +362,28 @@ impl Drop for ClaimGuard<'_> { .expect("key should only be claimed/released once"); self.release(state, WaitResult::Completed); + false } ReleaseMode::SelfOnly => { self.release_self(); + false } - ReleaseMode::TransferTo(new_owner) => { - self.transfer(new_owner); - } + ReleaseMode::TransferTo(new_owner) => self.transfer(new_owner), } } } +impl Drop for ClaimGuard<'_> { + fn drop(&mut self) { + if thread::panicking() { + self.release_panicking(); + return; + } + + self.drop_impl(); + } +} + impl std::fmt::Debug for SyncTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SyncTable").finish() diff --git a/src/runtime.rs b/src/runtime.rs index 670d6d62f..48caf53ec 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -383,13 +383,17 @@ impl Runtime { query: DatabaseKeyIndex, new_owner_key: DatabaseKeyIndex, new_owner_id: SyncOwner, - ) { - self.dependency_graph.lock().transfer_lock( + guard: SyncGuard, + ) -> bool { + let dg = self.dependency_graph.lock(); + DependencyGraph::transfer_lock( + dg, query, thread::current().id(), new_owner_key, new_owner_id, - ); + guard, + ) } #[cfg(feature = "persistence")] diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 403f7c544..9b8cbe221 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use rustc_hash::FxHashMap; use smallvec::SmallVec; -use crate::function::SyncOwner; +use crate::function::{SyncGuard, SyncOwner}; use crate::key::DatabaseKeyIndex; use crate::runtime::dependency_graph::edge::EdgeCondvar; use crate::runtime::WaitResult; @@ -199,18 +199,23 @@ impl DependencyGraph { pub(super) fn thread_id_of_transferred_query( &self, database_key: DatabaseKeyIndex, - ignore: Option, + skip_over: Option, ) -> Option { let &(mut resolved_thread, owner) = self.transferred.get(&database_key)?; let mut current_owner = owner; while let Some(&(next_thread, next_key)) = self.transferred.get(¤t_owner) { - if Some(next_key) == ignore { - break; + current_owner = next_key; + + // Ignore the `skip_over` key. E.g. if we have `a -> b -> c` and we want to resolve `a` but are transferring `b` to `c`, then + // we don't want to resolve `a` to the owner of `c`. But for `a -> c -> b`, we want resolve `a` to the owner of `c` and not `b` + // (because `b` will be owned by `a`). + if Some(next_key) == skip_over { + continue; } + resolved_thread = next_thread; - current_owner = next_key; } Some(resolved_thread) @@ -218,29 +223,36 @@ impl DependencyGraph { /// Modifies the graph so that the lock on `query` (currently owned by `current_thread`) is /// transferred to `new_owner` (which is owned by `new_owner_id`). + /// + /// Note, this function will block if `new_owner` runs on a different thread, unless `new_owner` is blocked + /// on current thread after transferring the query ownership. + /// + /// Returns `true` if the transfer blocked on `new_owner` (in which case it might be necessary to refetch any previously computed memos). pub(super) fn transfer_lock( - &mut self, + mut me: MutexGuard, query: DatabaseKeyIndex, current_thread: ThreadId, new_owner: DatabaseKeyIndex, new_owner_id: SyncOwner, - ) { + guard: SyncGuard, + ) -> bool { + let dg = &mut *me; let new_owner_thread = match new_owner_id { SyncOwner::Thread(thread) => thread, SyncOwner::Transferred => { // Skip over `query` to skip over any existing mapping from `new_owner` to `query` that may // exist from previous transfers. - self.thread_id_of_transferred_query(new_owner, Some(query)) + dg.thread_id_of_transferred_query(new_owner, Some(query)) .expect("new owner should be blocked on `query`") } }; debug_assert!( - new_owner_thread == current_thread || self.depends_on(new_owner_thread, current_thread), + new_owner_thread == current_thread || dg.depends_on(new_owner_thread, current_thread), "new owner {new_owner:?} ({new_owner_thread:?}) must be blocked on {query:?} ({current_thread:?})" ); - let thread_changed = match self.transferred.entry(query) { + let thread_changed = match dg.transferred.entry(query) { std::collections::hash_map::Entry::Vacant(entry) => { // Transfer `c -> b` and there's no existing entry for `c`. entry.insert((new_owner_thread, new_owner)); @@ -249,7 +261,7 @@ impl DependencyGraph { std::collections::hash_map::Entry::Occupied(mut entry) => { // If we transfer to the same owner as before, return immediately as this is a no-op. if entry.get() == &(new_owner_thread, new_owner) { - return; + return false; } // `Transfer `c -> b` after a previous `c -> d` mapping. @@ -257,7 +269,7 @@ impl DependencyGraph { let &(old_owner_thread, old_owner) = entry.get(); // For the example below, remove `d` from `b`'s dependents.` - self.transferred_dependents + dg.transferred_dependents .get_mut(&old_owner) .unwrap() .remove(&query); @@ -273,10 +285,9 @@ impl DependencyGraph { // d / // ``` // - // // A cycle between transfers can occur when a later iteration has a different outer most query than // a previous iteration. The second iteration then hits `cycle_initial` for a different head, (e.g. for `c` where it previously was `d`). - let mut last_segment = self.transferred.entry(new_owner); + let mut last_segment = dg.transferred.entry(new_owner); while let std::collections::hash_map::Entry::Occupied(mut entry) = last_segment { let source = *entry.key(); @@ -289,19 +300,19 @@ impl DependencyGraph { ); // Remove `a` from the dependents of `d` and remove the mapping from `a -> d`. - self.transferred_dependents + dg.transferred_dependents .get_mut(&query) .unwrap() .remove(&source); - // if the old mapping was `c -> d` and we now insert `d -> c`, remove `d -> c` + // if the old mapping was `c -> d` and we now insert `d -> c`, remove `c -> d` if old_owner == new_owner { entry.remove(); } else { // otherwise (when `d` pointed to some other query, e.g. `b` in the example), // add an edge from `a` to `b` entry.insert((old_owner_thread, old_owner)); - self.transferred_dependents + dg.transferred_dependents .get_mut(&old_owner) .unwrap() .push(source); @@ -310,7 +321,7 @@ impl DependencyGraph { break; } - last_segment = self.transferred.entry(next_target); + last_segment = dg.transferred.entry(next_target); } // We simply assume here that the thread has changed because we'd have to walk the entire @@ -321,15 +332,30 @@ impl DependencyGraph { }; // Register `c` as a dependent of `b`. - let all_dependents = self.transferred_dependents.entry(new_owner).or_default(); + let all_dependents = dg.transferred_dependents.entry(new_owner).or_default(); debug_assert!(!all_dependents.contains(&new_owner)); all_dependents.push(query); if thread_changed { tracing::debug!("Unblocking new owner of transfer target {new_owner:?}"); - self.unblock_transfer_target(query, new_owner_thread); - self.update_transferred_edges(query, new_owner_thread); + dg.unblock_transfer_target(query, new_owner_thread); + dg.update_transferred_edges(query, new_owner_thread); + + // Block on the new owner, unless new owner is blocked on this query. + // This is necessary to avoid a race between `fetch` completing and `provisional_retry` blocking on the + // first cycle head. + if current_thread != new_owner_thread + && !dg.depends_on(new_owner_thread, current_thread) + { + crate::tracing::info!( + "block_on: thread {current_thread:?} is blocking on {new_owner:?} in thread {new_owner_thread:?}", + ); + Self::block_on(me, current_thread, new_owner, new_owner_thread, guard); + return true; + } } + + false } /// Finds the one query in the dependents of the `source_query` (the one that is transferred to a new owner) diff --git a/tests/parallel/cycle_iteration_mismatch.rs b/tests/parallel/cycle_iteration_mismatch.rs new file mode 100644 index 000000000..17cc60108 --- /dev/null +++ b/tests/parallel/cycle_iteration_mismatch.rs @@ -0,0 +1,124 @@ +//! Test for iteration count mismatch bug where cycle heads have different iteration counts +//! +//! This test aims to reproduce the scenario where: +//! 1. A memo has multiple cycle heads with different iteration counts +//! 2. When validating, iteration counts mismatch causes re-execution +//! 3. After re-execution, the memo still has the same mismatched iteration counts + +use crate::sync::thread; +use crate::{Knobs, KnobsDatabase}; +use salsa::CycleRecoveryAction; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(5); + +// Query A: First cycle head - will iterate multiple times +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + let b = query_b(db); + CycleValue(b.0 + 1).min(MAX) +} + +// Query B: Depends on C and D, creating complex dependencies +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let c = query_c(db); + let d = query_d(db); + CycleValue(c.0.max(d.0) + 1).min(MAX) +} + +// Query C: Creates a cycle back to A +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + let a = query_a(db); + // Also depends on E to create more complex cycle structure + let e = query_e(db); + CycleValue(a.0.max(e.0)) +} + +// Query D: Part of a separate cycle with E +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_d(db: &dyn KnobsDatabase) -> CycleValue { + let e = query_e(db); + CycleValue(e.0 + 1).min(MAX) +} + +// Query E: Depends back on D and F +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_e(db: &dyn KnobsDatabase) -> CycleValue { + let d = query_d(db); + let f = query_f(db); + CycleValue(d.0.max(f.0) + 1).min(MAX) +} + +// Query F: Creates another cycle that might have different iteration count +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_f(db: &dyn KnobsDatabase) -> CycleValue { + // Create a cycle that depends on earlier queries + let b = query_b(db); + let e = query_e(db); + CycleValue(b.0.max(e.0)) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn test_iteration_count_mismatch() { + crate::sync::check(|| { + tracing::debug!("Starting new run"); + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + let db_t3 = db_t1.clone(); + let db_t4 = db_t1.clone(); + + // Thread 1: Starts with query_a - main cycle head + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + + // Thread 2: Starts with query_d - separate cycle that will have different iteration + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + query_d(&db_t2) + }); + + // Thread 3: Starts with query_f after others have started + let t3 = thread::spawn(move || { + let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); + query_f(&db_t3) + }); + + // Thread 4: Queries b which depends on multiple cycles + let t4 = thread::spawn(move || { + let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered(); + query_b(&db_t4) + }); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + let r_t4 = t4.join().unwrap(); + + // All queries should converge to the same value + assert_eq!(r_t1, r_t2); + assert_eq!(r_t2, r_t3); + assert_eq!(r_t3, r_t4); + + // They should have computed a non-initial value + assert!(r_t1.0 > MIN.0); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 859d14f47..1062d4899 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -6,6 +6,7 @@ mod signal; mod cycle_a_t1_b_t2; mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; +mod cycle_iteration_mismatch; mod cycle_nested_deep; mod cycle_nested_deep_conditional; mod cycle_nested_deep_conditional_changed; From d38145c29574758de7ffbe8a13cd4584c3b09161 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 23 Oct 2025 08:36:56 +0200 Subject: [PATCH 07/21] Expose the query ID and the last provisional value to the cycle recovery function (#1012) * Expose the query ID and the last provisional value to the cycle recovery function * Mark cycle as converged if fallback value is the same as the last provisional * Make `cycle_fn` optional --- benches/dataflow.rs | 4 +++ book/src/cycles.md | 2 ++ .../salsa-macro-rules/src/setup_tracked_fn.rs | 6 +++-- .../src/unexpected_cycle_recovery.rs | 6 ++--- components/salsa-macros/src/tracked_fn.rs | 7 ++--- src/cycle.rs | 6 ++++- src/function.rs | 24 ++++++++++++----- src/function/execute.rs | 6 ++++- src/function/memo.rs | 2 ++ tests/backtrace.rs | 11 +------- tests/cycle.rs | 27 ++++--------------- tests/cycle_accumulate.rs | 2 ++ tests/cycle_initial_call_back_into_cycle.rs | 10 +------ tests/cycle_initial_call_query.rs | 10 +------ tests/cycle_maybe_changed_after.rs | 26 +++--------------- tests/cycle_output.rs | 11 +------- tests/cycle_recovery_call_back_into_cycle.rs | 8 +++++- tests/cycle_recovery_call_query.rs | 2 ++ tests/cycle_regression_455.rs | 12 +-------- tests/cycle_tracked.rs | 25 +++-------------- tests/cycle_tracked_own_input.rs | 13 ++------- tests/dataflow.rs | 4 +++ tests/parallel/cycle_a_t1_b_t2.rs | 14 ++-------- tests/parallel/cycle_ab_peeping_c.rs | 14 ++-------- tests/parallel/cycle_iteration_mismatch.rs | 21 +++++---------- tests/parallel/cycle_nested_deep.rs | 20 ++++---------- .../parallel/cycle_nested_deep_conditional.rs | 20 ++++---------- .../cycle_nested_deep_conditional_changed.rs | 21 ++++----------- tests/parallel/cycle_nested_deep_panic.rs | 18 +++---------- tests/parallel/cycle_nested_three_threads.rs | 16 +++-------- .../cycle_nested_three_threads_changed.rs | 17 +++--------- tests/parallel/cycle_panic.rs | 8 +++++- .../cycle_provisional_depending_on_itself.rs | 15 +++-------- 33 files changed, 127 insertions(+), 281 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index db099c6b2..d1acfd27b 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -76,6 +76,8 @@ fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _def: Definition, @@ -89,6 +91,8 @@ fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _use: Use, diff --git a/book/src/cycles.md b/book/src/cycles.md index 2215b8ff3..2e2c6e7b8 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -21,6 +21,8 @@ fn initial(_db: &dyn KnobsDatabase) -> u32 { } ``` +The `cycle_fn` is optional. The default implementation always returns `Iterate`. + If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should resume starting with the given value (which should be a value that will converge quickly). The cycle will iterate until it converges: that is, until two successive iterations produce the same result. diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 945021f3a..961b5b4f8 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -308,11 +308,13 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, + id: salsa::Id, + last_provisional_value: &Self::Output<$db_lt>, value: &Self::Output<$db_lt>, - count: u32, + iteration_count: u32, ($($input_id),*): ($($interned_input_ty),*) ) -> $zalsa::CycleRecoveryAction> { - $($cycle_recovery_fn)*(db, value, count, $($input_id),*) + $($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*) } fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index 8d56d54f3..aa6161d28 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,10 +3,10 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ - std::mem::drop($db); + ($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{ + let (_db, _id, _last_provisional_value, _new_value, _count) = ($db, $id, $last_provisional_value, $new_value, $count); std::mem::drop(($($other_inputs,)*)); - panic!("cannot recover from cycle") + salsa::CycleRecoveryAction::Iterate }}; } diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 5c6fab7d2..12f9170c7 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -286,9 +286,10 @@ impl Macro { self.args.cycle_fn.as_ref().unwrap(), "must provide `cycle_initial` along with `cycle_fn`", )), - (None, Some(_), None) => Err(syn::Error::new_spanned( - self.args.cycle_initial.as_ref().unwrap(), - "must provide `cycle_fn` along with `cycle_initial`", + (None, Some(cycle_initial), None) => Ok(( + quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((#cycle_initial)), + quote!(Fixpoint), )), (None, None, Some(cycle_result)) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), diff --git a/src/cycle.rs b/src/cycle.rs index c9a9b82c1..09ec51525 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -70,7 +70,11 @@ pub enum CycleRecoveryAction { /// Iterate the cycle again to look for a fixpoint. Iterate, - /// Cut off iteration and use the given result value for this query. + /// Use the given value as the result for the current iteration instead + /// of the value computed by the query function. + /// + /// Returning `Fallback` doesn't stop the fixpoint iteration. It only + /// allows the iterate function to return a different value. Fallback(T), } diff --git a/src/function.rs b/src/function.rs index 259dff14b..1cf3e9478 100644 --- a/src/function.rs +++ b/src/function.rs @@ -94,20 +94,32 @@ pub trait Configuration: Any { /// value from the latest iteration of this cycle. `count` is the number of cycle iterations /// completed so far. /// - /// # Iteration count semantics + /// # Id /// - /// The `count` parameter isn't guaranteed to start from zero or to be contiguous: + /// The id can be used to uniquely identify the query instance. This can be helpful + /// if the cycle function has to re-identify a value it returned previously. /// - /// * **Initial value**: `count` may be non-zero on the first call for a given query if that + /// # Values + /// + /// The `last_provisional_value` is the value from the previous iteration of this cycle + /// and `value` is the new value that was computed in the current iteration. + /// + /// # Iteration count + /// + /// The `iteration` parameter isn't guaranteed to start from zero or to be contiguous: + /// + /// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that /// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case, - /// `count` continues from the nested cycle's iteration count rather than resetting to zero. + /// `iteration` continues from the nested cycle's iteration count rather than resetting to zero. /// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle /// and the value for this query remains unchanged for one iteration. But the outer cycle might /// keep iterating because other heads keep changing. fn recover_from_cycle<'db>( db: &'db Self::DbView, - value: &Self::Output<'db>, - count: u32, + id: Id, + last_provisional_value: &Self::Output<'db>, + new_value: &Self::Output<'db>, + iteration: u32, input: Self::Input<'db>, ) -> CycleRecoveryAction>; diff --git a/src/function/execute.rs b/src/function/execute.rs index aa4339bef..3acfaadc8 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -320,7 +320,7 @@ where I am a cycle head, comparing last provisional value with new value" ); - let this_converged = C::values_equal(&new_value, last_provisional_value); + let mut this_converged = C::values_equal(&new_value, last_provisional_value); // If this is the outermost cycle, use the maximum iteration count of all cycles. // This is important for when later iterations introduce new cycle heads (that then @@ -341,6 +341,8 @@ where // cycle-recovery function what to do: match C::recover_from_cycle( db, + id, + last_provisional_value, &new_value, iteration_count.as_u32(), C::id_to_input(zalsa, id), @@ -351,6 +353,8 @@ where "{database_key_index:?}: execute: user cycle_fn says to fall back" ); new_value = fallback_value; + + this_converged = C::values_equal(&new_value, last_provisional_value); } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 2e84bc04f..8fe0c1dd8 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -557,6 +557,8 @@ mod _memory_usage { fn recover_from_cycle<'db>( _: &'db Self::DbView, + _: Id, + _: &Self::Output<'db>, _: &Self::Output<'db>, _: u32, _: Self::Input<'db>, diff --git a/tests/backtrace.rs b/tests/backtrace.rs index b611cac86..0adf517cd 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -42,7 +42,7 @@ fn query_f(db: &dyn Database, thing: Thing) -> String { query_cycle(db, thing) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_cycle(db: &dyn Database, thing: Thing) -> String { let backtrace = query_cycle(db, thing); if backtrace.is_empty() { @@ -56,15 +56,6 @@ fn cycle_initial(_db: &dyn salsa::Database, _thing: Thing) -> String { String::new() } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &str, - _count: u32, - _thing: Thing, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[test] fn backtrace_works() { let db = DatabaseImpl::default(); diff --git a/tests/cycle.rs b/tests/cycle.rs index 5e46cc0be..0c4d686af 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -125,6 +125,8 @@ const MAX_ITERATIONS: u32 = 3; /// iterating again. fn cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Value, value: &Value, count: u32, _inputs: Inputs, @@ -440,7 +442,6 @@ fn two_fallback_count() { /// /// Two-query cycle, falls back but fallback does not converge. #[test] -#[should_panic(expected = "too many cycle iterations")] fn two_fallback_diverge() { let mut db = DbImpl::new(); let a_in = Inputs::new(&db, vec![]); @@ -1167,7 +1168,7 @@ fn repeat_query_participating_in_cycle() { value: u32, } - #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + #[salsa::tracked(cycle_initial=initial)] fn head(db: &dyn Db, input: Input) -> u32 { let a = query_a(db, input); @@ -1178,15 +1179,6 @@ fn repeat_query_participating_in_cycle() { 0 } - fn cycle_recover( - _db: &dyn Db, - _value: &u32, - _count: u32, - _input: Input, - ) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate - } - #[salsa::tracked] fn query_a(db: &dyn Db, input: Input) -> u32 { let _ = query_b(db, input); @@ -1281,7 +1273,7 @@ fn repeat_query_participating_in_cycle2() { value: u32, } - #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + #[salsa::tracked(cycle_initial=initial)] fn head(db: &dyn Db, input: Input) -> u32 { let a = query_a(db, input); @@ -1292,16 +1284,7 @@ fn repeat_query_participating_in_cycle2() { 0 } - fn cycle_recover( - _db: &dyn Db, - _value: &u32, - _count: u32, - _input: Input, - ) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate - } - - #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + #[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn Db, input: Input) -> u32 { let _ = query_hot(db, input); query_b(db, input) diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index e06fe033b..8148e952d 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -50,6 +50,8 @@ fn cycle_initial(_db: &dyn LogDatabase, _file: File) -> Vec { fn cycle_fn( _db: &dyn LogDatabase, + _id: salsa::Id, + _last_provisional_value: &[u32], _value: &[u32], _count: u32, _file: File, diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs index 326fd46c7..e56c4c4d1 100644 --- a/tests/cycle_initial_call_back_into_cycle.rs +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -7,7 +7,7 @@ fn initial_value(db: &dyn salsa::Database) -> u32 { query(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query(db: &dyn salsa::Database) -> u32 { let val = query(db); if val < 5 { @@ -21,14 +21,6 @@ fn cycle_initial(db: &dyn salsa::Database) -> u32 { initial_value(db) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &u32, - _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[test_log::test] #[should_panic(expected = "dependency graph cycle")] fn the_test() { diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs index cb10e77e1..2212ef958 100644 --- a/tests/cycle_initial_call_query.rs +++ b/tests/cycle_initial_call_query.rs @@ -7,7 +7,7 @@ fn initial_value(_db: &dyn salsa::Database) -> u32 { 0 } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query(db: &dyn salsa::Database) -> u32 { let val = query(db); if val < 5 { @@ -21,14 +21,6 @@ fn cycle_initial(db: &dyn salsa::Database) -> u32 { initial_value(db) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &u32, - _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[test_log::test] fn the_test() { let db = salsa::DatabaseImpl::default(); diff --git a/tests/cycle_maybe_changed_after.rs b/tests/cycle_maybe_changed_after.rs index 6ee42d3a5..8c00c484a 100644 --- a/tests/cycle_maybe_changed_after.rs +++ b/tests/cycle_maybe_changed_after.rs @@ -4,7 +4,7 @@ mod common; use crate::common::EventLoggerDatabase; -use salsa::{CycleRecoveryAction, Database, Durability, Setter}; +use salsa::{Database, Durability, Setter}; #[salsa::input(debug)] struct Input { @@ -17,7 +17,7 @@ struct Output<'db> { value: u32, } -#[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)] +#[salsa::tracked(cycle_initial=query_a_initial)] fn query_c<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { query_d(db, input) } @@ -40,21 +40,12 @@ fn query_a_initial(_db: &dyn Database, _input: Input) -> u32 { 0 } -fn query_a_recover( - _db: &dyn Database, - _output: &u32, - _count: u32, - _input: Input, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - /// Only the first iteration depends on `input.value`. It's important that the entire query /// reruns if `input.value` changes. That's why salsa has to carry-over the inputs and outputs /// from the previous iteration. #[test_log::test] fn first_iteration_input_only() { - #[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)] + #[salsa::tracked(cycle_initial=query_a_initial)] fn query_a<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { query_b(db, input) } @@ -126,7 +117,7 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() { scope: Scope<'db>, } - #[salsa::tracked(cycle_fn=head_recover, cycle_initial=head_initial)] + #[salsa::tracked(cycle_initial=head_initial)] fn cycle_head<'db>(db: &'db dyn salsa::Database, input: Input) -> Option> { let b = cycle_outer(db, input); tracing::info!("query_b = {b:?}"); @@ -141,15 +132,6 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() { None } - fn head_recover<'db>( - _db: &'db dyn Database, - _output: &Option>, - _count: u32, - _input: Input, - ) -> CycleRecoveryAction>> { - CycleRecoveryAction::Iterate - } - #[salsa::tracked] fn cycle_outer<'db>(db: &'db dyn salsa::Database, input: Input) -> Option> { cycle_participant(db, input) diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 59b789aa4..02a3b569f 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -35,7 +35,7 @@ fn query_a(db: &dyn Db, input: InputValue) -> u32 { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_b(db: &dyn Db, input: InputValue) -> u32 { query_a(db, input) } @@ -44,15 +44,6 @@ fn cycle_initial(_db: &dyn Db, _input: InputValue) -> u32 { 0 } -fn cycle_fn( - _db: &dyn Db, - _value: &u32, - _count: u32, - _input: InputValue, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[salsa::tracked] fn query_c(db: &dyn Db, input: InputValue) -> u32 { input.value(db) diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 805a2be7b..358f988ad 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -25,7 +25,13 @@ fn cycle_initial(_db: &dyn ValueDatabase) -> u32 { 0 } -fn cycle_fn(db: &dyn ValueDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { +fn cycle_fn( + db: &dyn ValueDatabase, + _id: salsa::Id, + _last_provisional_value: &u32, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { salsa::CycleRecoveryAction::Fallback(fallback_value(db)) } diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index dcc31abeb..37341a202 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -23,6 +23,8 @@ fn cycle_initial(_db: &dyn salsa::Database) -> u32 { fn cycle_fn( db: &dyn salsa::Database, + _id: salsa::Id, + _last_provisional_value: &u32, _value: &u32, _count: u32, ) -> salsa::CycleRecoveryAction { diff --git a/tests/cycle_regression_455.rs b/tests/cycle_regression_455.rs index 99c193ab9..a083cb996 100644 --- a/tests/cycle_regression_455.rs +++ b/tests/cycle_regression_455.rs @@ -7,21 +7,11 @@ fn memoized(db: &dyn Database, input: MyInput) -> u32 { memoized_a(db, MyTracked::new(db, input.field(db))) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn memoized_a<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> u32 { MyTracked::new(db, 0); memoized_b(db, tracked) } - -fn cycle_fn<'db>( - _db: &'db dyn Database, - _value: &u32, - _count: u32, - _input: MyTracked<'db>, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - fn cycle_initial(_db: &dyn Database, _input: MyTracked) -> u32 { 0 } diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 2e0c2cfd0..5ee4e1620 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -4,7 +4,7 @@ mod common; use crate::common::{EventLoggerDatabase, LogDatabase}; use expect_test::expect; -use salsa::{CycleRecoveryAction, Database, Setter}; +use salsa::{Database, Setter}; #[derive(Clone, Debug, Eq, PartialEq, Hash, salsa::Update)] struct Graph<'db> { @@ -86,7 +86,7 @@ fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> { } /// Computes the minimum cost from the node with offset `0` to the given node. -#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)] +#[salsa::tracked(cycle_initial=max_initial)] fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize { let mut min_cost = usize::MAX; let graph = create_graph(db, node.graph(db)); @@ -114,15 +114,6 @@ fn max_initial(_db: &dyn Database, _node: Node) -> usize { usize::MAX } -fn cycle_recover( - _db: &dyn Database, - _value: &usize, - _count: u32, - _inputs: Node, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - /// Tests for cycles where the cycle head is stored on a tracked struct /// and that tracked struct is freed in a later revision. #[test] @@ -215,7 +206,7 @@ struct IterationNode<'db> { /// 3. Second iteration: returns `[iter_0, iter_1]` /// 4. Third iteration (only for variant=1): returns `[iter_0, iter_1, iter_2]` /// 5. Further iterations: no change, fixpoint reached -#[salsa::tracked(cycle_fn=cycle_recover_with_structs, cycle_initial=initial_with_structs)] +#[salsa::tracked(cycle_initial=initial_with_structs)] fn create_tracked_in_cycle<'db>( db: &'db dyn Database, input: GraphInput, @@ -259,16 +250,6 @@ fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec( - _db: &'db dyn Database, - _value: &Vec>, - _iteration: u32, - _input: GraphInput, -) -> CycleRecoveryAction>> { - CycleRecoveryAction::Iterate -} - #[test_log::test] fn test_cycle_with_fixpoint_structs() { let mut db = EventLoggerDatabase::default(); diff --git a/tests/cycle_tracked_own_input.rs b/tests/cycle_tracked_own_input.rs index 38218f1a7..79035bab5 100644 --- a/tests/cycle_tracked_own_input.rs +++ b/tests/cycle_tracked_own_input.rs @@ -11,7 +11,7 @@ mod common; use crate::common::{EventLoggerDatabase, LogDatabase}; use expect_test::expect; -use salsa::{CycleRecoveryAction, Database, Setter}; +use salsa::{Database, Setter}; #[salsa::input(debug)] struct ClassNode { @@ -52,7 +52,7 @@ impl Type<'_> { } } -#[salsa::tracked(cycle_fn=infer_class_recover, cycle_initial=infer_class_initial)] +#[salsa::tracked(cycle_initial=infer_class_initial)] fn infer_class<'db>(db: &'db dyn salsa::Database, node: ClassNode) -> Type<'db> { Type::Class(Class::new( db, @@ -85,15 +85,6 @@ fn infer_class_initial(_db: &'_ dyn Database, _node: ClassNode) -> Type<'_> { Type::Unknown } -fn infer_class_recover<'db>( - _db: &'db dyn Database, - _type: &Type<'db>, - _count: u32, - _inputs: ClassNode, -) -> CycleRecoveryAction> { - CycleRecoveryAction::Iterate -} - #[test] fn main() { let mut db = EventLoggerDatabase::default(); diff --git a/tests/dataflow.rs b/tests/dataflow.rs index 960cc33f5..793870322 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -77,6 +77,8 @@ fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _def: Definition, @@ -90,6 +92,8 @@ fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _use: Use, diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index ad21b7963..6a434099e 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -15,8 +15,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -26,7 +24,7 @@ const MAX: CycleValue = CycleValue(3); // Signal 1: T1 has entered `query_a` // Signal 2: T2 has entered `query_b` -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { db.signal(1); @@ -36,7 +34,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { // Wait for Thread T1 to enter `query_a` before we continue. db.wait_for(1); @@ -47,14 +45,6 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs index 134fe7429..8ed2b4fb6 100644 --- a/tests/parallel/cycle_ab_peeping_c.rs +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -9,8 +9,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -18,7 +16,7 @@ const MIN: CycleValue = CycleValue(0); const MID: CycleValue = CycleValue(5); const MAX: CycleValue = CycleValue(10); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { let b_value = query_b(db); @@ -32,19 +30,11 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { b_value } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let a_value = query_a(db); diff --git a/tests/parallel/cycle_iteration_mismatch.rs b/tests/parallel/cycle_iteration_mismatch.rs index 17cc60108..61d1da01d 100644 --- a/tests/parallel/cycle_iteration_mismatch.rs +++ b/tests/parallel/cycle_iteration_mismatch.rs @@ -7,7 +7,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -16,14 +15,14 @@ const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(5); // Query A: First cycle head - will iterate multiple times -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { let b = query_b(db); CycleValue(b.0 + 1).min(MAX) } // Query B: Depends on C and D, creating complex dependencies -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c = query_c(db); let d = query_d(db); @@ -31,7 +30,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { } // Query C: Creates a cycle back to A -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { let a = query_a(db); // Also depends on E to create more complex cycle structure @@ -40,14 +39,14 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { } // Query D: Part of a separate cycle with E -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { let e = query_e(db); CycleValue(e.0 + 1).min(MAX) } // Query E: Depends back on D and F -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { let d = query_d(db); let f = query_f(db); @@ -55,7 +54,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { } // Query F: Creates another cycle that might have different iteration count -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_f(db: &dyn KnobsDatabase) -> CycleValue { // Create a cycle that depends on earlier queries let b = query_b(db); @@ -63,14 +62,6 @@ fn query_f(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(b.0.max(e.0)) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep.rs b/tests/parallel/cycle_nested_deep.rs index f2b355616..3d46bbbc5 100644 --- a/tests/parallel/cycle_nested_deep.rs +++ b/tests/parallel/cycle_nested_deep.rs @@ -9,26 +9,24 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(3); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c_value = query_c(db); CycleValue(c_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { let d_value = query_d(db); let e_value = query_e(db); @@ -38,24 +36,16 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(d_value.0.max(e_value.0).max(b_value.0).max(a_value.0)) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional.rs b/tests/parallel/cycle_nested_deep_conditional.rs index 4eff75189..544342e07 100644 --- a/tests/parallel/cycle_nested_deep_conditional.rs +++ b/tests/parallel/cycle_nested_deep_conditional.rs @@ -14,26 +14,24 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(3); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c_value = query_c(db); CycleValue(c_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { let d_value = query_d(db); @@ -47,24 +45,16 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs index 51d506456..03423b09a 100644 --- a/tests/parallel/cycle_nested_deep_conditional_changed.rs +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -15,8 +15,6 @@ //! Specifically, the maybe_changed_after flow. use crate::sync::thread; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -28,18 +26,18 @@ struct Input { value: u32, } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn salsa::Database, input: Input) -> CycleValue { query_b(db, input) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn salsa::Database, input: Input) -> CycleValue { let c_value = query_c(db, input); CycleValue(c_value.0 + input.value(db).max(1)).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { let d_value = query_d(db, input); @@ -53,25 +51,16 @@ fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn salsa::Database, input: Input) -> CycleValue { query_c(db, input) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn salsa::Database, input: Input) -> CycleValue { query_c(db, input) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &CycleValue, - _count: u32, - _input: Input, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_panic.rs b/tests/parallel/cycle_nested_deep_panic.rs index 8b89f362a..4356489c3 100644 --- a/tests/parallel/cycle_nested_deep_panic.rs +++ b/tests/parallel/cycle_nested_deep_panic.rs @@ -8,20 +8,18 @@ use crate::{Knobs, KnobsDatabase}; use std::fmt; use std::panic::catch_unwind; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(3); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c_value = query_c(db); CycleValue(c_value.0 + 1).min(MAX) @@ -41,24 +39,16 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs index 22232bd85..728fc3e70 100644 --- a/tests/parallel/cycle_nested_three_threads.rs +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -17,8 +17,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -29,7 +27,7 @@ const MAX: CycleValue = CycleValue(3); // Signal 2: T2 has entered `query_b` // Signal 3: T3 has entered `query_c` -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { db.signal(1); db.wait_for(3); @@ -37,7 +35,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { db.wait_for(1); db.signal(2); @@ -47,7 +45,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(c_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { db.wait_for(2); db.signal(3); @@ -57,14 +55,6 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0.max(b_value.0)) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads_changed.rs b/tests/parallel/cycle_nested_three_threads_changed.rs index ccd92a407..626b3ef90 100644 --- a/tests/parallel/cycle_nested_three_threads_changed.rs +++ b/tests/parallel/cycle_nested_three_threads_changed.rs @@ -19,7 +19,7 @@ use crate::sync; use crate::sync::thread; -use salsa::{CycleRecoveryAction, DatabaseImpl, Setter as _}; +use salsa::{DatabaseImpl, Setter as _}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -36,33 +36,24 @@ struct Input { // Signal 2: T2 has entered `query_b` // Signal 3: T3 has entered `query_c` -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn salsa::Database, input: Input) -> CycleValue { query_b(db, input) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn salsa::Database, input: Input) -> CycleValue { let c_value = query_c(db, input); CycleValue(c_value.0 + input.value(db)).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { let a_value = query_a(db, input); let b_value = query_b(db, input); CycleValue(a_value.0.max(b_value.0)) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &CycleValue, - _count: u32, - _input: Input, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index a713809b7..13c988f8f 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -18,7 +18,13 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 { query_a(db) + 1 } -fn cycle_fn(_db: &dyn KnobsDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { +fn cycle_fn( + _db: &dyn KnobsDatabase, + _id: salsa::Id, + _last_provisional_value: &u32, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { panic!("cancel!") } diff --git a/tests/parallel/cycle_provisional_depending_on_itself.rs b/tests/parallel/cycle_provisional_depending_on_itself.rs index ba3645fd5..bb615210e 100644 --- a/tests/parallel/cycle_provisional_depending_on_itself.rs +++ b/tests/parallel/cycle_provisional_depending_on_itself.rs @@ -19,7 +19,6 @@ //! 3. `t1`: Iterates on `a`, finalizes the memo use crate::sync::thread; -use salsa::CycleRecoveryAction; use crate::setup::{Knobs, KnobsDatabase}; @@ -29,12 +28,12 @@ struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(1); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { // Wait for thread 2 to have entered `query_c`. tracing::debug!("Wait for signal 1 from thread 2"); @@ -55,7 +54,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { tracing::debug!("query_c: signaling thread1 to call c"); db.signal(1); @@ -68,14 +67,6 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { b } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } From 25b3ef146cfa2615f4ec82760bd0c22b454d0a12 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 24 Oct 2025 17:58:45 +0200 Subject: [PATCH 08/21] Fix cache invalidation when cycle head becomes non-head (#1014) * Fix cache invalidation when cycle head becomes non-head * Discard changes to src/function/fetch.rs * Inline comment --- src/cycle.rs | 15 +++++++++++++-- src/function.rs | 18 +++++++----------- src/function/execute.rs | 10 +++++++--- src/function/maybe_changed_after.rs | 28 ++++++++++++++++++++++++++-- src/function/memo.rs | 21 ++++++++++++++------- src/ingredient.rs | 14 ++++++-------- 6 files changed, 73 insertions(+), 33 deletions(-) diff --git a/src/cycle.rs b/src/cycle.rs index 09ec51525..fcbadf891 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -346,6 +346,7 @@ impl CycleHeads { if *removed { *removed = false; + existing.iteration_count.store_mut(iteration_count); true } else { @@ -468,11 +469,12 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new())) } -#[derive(Debug, PartialEq, Eq)] -pub enum ProvisionalStatus { +#[derive(Debug)] +pub enum ProvisionalStatus<'db> { Provisional { iteration: IterationCount, verified_at: Revision, + cycle_heads: &'db CycleHeads, }, Final { iteration: IterationCount, @@ -480,3 +482,12 @@ pub enum ProvisionalStatus { }, FallbackImmediate, } + +impl<'db> ProvisionalStatus<'db> { + pub(crate) fn cycle_heads(&self) -> &'db CycleHeads { + match self { + ProvisionalStatus::Provisional { cycle_heads, .. } => cycle_heads, + _ => empty_cycle_heads(), + } + } +} diff --git a/src/function.rs b/src/function.rs index 1cf3e9478..512c8ba70 100644 --- a/src/function.rs +++ b/src/function.rs @@ -7,10 +7,7 @@ use std::ptr::NonNull; use std::sync::atomic::Ordering; use std::sync::OnceLock; -use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, - ProvisionalStatus, -}; +use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::hash::{FxHashSet, FxIndexSet}; @@ -357,7 +354,11 @@ where /// /// Otherwise, the value is still provisional. For both final and provisional, it also /// returns the iteration in which this memo was created (always 0 except for cycle heads). - fn provisional_status(&self, zalsa: &Zalsa, input: Id) -> Option { + fn provisional_status<'db>( + &self, + zalsa: &'db Zalsa, + input: Id, + ) -> Option> { let memo = self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))?; @@ -377,6 +378,7 @@ where ProvisionalStatus::Provisional { iteration, verified_at: memo.verified_at.load(), + cycle_heads: memo.cycle_heads(), } }) } @@ -416,12 +418,6 @@ where self.sync_table.mark_as_transfer_target(key_index) } - fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads { - self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) - .map(|memo| memo.cycle_heads()) - .unwrap_or(empty_cycle_heads()) - } - /// Attempts to claim `key_index` without blocking. /// /// * [`WaitForResult::Running`] if the `key_index` is running on another thread. It's up to the caller to block on the other thread diff --git a/src/function/execute.rs b/src/function/execute.rs index 3acfaadc8..5e3c226be 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -248,9 +248,11 @@ where let ingredient = zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); - for nested_head in - ingredient.cycle_heads(zalsa, head.database_key_index.key_index()) - { + let provisional_status = ingredient + .provisional_status(zalsa, head.database_key_index.key_index()) + .expect("cycle head memo must have been created during the execution"); + + for nested_head in provisional_status.cycle_heads() { let nested_as_tuple = ( nested_head.database_key_index, nested_head.iteration_count.load(), @@ -442,6 +444,8 @@ where // Update the iteration count of this cycle head, but only after restoring // the cycle heads array (or this becomes a no-op). + // We don't call the same method on `cycle_heads` because that one doens't update + // the `memo.iteration_count` completed_query.revisions.set_cycle_heads(cycle_heads); completed_query .revisions diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 62839e865..4198631b9 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -484,8 +484,9 @@ where // Always return `false` for cycle initial values "unless" they are running in the same thread. if cycle_heads - .iter() - .all(|head| head.database_key_index == memo_database_key_index) + .iter_not_eq(memo_database_key_index) + .next() + .is_none() { // SAFETY: We do not access the query stack reentrantly. let on_stack = unsafe { @@ -508,6 +509,8 @@ where head_iteration_count, memo_iteration_count: current_iteration_count, verified_at: head_verified_at, + cycle_heads, + database_key_index: head_database_key, } => { if head_verified_at != memo_verified_at { return false; @@ -516,6 +519,27 @@ where if head_iteration_count != current_iteration_count { return false; } + + // Check if the memo is still a cycle head and hasn't changed + // to a normal cycle participant. This is to force re-execution in + // a scenario like this: + // + // * There's a nested cycle with the outermost query A + // * B participates in the cycle and is a cycle head in the first few iterations + // * B becomes a non-cycle head in a later iteration + // * There's a query `C` that has `B` as its cycle head + // + // The crucial point is that `B` switches from being a cycle head to being a regular cycle participant. + // The issue with that is that `A` doesn't update `B`'s `iteration_count `when the iteration completes + // because it only does that for cycle heads (and collecting all queries participating in a query would be sort of expensive?). + // + // When we now pull `C` in a later iteration, `validate_same_iteration` iterates over all its cycle heads (`B`), + // and check if the iteration count still matches. Which is the case because `A` didn't update `B`'s iteration count. + // + // That's why we also check if `B` is still a cycle head in the current iteration. + if !cycle_heads.contains(&head_database_key) { + return false; + } } _ => { return false; diff --git a/src/function/memo.rs b/src/function/memo.rs index 8fe0c1dd8..200f83a4d 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -409,9 +409,11 @@ mod persistence { pub(super) enum TryClaimHeadsResult<'me> { /// Claiming the cycle head results in a cycle. Cycle { + database_key_index: DatabaseKeyIndex, head_iteration_count: IterationCount, memo_iteration_count: IterationCount, verified_at: Revision, + cycle_heads: &'me CycleHeads, }, /// The cycle head is not finalized, but it can be claimed. @@ -458,23 +460,28 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { let provisional_status = ingredient .provisional_status(self.zalsa, head_key_index) .expect("cycle head memo to exist"); - let (current_iteration_count, verified_at) = match provisional_status { + let (current_iteration_count, verified_at, cycle_heads) = match provisional_status { ProvisionalStatus::Provisional { iteration, verified_at, - } - | ProvisionalStatus::Final { + cycle_heads, + } => (iteration, verified_at, cycle_heads), + ProvisionalStatus::Final { iteration, verified_at, - } => (iteration, verified_at), - ProvisionalStatus::FallbackImmediate => { - (IterationCount::initial(), self.zalsa.current_revision()) - } + } => (iteration, verified_at, empty_cycle_heads()), + ProvisionalStatus::FallbackImmediate => ( + IterationCount::initial(), + self.zalsa.current_revision(), + empty_cycle_heads(), + ), }; Some(TryClaimHeadsResult::Cycle { + database_key_index: head_database_key, memo_iteration_count: current_iteration_count, head_iteration_count: head.iteration_count.load(), + cycle_heads, verified_at, }) } diff --git a/src/ingredient.rs b/src/ingredient.rs index 9b377e4d1..6fe525c4f 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,7 +1,7 @@ use std::any::{Any, TypeId}; use std::fmt; -use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount, ProvisionalStatus}; +use crate::cycle::{IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::hash::{FxHashSet, FxIndexSet}; @@ -74,16 +74,14 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// Is it a provisional value or has it been finalized and in which iteration. /// /// Returns `None` if `input` doesn't exist. - fn provisional_status(&self, _zalsa: &Zalsa, _input: Id) -> Option { + fn provisional_status<'db>( + &self, + _zalsa: &'db Zalsa, + _input: Id, + ) -> Option> { unreachable!("provisional_status should only be called on cycle heads and only functions can be cycle heads"); } - /// Returns the cycle heads for this ingredient. - fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads { - _ = (zalsa, input); - empty_cycle_heads() - } - /// Invoked when the current thread needs to wait for a result for the given `key_index`. /// This call doesn't block the current thread. Instead, it's up to the caller to block /// in case `key_index` is [running](`WaitForResult::Running`) on another thread. From e8ddb4dbf7f0adbfa951a6f6e793a2ce3b165355 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 26 Oct 2025 15:44:49 +0100 Subject: [PATCH 09/21] pref: Add `SyncTable::peek_claim` fast path for `function::Ingredient::wait_for` (#1011) --- src/function.rs | 4 +-- src/function/sync.rs | 82 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/src/function.rs b/src/function.rs index 512c8ba70..434a895a5 100644 --- a/src/function.rs +++ b/src/function.rs @@ -428,11 +428,11 @@ where fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { match self .sync_table - .try_claim(zalsa, key_index, Reentrancy::Deny) + .peek_claim(zalsa, key_index, Reentrancy::Deny) { ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on), ClaimResult::Cycle { inner } => WaitForResult::Cycle { inner }, - ClaimResult::Claimed(_) => WaitForResult::Available, + ClaimResult::Claimed(()) => WaitForResult::Available, } } diff --git a/src/function/sync.rs b/src/function/sync.rs index 02f1bffd0..c9a74a307 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -20,7 +20,7 @@ pub(crate) struct SyncTable { ingredient: IngredientIndex, } -pub(crate) enum ClaimResult<'a> { +pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. @@ -31,7 +31,7 @@ pub(crate) enum ClaimResult<'a> { inner: bool, }, /// Successfully claimed the query. - Claimed(ClaimGuard<'a>), + Claimed(Guard), } pub(crate) struct SyncState { @@ -87,10 +87,7 @@ impl SyncTable { } }; - let &mut SyncState { - ref mut anyone_waiting, - .. - } = occupied_entry.into_mut(); + let SyncState { anyone_waiting, .. } = occupied_entry.into_mut(); // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this @@ -125,6 +122,51 @@ impl SyncTable { } } + /// Claims the given key index, or blocks if it is running on another thread. + pub(crate) fn peek_claim<'me>( + &'me self, + zalsa: &'me Zalsa, + key_index: Id, + reentrant: Reentrancy, + ) -> ClaimResult<'me, ()> { + let mut write = self.syncs.lock(); + match write.entry(key_index) { + std::collections::hash_map::Entry::Occupied(occupied_entry) => { + let id = match occupied_entry.get().id { + SyncOwner::Thread(id) => id, + SyncOwner::Transferred => { + return match self.peek_claim_transferred(zalsa, occupied_entry, reentrant) { + Ok(claimed) => claimed, + Err(other_thread) => match other_thread.block(write) { + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + BlockResult::Running(running) => ClaimResult::Running(running), + }, + } + } + }; + + let SyncState { anyone_waiting, .. } = occupied_entry.into_mut(); + + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + *anyone_waiting = true; + match zalsa.runtime().block( + DatabaseKeyIndex::new(self.ingredient, key_index), + id, + write, + ) { + BlockResult::Running(blocked_on) => ClaimResult::Running(blocked_on), + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + } + } + std::collections::hash_map::Entry::Vacant(_) => ClaimResult::Claimed(()), + } + } + #[cold] #[inline(never)] fn try_claim_transferred<'me>( @@ -179,6 +221,34 @@ impl SyncTable { } } + #[cold] + #[inline(never)] + fn peek_claim_transferred<'me>( + &'me self, + zalsa: &'me Zalsa, + mut entry: OccupiedEntry, + reentrant: Reentrancy, + ) -> Result, Box>> { + let key_index = *entry.key(); + let database_key_index = DatabaseKeyIndex::new(self.ingredient, key_index); + let thread_id = thread::current().id(); + + match zalsa + .runtime() + .block_transferred(database_key_index, thread_id) + { + BlockTransferredResult::ImTheOwner if reentrant.is_allow() => { + Ok(ClaimResult::Claimed(())) + } + BlockTransferredResult::ImTheOwner => Ok(ClaimResult::Cycle { inner: true }), + BlockTransferredResult::OwnedBy(other_thread) => { + entry.get_mut().anyone_waiting = true; + Err(other_thread) + } + BlockTransferredResult::Released => Ok(ClaimResult::Claimed(())), + } + } + /// Marks `key_index` as a transfer target. /// /// Returns the `SyncOwnerId` of the thread that currently owns this query. From cdd0b85516a52c18b8a6d17a2279a96ed6c3e198 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Mon, 27 Oct 2025 21:27:09 +0900 Subject: [PATCH 10/21] Expose the Input query Id with cycle_initial (#1015) --- benches/dataflow.rs | 4 ++-- components/salsa-macro-rules/src/setup_tracked_fn.rs | 4 ++-- .../salsa-macro-rules/src/unexpected_cycle_recovery.rs | 2 +- src/function.rs | 6 +++++- src/function/execute.rs | 2 +- src/function/fetch.rs | 4 ++-- src/function/memo.rs | 6 +++++- tests/backtrace.rs | 2 +- tests/cycle.rs | 8 ++++---- tests/cycle_accumulate.rs | 2 +- tests/cycle_fallback_immediate.rs | 4 ++-- tests/cycle_initial_call_back_into_cycle.rs | 2 +- tests/cycle_initial_call_query.rs | 2 +- tests/cycle_maybe_changed_after.rs | 4 ++-- tests/cycle_output.rs | 2 +- tests/cycle_recovery_call_back_into_cycle.rs | 2 +- tests/cycle_recovery_call_query.rs | 2 +- tests/cycle_regression_455.rs | 2 +- tests/cycle_result_dependencies.rs | 2 +- tests/cycle_tracked.rs | 8 ++++++-- tests/cycle_tracked_own_input.rs | 2 +- tests/dataflow.rs | 4 ++-- tests/parallel/cycle_a_t1_b_t2.rs | 2 +- tests/parallel/cycle_a_t1_b_t2_fallback.rs | 4 ++-- tests/parallel/cycle_ab_peeping_c.rs | 2 +- tests/parallel/cycle_iteration_mismatch.rs | 2 +- tests/parallel/cycle_nested_deep.rs | 2 +- tests/parallel/cycle_nested_deep_conditional.rs | 2 +- tests/parallel/cycle_nested_deep_conditional_changed.rs | 2 +- tests/parallel/cycle_nested_deep_panic.rs | 2 +- tests/parallel/cycle_nested_three_threads.rs | 2 +- tests/parallel/cycle_nested_three_threads_changed.rs | 2 +- tests/parallel/cycle_panic.rs | 2 +- tests/parallel/cycle_provisional_depending_on_itself.rs | 2 +- 34 files changed, 57 insertions(+), 45 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index d1acfd27b..a548c806a 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -70,7 +70,7 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { +fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { Type::Bottom } @@ -85,7 +85,7 @@ fn def_cycle_recover( cycle_recover(value, count) } -fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { +fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { Type::Bottom } diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 961b5b4f8..8ea4e5e33 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -302,8 +302,8 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } - fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { - $($cycle_recovery_initial)*(db, $($input_id),*) + fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, id: salsa::Id, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, id, $($input_id),*) } fn recover_from_cycle<$db_lt>( diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index aa6161d28..ff03c02a2 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -12,7 +12,7 @@ macro_rules! unexpected_cycle_recovery { #[macro_export] macro_rules! unexpected_cycle_initial { - ($db:ident, $($other_inputs:ident),*) => {{ + ($db:ident, $id:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); std::mem::drop(($($other_inputs,)*)); panic!("no cycle initial value") diff --git a/src/function.rs b/src/function.rs index 434a895a5..045825e19 100644 --- a/src/function.rs +++ b/src/function.rs @@ -85,7 +85,11 @@ pub trait Configuration: Any { fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; /// Get the cycle recovery initial value. - fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; + fn cycle_initial<'db>( + db: &'db Self::DbView, + id: Id, + input: Self::Input<'db>, + ) -> Self::Output<'db>; /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return /// value from the latest iteration of this cycle. `count` is the number of cycle iterations diff --git a/src/function/execute.rs b/src/function/execute.rs index 5e3c226be..d299b0966 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -94,7 +94,7 @@ where let cycle_heads = std::mem::take(cycle_heads); let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); - new_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + new_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); completed_query = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. // When verifying this, we will see we have fallback and mark ourselves verified. diff --git a/src/function/fetch.rs b/src/function/fetch.rs index a3f3705f4..14d7a93d7 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -236,7 +236,7 @@ where inserting and returning fixpoint initial value" ); let revisions = QueryRevisions::fixpoint_initial(database_key_index); - let initial_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); self.insert_memo( zalsa, id, @@ -250,7 +250,7 @@ where ); let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); - let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + let fallback_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); let mut completed_query = active_query.pop(); completed_query .revisions diff --git a/src/function/memo.rs b/src/function/memo.rs index 200f83a4d..fd830ced3 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -558,7 +558,11 @@ mod _memory_usage { unimplemented!() } - fn cycle_initial<'db>(_: &'db Self::DbView, _: Self::Input<'db>) -> Self::Output<'db> { + fn cycle_initial<'db>( + _: &'db Self::DbView, + _: Id, + _: Self::Input<'db>, + ) -> Self::Output<'db> { unimplemented!() } diff --git a/tests/backtrace.rs b/tests/backtrace.rs index 0adf517cd..3cc5bbad0 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -52,7 +52,7 @@ fn query_cycle(db: &dyn Database, thing: Thing) -> String { } } -fn cycle_initial(_db: &dyn salsa::Database, _thing: Thing) -> String { +fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _thing: Thing) -> String { String::new() } diff --git a/tests/cycle.rs b/tests/cycle.rs index 0c4d686af..dbe0bdc19 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -173,7 +173,7 @@ fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { fold_values(inputs.values(db), u8::min) } -fn min_initial(_db: &dyn Db, _inputs: Inputs) -> Value { +fn min_initial(_db: &dyn Db, _id: salsa::Id, _inputs: Inputs) -> Value { Value::N(255) } @@ -183,7 +183,7 @@ fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { fold_values(inputs.values(db), u8::max) } -fn max_initial(_db: &dyn Db, _inputs: Inputs) -> Value { +fn max_initial(_db: &dyn Db, _id: salsa::Id, _inputs: Inputs) -> Value { Value::N(0) } @@ -1175,7 +1175,7 @@ fn repeat_query_participating_in_cycle() { a.min(2) } - fn initial(_db: &dyn Db, _input: Input) -> u32 { + fn initial(_db: &dyn Db, _id: salsa::Id, _input: Input) -> u32 { 0 } @@ -1280,7 +1280,7 @@ fn repeat_query_participating_in_cycle2() { a.min(2) } - fn initial(_db: &dyn Db, _input: Input) -> u32 { + fn initial(_db: &dyn Db, _id: salsa::Id, _input: Input) -> u32 { 0 } diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 8148e952d..49f1d06d9 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -44,7 +44,7 @@ fn check_file(db: &dyn LogDatabase, file: File) -> Vec { sorted_issues } -fn cycle_initial(_db: &dyn LogDatabase, _file: File) -> Vec { +fn cycle_initial(_db: &dyn LogDatabase, _id: salsa::Id, _file: File) -> Vec { vec![] } diff --git a/tests/cycle_fallback_immediate.rs b/tests/cycle_fallback_immediate.rs index 374978d81..64f872ad1 100644 --- a/tests/cycle_fallback_immediate.rs +++ b/tests/cycle_fallback_immediate.rs @@ -11,7 +11,7 @@ fn one_o_one(db: &dyn salsa::Database) -> u32 { val + 1 } -fn cycle_result(_db: &dyn salsa::Database) -> u32 { +fn cycle_result(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 { 100 } @@ -38,7 +38,7 @@ fn two_queries2(db: &dyn salsa::Database) -> i32 { CALLS_COUNT.fetch_add(1, Ordering::Relaxed) } -fn two_queries_cycle_result(_db: &dyn salsa::Database) -> i32 { +fn two_queries_cycle_result(_db: &dyn salsa::Database, _id: salsa::Id) -> i32 { 1 } diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs index e56c4c4d1..ab7a473a2 100644 --- a/tests/cycle_initial_call_back_into_cycle.rs +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -17,7 +17,7 @@ fn query(db: &dyn salsa::Database) -> u32 { } } -fn cycle_initial(db: &dyn salsa::Database) -> u32 { +fn cycle_initial(db: &dyn salsa::Database, _id: salsa::Id) -> u32 { initial_value(db) } diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs index 2212ef958..b16b72711 100644 --- a/tests/cycle_initial_call_query.rs +++ b/tests/cycle_initial_call_query.rs @@ -17,7 +17,7 @@ fn query(db: &dyn salsa::Database) -> u32 { } } -fn cycle_initial(db: &dyn salsa::Database) -> u32 { +fn cycle_initial(db: &dyn salsa::Database, _id: salsa::Id) -> u32 { initial_value(db) } diff --git a/tests/cycle_maybe_changed_after.rs b/tests/cycle_maybe_changed_after.rs index 8c00c484a..f411404d5 100644 --- a/tests/cycle_maybe_changed_after.rs +++ b/tests/cycle_maybe_changed_after.rs @@ -36,7 +36,7 @@ fn query_d<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { } } -fn query_a_initial(_db: &dyn Database, _input: Input) -> u32 { +fn query_a_initial(_db: &dyn Database, _id: salsa::Id, _input: Input) -> u32 { 0 } @@ -128,7 +128,7 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() { }) } - fn head_initial(_db: &dyn Database, _input: Input) -> Option> { + fn head_initial(_db: &dyn Database, _id: salsa::Id, _input: Input) -> Option> { None } diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 02a3b569f..c4a9384e0 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -40,7 +40,7 @@ fn query_b(db: &dyn Db, input: InputValue) -> u32 { query_a(db, input) } -fn cycle_initial(_db: &dyn Db, _input: InputValue) -> u32 { +fn cycle_initial(_db: &dyn Db, _id: salsa::Id, _input: InputValue) -> u32 { 0 } diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 358f988ad..4ab236565 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -21,7 +21,7 @@ fn query(db: &dyn ValueDatabase) -> u32 { } } -fn cycle_initial(_db: &dyn ValueDatabase) -> u32 { +fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 { 0 } diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index 37341a202..a227d6122 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -17,7 +17,7 @@ fn query(db: &dyn salsa::Database) -> u32 { } } -fn cycle_initial(_db: &dyn salsa::Database) -> u32 { +fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 { 0 } diff --git a/tests/cycle_regression_455.rs b/tests/cycle_regression_455.rs index a083cb996..2957e5284 100644 --- a/tests/cycle_regression_455.rs +++ b/tests/cycle_regression_455.rs @@ -12,7 +12,7 @@ fn memoized_a<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> u32 { MyTracked::new(db, 0); memoized_b(db, tracked) } -fn cycle_initial(_db: &dyn Database, _input: MyTracked) -> u32 { +fn cycle_initial(_db: &dyn Database, _id: salsa::Id, _input: MyTracked) -> u32 { 0 } diff --git a/tests/cycle_result_dependencies.rs b/tests/cycle_result_dependencies.rs index 8e025f998..d614f956e 100644 --- a/tests/cycle_result_dependencies.rs +++ b/tests/cycle_result_dependencies.rs @@ -12,7 +12,7 @@ fn has_cycle(db: &dyn Database, input: Input) -> i32 { has_cycle(db, input) } -fn cycle_result(db: &dyn Database, input: Input) -> i32 { +fn cycle_result(db: &dyn Database, _id: salsa::Id, input: Input) -> i32 { input.value(db) } diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 5ee4e1620..1a5b82ee6 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -110,7 +110,7 @@ fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize { min_cost } -fn max_initial(_db: &dyn Database, _node: Node) -> usize { +fn max_initial(_db: &dyn Database, _id: salsa::Id, _node: Node) -> usize { usize::MAX } @@ -246,7 +246,11 @@ fn create_tracked_in_cycle<'db>( } } -fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec> { +fn initial_with_structs( + _db: &dyn Database, + _id: salsa::Id, + _input: GraphInput, +) -> Vec> { vec![] } diff --git a/tests/cycle_tracked_own_input.rs b/tests/cycle_tracked_own_input.rs index 79035bab5..0359c2df2 100644 --- a/tests/cycle_tracked_own_input.rs +++ b/tests/cycle_tracked_own_input.rs @@ -81,7 +81,7 @@ fn infer_type_param<'db>(db: &'db dyn salsa::Database, node: TypeParamNode) -> T } } -fn infer_class_initial(_db: &'_ dyn Database, _node: ClassNode) -> Type<'_> { +fn infer_class_initial(_db: &'_ dyn Database, _id: salsa::Id, _node: ClassNode) -> Type<'_> { Type::Unknown } diff --git a/tests/dataflow.rs b/tests/dataflow.rs index 793870322..69c91d513 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -71,7 +71,7 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { +fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { Type::Bottom } @@ -86,7 +86,7 @@ fn def_cycle_recover( cycle_recover(value, count) } -fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { +fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { Type::Bottom } diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index 6a434099e..95b2a3d28 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -45,7 +45,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_a_t1_b_t2_fallback.rs b/tests/parallel/cycle_a_t1_b_t2_fallback.rs index b2d6631cc..b49fa0448 100644 --- a/tests/parallel/cycle_a_t1_b_t2_fallback.rs +++ b/tests/parallel/cycle_a_t1_b_t2_fallback.rs @@ -41,11 +41,11 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 { query_a(db) | OFFSET_B } -fn cycle_result_a(_db: &dyn KnobsDatabase) -> u32 { +fn cycle_result_a(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { FALLBACK_A } -fn cycle_result_b(_db: &dyn KnobsDatabase) -> u32 { +fn cycle_result_b(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { FALLBACK_B } diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs index 8ed2b4fb6..c61f3c6ae 100644 --- a/tests/parallel/cycle_ab_peeping_c.rs +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -30,7 +30,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { b_value } -fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_iteration_mismatch.rs b/tests/parallel/cycle_iteration_mismatch.rs index 61d1da01d..fa84bfb0d 100644 --- a/tests/parallel/cycle_iteration_mismatch.rs +++ b/tests/parallel/cycle_iteration_mismatch.rs @@ -62,7 +62,7 @@ fn query_f(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(b.0.max(e.0)) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep.rs b/tests/parallel/cycle_nested_deep.rs index 3d46bbbc5..72d4ebf74 100644 --- a/tests/parallel/cycle_nested_deep.rs +++ b/tests/parallel/cycle_nested_deep.rs @@ -46,7 +46,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional.rs b/tests/parallel/cycle_nested_deep_conditional.rs index 544342e07..bf9a600b3 100644 --- a/tests/parallel/cycle_nested_deep_conditional.rs +++ b/tests/parallel/cycle_nested_deep_conditional.rs @@ -55,7 +55,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs index 03423b09a..95122bebd 100644 --- a/tests/parallel/cycle_nested_deep_conditional_changed.rs +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -61,7 +61,7 @@ fn query_e(db: &dyn salsa::Database, input: Input) -> CycleValue { query_c(db, input) } -fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { +fn initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_panic.rs b/tests/parallel/cycle_nested_deep_panic.rs index 4356489c3..92d192be5 100644 --- a/tests/parallel/cycle_nested_deep_panic.rs +++ b/tests/parallel/cycle_nested_deep_panic.rs @@ -49,7 +49,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs index 728fc3e70..d56dfd22a 100644 --- a/tests/parallel/cycle_nested_three_threads.rs +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -55,7 +55,7 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0.max(b_value.0)) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads_changed.rs b/tests/parallel/cycle_nested_three_threads_changed.rs index 626b3ef90..b9677ccc4 100644 --- a/tests/parallel/cycle_nested_three_threads_changed.rs +++ b/tests/parallel/cycle_nested_three_threads_changed.rs @@ -54,7 +54,7 @@ fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { CycleValue(a_value.0.max(b_value.0)) } -fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { +fn initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index 13c988f8f..ba05291a5 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -28,7 +28,7 @@ fn cycle_fn( panic!("cancel!") } -fn initial(_db: &dyn KnobsDatabase) -> u32 { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { 0 } diff --git a/tests/parallel/cycle_provisional_depending_on_itself.rs b/tests/parallel/cycle_provisional_depending_on_itself.rs index bb615210e..2c27becb3 100644 --- a/tests/parallel/cycle_provisional_depending_on_itself.rs +++ b/tests/parallel/cycle_provisional_depending_on_itself.rs @@ -67,7 +67,7 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { b } -fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } From 76e65b1890c68b75f4d41db17c067b4489e843ac Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Wed, 29 Oct 2025 22:33:51 +0900 Subject: [PATCH 11/21] doc: Explain the motivation for breaking API changes made in #1012 and #1015 (#1016) * doc: Explain the motivation for breaking API changes made in #1012 and #1015 * Update book/src/cycles.md --------- Co-authored-by: Micha Reiser --- book/src/cycles.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/book/src/cycles.md b/book/src/cycles.md index 2e2c6e7b8..bd0675bdc 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -7,23 +7,23 @@ Salsa also supports recovering from query cycles via fixed-point iteration. Fixe In order to support fixed-point iteration for a query, provide the `cycle_fn` and `cycle_initial` arguments to `salsa::tracked`: ```rust -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial_fn)] +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] fn query(db: &dyn salsa::Database) -> u32 { // ... } -fn cycle_fn(_db: &dyn KnobsDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { +fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { salsa::CycleRecoveryAction::Iterate } -fn initial(_db: &dyn KnobsDatabase) -> u32 { +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { 0 } ``` The `cycle_fn` is optional. The default implementation always returns `Iterate`. -If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should resume starting with the given value (which should be a value that will converge quickly). +If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should continue with the given value (which should be a value that will converge quickly). The cycle will iterate until it converges: that is, until two successive iterations produce the same result. @@ -39,6 +39,11 @@ Consider a two-query cycle where `query_a` calls `query_b`, and `query_b` calls Fixed-point iteration is a powerful tool, but is also easy to misuse, potentially resulting in infinite iteration. To avoid this, ensure that all queries participating in fixpoint iteration are deterministic and monotone. +To guarantee convergence, you can leverage the `last_provisional_value` (3rd parameter) received by `cycle_fn`. +When the `cycle_fn` recalculates a value, you can implement a strategy that references the last provisional value to "join" values ​​or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values ​​between cycles. + +Also, in fixed-point iteration, it is advantageous to be able to identify which cycle head seeded a value. By embedding a `salsa::Id` (2nd parameter) in the initial value as a "cycle marker", the recovery function can detect self-originated recursion. + ## Calling Salsa queries from within `cycle_fn` or `cycle_initial` It is permitted to call other Salsa queries from within the `cycle_fn` and `cycle_initial` functions. However, if these functions re-enter the same cycle, this can lead to unpredictable results. Take care which queries are called from within cycle-recovery functions, and avoid triggering further cycles. From 671c3dcba6ee94794876fd904606cd45a7b71599 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 29 Oct 2025 20:29:46 +0100 Subject: [PATCH 12/21] Only use provisional values from the same revision (#1019) --- src/function/fetch.rs | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 14d7a93d7..f1c58eda1 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -195,26 +195,24 @@ where // existing provisional memo if it exists let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = memo_guard { - if memo.value.is_some() && memo.revisions.cycle_heads().contains(&database_key_index) { - let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); - if can_shallow_update.yes() { - self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); - - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { - memo.revisions - .cycle_heads() - .remove_all_except(database_key_index); - } + if memo.verified_at.load() == zalsa.current_revision() + && memo.value.is_some() + && memo.revisions.cycle_heads().contains(&database_key_index) + { + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { + memo.revisions + .cycle_heads() + .remove_all_except(database_key_index); + } - crate::tracing::debug!( - "hit cycle at {database_key_index:#?}, \ + crate::tracing::debug!( + "hit cycle at {database_key_index:#?}, \ returning last provisional value: {:#?}", - memo.revisions - ); + memo.revisions + ); - // SAFETY: memo is present in memo_map. - return unsafe { self.extend_memo_lifetime(memo) }; - } + // SAFETY: memo is present in memo_map. + return unsafe { self.extend_memo_lifetime(memo) }; } } From 46aa2cfadc91c798b3a1d5fefc2fe19a5ba379bc Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 31 Oct 2025 22:55:41 +0100 Subject: [PATCH 13/21] Update compile fail snapshots to match new rust stable output (#1020) * Update expected test output to match new rust stable output * Remove 1.90 constraint from compile_fail tests * Discard changes to tests/persistence.rs * Update snapshot on unix * Update with the correct rust version * Discard changes to tests/compile_fail.rs --- .../incomplete_persistence.stderr | 32 +++++++++++++++---- tests/compile-fail/span-tracked-getter.stderr | 2 +- ...ot-work-if-the-key-is-a-salsa-input.stderr | 9 ++++-- ...work-if-the-key-is-a-salsa-interned.stderr | 9 ++++-- .../tracked_method_incompatibles.stderr | 2 +- 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/tests/compile-fail/incomplete_persistence.stderr b/tests/compile-fail/incomplete_persistence.stderr index f7082ecca..a7a65b94c 100644 --- a/tests/compile-fail/incomplete_persistence.stderr +++ b/tests/compile-fail/incomplete_persistence.stderr @@ -4,9 +4,14 @@ error[E0277]: the trait bound `NotPersistable<'_>: serde::Serialize` is not sati 1 | #[salsa::tracked(persist)] | ^^^^^^^^^^^^^^^^^^^^^^^^^^ | | - | the trait `Serialize` is not implemented for `NotPersistable<'_>` + | unsatisfied trait bound | required by a bound introduced by this call | +help: the trait `Serialize` is not implemented for `NotPersistable<'_>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 + | +6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Serialize)]` to your `NotPersistable<'_>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Serialize`: @@ -26,8 +31,13 @@ error[E0277]: the trait bound `NotPersistable<'_>: serde::Deserialize<'de>` is n --> tests/compile-fail/incomplete_persistence.rs:1:1 | 1 | #[salsa::tracked(persist)] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Deserialize<'_>` is not implemented for `NotPersistable<'_>` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound + | +help: the trait `Deserialize<'_>` is not implemented for `NotPersistable<'_>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 | +6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `NotPersistable<'_>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Deserialize<'de>`: @@ -47,8 +57,13 @@ error[E0277]: the trait bound `NotPersistable<'db>: serde::Serialize` is not sat --> tests/compile-fail/incomplete_persistence.rs:12:45 | 12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} - | ^^^^^^^^^^^^^^^^^^ the trait `Serialize` is not implemented for `NotPersistable<'db>` + | ^^^^^^^^^^^^^^^^^^ unsatisfied trait bound | +help: the trait `Serialize` is not implemented for `NotPersistable<'db>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 + | + 6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Serialize)]` to your `NotPersistable<'db>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Serialize`: @@ -69,14 +84,19 @@ note: required by a bound in `query_input_is_persistable` | | | required by a bound in this function | required by this bound in `query_input_is_persistable` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `NotPersistable<'db>: serde::Deserialize<'de>` is not satisfied --> tests/compile-fail/incomplete_persistence.rs:12:45 | 12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} - | ^^^^^^^^^^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `NotPersistable<'db>` + | ^^^^^^^^^^^^^^^^^^ unsatisfied trait bound + | +help: the trait `for<'de> Deserialize<'de>` is not implemented for `NotPersistable<'db>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 | + 6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `NotPersistable<'db>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Deserialize<'de>`: @@ -97,4 +117,4 @@ note: required by a bound in `query_input_is_persistable` | | | required by a bound in this function | required by this bound in `query_input_is_persistable` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/span-tracked-getter.stderr b/tests/compile-fail/span-tracked-getter.stderr index fcf546c72..bc304a5c6 100644 --- a/tests/compile-fail/span-tracked-getter.stderr +++ b/tests/compile-fail/span-tracked-getter.stderr @@ -29,4 +29,4 @@ warning: variable does not need to be mutable | | | help: remove this `mut` | - = note: `#[warn(unused_mut)]` on by default + = note: `#[warn(unused_mut)]` (part of `#[warn(unused)]`) on by default diff --git a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr index 580ea67bf..5c6420632 100644 --- a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr +++ b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr @@ -2,8 +2,13 @@ error[E0277]: the trait bound `MyInput: TrackedStructInDb` is not satisfied --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs:15:1 | 15 | #[salsa::tracked(specify)] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `TrackedStructInDb` is not implemented for `MyInput` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound | +help: the trait `TrackedStructInDb` is not implemented for `MyInput` + --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs:5:1 + | + 5 | #[salsa::input] + | ^^^^^^^^^^^^^^^ = help: the trait `TrackedStructInDb` is implemented for `MyTracked<'_>` note: required by a bound in `salsa::function::specify::>::specify_and_record` --> src/function/specify.rs @@ -13,4 +18,4 @@ note: required by a bound in `salsa::function::specify::: TrackedStructInDb, | ^^^^^^^^^^^^^^^^^ required by this bound in `salsa::function::specify::>::specify_and_record` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::input` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr index 01a4b8f60..6c6ba51e0 100644 --- a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr +++ b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr @@ -2,8 +2,13 @@ error[E0277]: the trait bound `MyInterned<'_>: TrackedStructInDb` is not satisfi --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs:15:1 | 15 | #[salsa::tracked(specify)] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `TrackedStructInDb` is not implemented for `MyInterned<'_>` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound | +help: the trait `TrackedStructInDb` is not implemented for `MyInterned<'_>` + --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs:5:1 + | + 5 | #[salsa::interned] + | ^^^^^^^^^^^^^^^^^^ = help: the trait `TrackedStructInDb` is implemented for `MyTracked<'_>` note: required by a bound in `salsa::function::specify::>::specify_and_record` --> src/function/specify.rs @@ -13,4 +18,4 @@ note: required by a bound in `salsa::function::specify::: TrackedStructInDb, | ^^^^^^^^^^^^^^^^^ required by this bound in `salsa::function::specify::>::specify_and_record` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::interned` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/tracked_method_incompatibles.stderr b/tests/compile-fail/tracked_method_incompatibles.stderr index 72a27a33b..5700eb556 100644 --- a/tests/compile-fail/tracked_method_incompatibles.stderr +++ b/tests/compile-fail/tracked_method_incompatibles.stderr @@ -52,7 +52,7 @@ warning: unused variable: `db` 9 | fn ref_self(&self, db: &dyn salsa::Database) {} | ^^ help: if this is intentional, prefix it with an underscore: `_db` | - = note: `#[warn(unused_variables)]` on by default + = note: `#[warn(unused_variables)]` (part of `#[warn(unused)]`) on by default warning: unused variable: `db` --> tests/compile-fail/tracked_method_incompatibles.rs:15:32 From c762869fd590855e444a957afbea355dec7f6028 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sat, 1 Nov 2025 01:36:55 +0100 Subject: [PATCH 14/21] Always increment iteration count (#1017) --- src/function/execute.rs | 39 ++++++++++++++++++++++------- src/function/fetch.rs | 18 ++++++++++++- src/function/maybe_changed_after.rs | 23 ----------------- src/function/memo.rs | 20 ++++++--------- src/zalsa_local.rs | 27 ++++++++++++++------ 5 files changed, 73 insertions(+), 54 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index d299b0966..9d6758730 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -172,11 +172,7 @@ where let mut iteration_count = IterationCount::initial(); if let Some(old_memo) = opt_old_memo { - if old_memo.verified_at.load() == zalsa.current_revision() - && old_memo.cycle_heads().contains(&database_key_index) - { - let memo_iteration_count = old_memo.revisions.iteration(); - + if old_memo.verified_at.load() == zalsa.current_revision() { // The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query. // However, the locking doesn't handle the case where a thread fetches the result of a panicking // cycle head query **after** all locks were released. That's what we do here. @@ -189,8 +185,14 @@ where tracing::warn!("Propagating panic for cycle head that panicked in an earlier execution in that revision"); Cancelled::PropagatedPanic.throw(); } - last_provisional_memo = Some(old_memo); - iteration_count = memo_iteration_count; + + // Only use the last provisional memo if it was a cycle head in the last iteration. This is to + // force at least two executions. + if old_memo.cycle_heads().contains(&database_key_index) { + last_provisional_memo = Some(old_memo); + } + + iteration_count = old_memo.revisions.iteration(); } } @@ -216,6 +218,14 @@ where // If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty) let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else { + iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); + panic!("{database_key_index:?}: execute: too many cycle iterations") + }); + completed_query + .revisions + .update_iteration_count_mut(database_key_index, iteration_count); + claim_guard.set_release_mode(ReleaseMode::SelfOnly); break (new_value, completed_query); }; @@ -289,6 +299,15 @@ where } completed_query.revisions.set_cycle_heads(cycle_heads); + + iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); + panic!("{database_key_index:?}: execute: too many cycle iterations") + }); + completed_query + .revisions + .update_iteration_count_mut(database_key_index, iteration_count); + break (new_value, completed_query); } @@ -555,8 +574,10 @@ impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { impl Drop for PoisonProvisionalIfPanicking<'_, C> { fn drop(&mut self) { if thread::panicking() { - let revisions = - QueryRevisions::fixpoint_initial(self.ingredient.database_key_index(self.id)); + let revisions = QueryRevisions::fixpoint_initial( + self.ingredient.database_key_index(self.id), + IterationCount::initial(), + ); let memo = Memo::new(None, self.zalsa.current_revision(), revisions); self.ingredient diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f1c58eda1..588b08bb1 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -195,6 +195,9 @@ where // existing provisional memo if it exists let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = memo_guard { + // Ideally, we'd use the last provisional memo even if it wasn't a cycle head in the last iteration + // but that would require inserting itself as a cycle head, which either requires clone + // on the value OR a concurrent `Vec` for cycle heads. if memo.verified_at.load() == zalsa.current_revision() && memo.value.is_some() && memo.revisions.cycle_heads().contains(&database_key_index) @@ -233,7 +236,20 @@ where "hit cycle at {database_key_index:#?}, \ inserting and returning fixpoint initial value" ); - let revisions = QueryRevisions::fixpoint_initial(database_key_index); + + let iteration = memo_guard + .and_then(|old_memo| { + if old_memo.verified_at.load() == zalsa.current_revision() + && old_memo.value.is_some() + { + Some(old_memo.revisions.iteration()) + } else { + None + } + }) + .unwrap_or(IterationCount::initial()); + let revisions = QueryRevisions::fixpoint_initial(database_key_index, iteration); + let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); self.insert_memo( zalsa, diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 4198631b9..20440883e 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -509,8 +509,6 @@ where head_iteration_count, memo_iteration_count: current_iteration_count, verified_at: head_verified_at, - cycle_heads, - database_key_index: head_database_key, } => { if head_verified_at != memo_verified_at { return false; @@ -519,27 +517,6 @@ where if head_iteration_count != current_iteration_count { return false; } - - // Check if the memo is still a cycle head and hasn't changed - // to a normal cycle participant. This is to force re-execution in - // a scenario like this: - // - // * There's a nested cycle with the outermost query A - // * B participates in the cycle and is a cycle head in the first few iterations - // * B becomes a non-cycle head in a later iteration - // * There's a query `C` that has `B` as its cycle head - // - // The crucial point is that `B` switches from being a cycle head to being a regular cycle participant. - // The issue with that is that `A` doesn't update `B`'s `iteration_count `when the iteration completes - // because it only does that for cycle heads (and collecting all queries participating in a query would be sort of expensive?). - // - // When we now pull `C` in a later iteration, `validate_same_iteration` iterates over all its cycle heads (`B`), - // and check if the iteration count still matches. Which is the case because `A` didn't update `B`'s iteration count. - // - // That's why we also check if `B` is still a cycle head in the current iteration. - if !cycle_heads.contains(&head_database_key) { - return false; - } } _ => { return false; diff --git a/src/function/memo.rs b/src/function/memo.rs index fd830ced3..d8faf3e0b 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -409,11 +409,9 @@ mod persistence { pub(super) enum TryClaimHeadsResult<'me> { /// Claiming the cycle head results in a cycle. Cycle { - database_key_index: DatabaseKeyIndex, head_iteration_count: IterationCount, memo_iteration_count: IterationCount, verified_at: Revision, - cycle_heads: &'me CycleHeads, }, /// The cycle head is not finalized, but it can be claimed. @@ -460,28 +458,24 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { let provisional_status = ingredient .provisional_status(self.zalsa, head_key_index) .expect("cycle head memo to exist"); - let (current_iteration_count, verified_at, cycle_heads) = match provisional_status { + let (current_iteration_count, verified_at) = match provisional_status { ProvisionalStatus::Provisional { iteration, verified_at, - cycle_heads, - } => (iteration, verified_at, cycle_heads), + cycle_heads: _, + } => (iteration, verified_at), ProvisionalStatus::Final { iteration, verified_at, - } => (iteration, verified_at, empty_cycle_heads()), - ProvisionalStatus::FallbackImmediate => ( - IterationCount::initial(), - self.zalsa.current_revision(), - empty_cycle_heads(), - ), + } => (iteration, verified_at), + ProvisionalStatus::FallbackImmediate => { + (IterationCount::initial(), self.zalsa.current_revision()) + } }; Some(TryClaimHeadsResult::Cycle { - database_key_index: head_database_key, memo_iteration_count: current_iteration_count, head_iteration_count: head.iteration_count.load(), - cycle_heads, verified_at, }) } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 7b0399178..f43eb78eb 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -639,7 +639,7 @@ const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 7 } else { 3 }]>()]; impl QueryRevisions { - pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, iteration: IterationCount) -> Self { Self { changed_at: Revision::start(), durability: Durability::MAX, @@ -651,8 +651,8 @@ impl QueryRevisions { #[cfg(feature = "accumulator")] AccumulatedMap::default(), ThinVec::default(), - CycleHeads::initial(query, IterationCount::initial()), - IterationCount::initial(), + CycleHeads::initial(query, iteration), + iteration, ), } } @@ -743,12 +743,23 @@ impl QueryRevisions { cycle_head_index: DatabaseKeyIndex, iteration_count: IterationCount, ) { - if let Some(extra) = &mut self.extra.0 { - extra.iteration.store_mut(iteration_count); + match &mut self.extra.0 { + None => { + self.extra = QueryRevisionsExtra::new( + #[cfg(feature = "accumulator")] + AccumulatedMap::default(), + ThinVec::default(), + empty_cycle_heads().clone(), + iteration_count, + ); + } + Some(extra) => { + extra.iteration.store_mut(iteration_count); - extra - .cycle_heads - .update_iteration_count_mut(cycle_head_index, iteration_count); + extra + .cycle_heads + .update_iteration_count_mut(cycle_head_index, iteration_count); + } } } From 664750a6e588ed23a0d2d9105a02cb5993c8e178 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 3 Nov 2025 21:44:22 +0100 Subject: [PATCH 15/21] Track cycle function dependencies as part of the cyclic query (#1018) * Track cycle function dependenciees as part of the cyclic query * Add regression test * Discard changes to src/function/backdate.rs * Update comment * Fix merge error * Refine comment --- src/active_query.rs | 4 ++ src/cycle.rs | 4 ++ src/function/execute.rs | 61 +++++++++++++++------ src/function/maybe_changed_after.rs | 5 +- src/zalsa_local.rs | 12 ++++ tests/cycle_recovery_dependencies.rs | 82 ++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 19 deletions(-) create mode 100644 tests/cycle_recovery_dependencies.rs diff --git a/src/active_query.rs b/src/active_query.rs index bb5987fcd..c80cded3b 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -91,6 +91,10 @@ impl ActiveQuery { .mark_all_active(active_tracked_ids.iter().copied()); } + pub(super) fn take_cycle_heads(&mut self) -> CycleHeads { + std::mem::take(&mut self.cycle_heads) + } + pub(super) fn add_read( &mut self, input: DatabaseKeyIndex, diff --git a/src/cycle.rs b/src/cycle.rs index fcbadf891..3f6f70aa0 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -490,4 +490,8 @@ impl<'db> ProvisionalStatus<'db> { _ => empty_cycle_heads(), } } + + pub(crate) const fn is_provisional(&self) -> bool { + matches!(self, ProvisionalStatus::Provisional { .. }) + } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 9d6758730..53bc640a2 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -56,20 +56,25 @@ where }); let (new_value, mut completed_query) = match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Panic => Self::execute_query( - db, - zalsa, - zalsa_local.push_query(database_key_index, IterationCount::initial()), - opt_old_memo, - ), + CycleRecoveryStrategy::Panic => { + let (new_value, active_query) = Self::execute_query( + db, + zalsa, + zalsa_local.push_query(database_key_index, IterationCount::initial()), + opt_old_memo, + ); + (new_value, active_query.pop()) + } CycleRecoveryStrategy::FallbackImmediate => { - let (mut new_value, mut completed_query) = Self::execute_query( + let (mut new_value, active_query) = Self::execute_query( db, zalsa, zalsa_local.push_query(database_key_index, IterationCount::initial()), opt_old_memo, ); + let mut completed_query = active_query.pop(); + if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() { // Did the new result we got depend on our own provisional value, in a cycle? if cycle_heads.contains(&database_key_index) { @@ -198,9 +203,10 @@ where let _poison_guard = PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); - let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); let (new_value, completed_query) = loop { + let active_query = zalsa_local.push_query(database_key_index, iteration_count); + // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next // query with these ids because the query might re-create them as part of the next iteration. @@ -209,29 +215,32 @@ where // if they aren't recreated when reaching the final iteration. active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); - let (mut new_value, mut completed_query) = Self::execute_query( + let (mut new_value, mut active_query) = Self::execute_query( db, zalsa, active_query, last_provisional_memo.or(opt_old_memo), ); - // If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty) - let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else { + // Take the cycle heads to not-fight-rust's-borrow-checker. + let mut cycle_heads = active_query.take_cycle_heads(); + + // If there are no cycle heads, break out of the loop. + if cycle_heads.is_empty() { iteration_count = iteration_count.increment().unwrap_or_else(|| { tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); panic!("{database_key_index:?}: execute: too many cycle iterations") }); + + let mut completed_query = active_query.pop(); completed_query .revisions .update_iteration_count_mut(database_key_index, iteration_count); claim_guard.set_release_mode(ReleaseMode::SelfOnly); break (new_value, completed_query); - }; + } - // Take the cycle heads to not-fight-rust's-borrow-checker. - let mut cycle_heads = std::mem::take(cycle_heads); let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> = SmallVec::new_const(); let mut max_iteration_count = iteration_count; @@ -262,6 +271,11 @@ where .provisional_status(zalsa, head.database_key_index.key_index()) .expect("cycle head memo must have been created during the execution"); + // A query should only ever depend on other heads that are provisional. + // If this invariant is violated, it means that this query participates in a cycle, + // but it wasn't executed in the last iteration of said cycle. + assert!(provisional_status.is_provisional()); + for nested_head in provisional_status.cycle_heads() { let nested_as_tuple = ( nested_head.database_key_index, @@ -298,6 +312,8 @@ where claim_guard.set_release_mode(ReleaseMode::SelfOnly); } + let mut completed_query = active_query.pop(); + *completed_query.revisions.verified_final.get_mut() = false; completed_query.revisions.set_cycle_heads(cycle_heads); iteration_count = iteration_count.increment().unwrap_or_else(|| { @@ -378,8 +394,17 @@ where this_converged = C::values_equal(&new_value, last_provisional_value); } } + + let new_cycle_heads = active_query.take_cycle_heads(); + for head in new_cycle_heads { + if !cycle_heads.contains(&head.database_key_index) { + panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index); + } + } } + let mut completed_query = active_query.pop(); + if let Some(outer_cycle) = outer_cycle { tracing::info!( "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" @@ -390,6 +415,7 @@ where completed_query .revisions .set_cycle_converged(this_converged); + *completed_query.revisions.verified_final.get_mut() = false; // Transfer ownership of this query to the outer cycle, so that it can claim it // and other threads don't compete for the same lock. @@ -428,9 +454,9 @@ where } *completed_query.revisions.verified_final.get_mut() = true; - break (new_value, completed_query); } + *completed_query.revisions.verified_final.get_mut() = false; // The fixpoint iteration hasn't converged. Iterate again... iteration_count = iteration_count.increment().unwrap_or_else(|| { @@ -484,7 +510,6 @@ where last_provisional_memo = Some(new_memo); last_stale_tracked_ids = completed_query.stale_tracked_structs; - active_query = zalsa_local.push_query(database_key_index, iteration_count); continue; }; @@ -503,7 +528,7 @@ where zalsa: &'db Zalsa, active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, - ) -> (C::Output<'db>, CompletedQuery) { + ) -> (C::Output<'db>, ActiveQueryGuard<'db>) { if let Some(old_memo) = opt_old_memo { // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. @@ -528,7 +553,7 @@ where C::id_to_input(zalsa, active_query.database_key_index.key_index()), ); - (new_value, active_query.pop()) + (new_value, active_query) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 20440883e..165a3fb02 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -592,7 +592,10 @@ where cycle_heads.append_heads(&mut child_cycle_heads); match input_result { - VerifyResult::Changed => return VerifyResult::changed(), + VerifyResult::Changed => { + cycle_heads.remove_head(database_key_index); + return VerifyResult::changed(); + } #[cfg(feature = "accumulator")] VerifyResult::Unchanged { accumulated } => { inputs |= accumulated; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index f43eb78eb..bde3b6b24 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1213,6 +1213,18 @@ impl ActiveQueryGuard<'_> { } } + pub(crate) fn take_cycle_heads(&mut self) -> CycleHeads { + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.local_state.with_query_stack_unchecked_mut(|stack| { + #[cfg(debug_assertions)] + assert_eq!(stack.len(), self.push_len); + let frame = stack.last_mut().unwrap(); + frame.take_cycle_heads() + }) + } + } + /// Invoked when the query has successfully completed execution. fn complete(self) -> CompletedQuery { // SAFETY: We do not access the query stack reentrantly. diff --git a/tests/cycle_recovery_dependencies.rs b/tests/cycle_recovery_dependencies.rs new file mode 100644 index 000000000..b26ce973b --- /dev/null +++ b/tests/cycle_recovery_dependencies.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "inventory")] + +//! Queries or inputs read within the cycle recovery function +//! are tracked on the cycle function and don't "leak" into the +//! function calling the query with cycle handling. + +use expect_test::expect; +use salsa::Setter as _; + +use crate::common::LogDatabase; + +mod common; + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::tracked] +fn entry(db: &dyn salsa::Database, input: Input) -> u32 { + query(db, input) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database, input: Input) -> u32 { + let val = query(db, input); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u32 { + 0 +} + +fn cycle_fn( + db: &dyn salsa::Database, + _id: salsa::Id, + _last_provisional_value: &u32, + _value: &u32, + _count: u32, + input: Input, +) -> salsa::CycleRecoveryAction { + let _input = input.value(db); + salsa::CycleRecoveryAction::Iterate +} + +#[test_log::test] +fn the_test() { + let mut db = common::EventLoggerDatabase::default(); + + let input = Input::new(&db, 1); + assert_eq!(entry(&db, input), 5); + + db.assert_logs_len(15); + + input.set_value(&mut db).to(2); + + assert_eq!(entry(&db, input), 5); + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: query(Id(0)) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(1) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(2) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(3) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(4) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(5) }", + "WillCheckCancellation", + "DidValidateMemoizedValue { database_key: entry(Id(0)) }", + ]"#]]); +} From 05a9af7f554b64b8aadc2eeb6f2caf73d0408d09 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 5 Nov 2025 13:37:23 +0100 Subject: [PATCH 16/21] Call `cycle_fn` for every iteration (#1021) * Call `cycle_fn` for every iteration * Update documentation * Clippy * Remove `CycleRecoveryAction` --- benches/dataflow.rs | 22 +++++----- book/src/cycles.md | 14 +++--- .../salsa-macro-rules/src/setup_tracked_fn.rs | 4 +- .../src/unexpected_cycle_recovery.rs | 4 +- src/cycle.rs | 31 +++---------- src/function.rs | 25 +++++++---- src/function/execute.rs | 43 +++++++------------ src/function/memo.rs | 8 ++-- src/lib.rs | 4 +- tests/cycle.rs | 20 +++++---- tests/cycle_accumulate.rs | 6 +-- tests/cycle_recovery_call_back_into_cycle.rs | 12 ++++-- tests/cycle_recovery_call_query.rs | 6 +-- tests/cycle_recovery_dependencies.rs | 6 +-- tests/dataflow.rs | 38 +++++++++------- tests/parallel/cycle_panic.rs | 4 +- 16 files changed, 116 insertions(+), 131 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index a548c806a..cf20140f6 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -6,7 +6,7 @@ use std::collections::BTreeSet; use std::iter::IntoIterator; use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; -use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use salsa::{Database as Db, Setter}; /// A Use of a symbol. #[salsa::input] @@ -78,10 +78,10 @@ fn def_cycle_recover( _db: &dyn Db, _id: salsa::Id, _last_provisional_value: &Type, - value: &Type, + value: Type, count: u32, _def: Definition, -) -> CycleRecoveryAction { +) -> Type { cycle_recover(value, count) } @@ -93,24 +93,24 @@ fn use_cycle_recover( _db: &dyn Db, _id: salsa::Id, _last_provisional_value: &Type, - value: &Type, + value: Type, count: u32, _use: Use, -) -> CycleRecoveryAction { +) -> Type { cycle_recover(value, count) } -fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { - match value { - Type::Bottom => CycleRecoveryAction::Iterate, +fn cycle_recover(value: Type, count: u32) -> Type { + match &value { + Type::Bottom => value, Type::Values(_) => { if count > 4 { - CycleRecoveryAction::Fallback(Type::Top) + Type::Top } else { - CycleRecoveryAction::Iterate + value } } - Type::Top => CycleRecoveryAction::Iterate, + Type::Top => value, } } diff --git a/book/src/cycles.md b/book/src/cycles.md index bd0675bdc..023f5bb79 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -12,8 +12,8 @@ fn query(db: &dyn salsa::Database) -> u32 { // ... } -fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate +fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, value: u32, _count: u32) -> u32 { + value } fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { @@ -21,13 +21,11 @@ fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { } ``` -The `cycle_fn` is optional. The default implementation always returns `Iterate`. +The `cycle_fn` is optional. The default implementation always returns the computed `value`. -If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should continue with the given value (which should be a value that will converge quickly). +If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the last provisional value, the newly computed value, and the number of iterations that have occurred so far). The `cycle_fn` can return the `value` parameter to continue iterating with the computed value, or return a different value (a fallback value) to continue iteration with that value instead. -The cycle will iterate until it converges: that is, until two successive iterations produce the same result. - -If the `cycle_fn` returns `Fallback`, the cycle will still continue to iterate (using the given value as a new starting point), in order to verify that the fallback value results in a stable converged cycle. It is not permitted to use a fallback value that does not converge, because this would leave the cycle in an unpredictable state, depending on the order of query execution. +The cycle will iterate until it converges: that is, until the value returned by `cycle_fn` equals the value from the previous iteration. If a cycle iterates more than 200 times, Salsa will panic rather than iterate forever. @@ -40,7 +38,7 @@ Consider a two-query cycle where `query_a` calls `query_b`, and `query_b` calls Fixed-point iteration is a powerful tool, but is also easy to misuse, potentially resulting in infinite iteration. To avoid this, ensure that all queries participating in fixpoint iteration are deterministic and monotone. To guarantee convergence, you can leverage the `last_provisional_value` (3rd parameter) received by `cycle_fn`. -When the `cycle_fn` recalculates a value, you can implement a strategy that references the last provisional value to "join" values ​​or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values ​​between cycles. +When the `cycle_fn` receives a newly computed value, you can implement a strategy that references the last provisional value to "join" values or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values between cycles. For example: Also, in fixed-point iteration, it is advantageous to be able to identify which cycle head seeded a value. By embedding a `salsa::Id` (2nd parameter) in the initial value as a "cycle marker", the recovery function can detect self-originated recursion. diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 8ea4e5e33..1c3312372 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -310,10 +310,10 @@ macro_rules! setup_tracked_fn { db: &$db_lt dyn $Db, id: salsa::Id, last_provisional_value: &Self::Output<$db_lt>, - value: &Self::Output<$db_lt>, + value: Self::Output<$db_lt>, iteration_count: u32, ($($input_id),*): ($($interned_input_ty),*) - ) -> $zalsa::CycleRecoveryAction> { + ) -> Self::Output<$db_lt> { $($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*) } diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index ff03c02a2..fe002fa4e 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -4,9 +4,9 @@ #[macro_export] macro_rules! unexpected_cycle_recovery { ($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{ - let (_db, _id, _last_provisional_value, _new_value, _count) = ($db, $id, $last_provisional_value, $new_value, $count); + let (_db, _id, _last_provisional_value, _count) = ($db, $id, $last_provisional_value, $count); std::mem::drop(($($other_inputs,)*)); - salsa::CycleRecoveryAction::Iterate + $new_value }}; } diff --git a/src/cycle.rs b/src/cycle.rs index 3f6f70aa0..0f12472b4 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -23,14 +23,12 @@ //! //! When a query observes that it has just computed a result which contains itself as a cycle head, //! it recognizes that it is responsible for resolving this cycle and calls its `cycle_fn` to -//! decide how to do so. The `cycle_fn` function is passed the provisional value just computed for -//! that query and the count of iterations so far, and must return either -//! `CycleRecoveryAction::Iterate` (which signals that the cycle head should re-iterate the cycle), -//! or `CycleRecoveryAction::Fallback` (which signals that the cycle head should replace its -//! computed value with the given fallback value). +//! decide what value to use. The `cycle_fn` function is passed the provisional value just computed +//! for that query and the count of iterations so far, and returns the value to use for this +//! iteration. This can be the computed value itself, or a different value (e.g., a fallback value). //! -//! If the cycle head ever observes that the provisional value it just recomputed is the same as -//! the provisional value from the previous iteration, the cycle has converged. The cycle head will +//! If the cycle head ever observes that the value returned by `cycle_fn` is the same as the +//! provisional value from the previous iteration, this cycle has converged. The cycle head will //! mark that value as final (by removing itself as cycle head) and return it. //! //! Other queries in the cycle will still have provisional values recorded, but those values should @@ -39,11 +37,6 @@ //! of its cycle heads have a final result, in which case it, too, can be marked final. (This is //! implemented in `shallow_verify_memo` and `validate_provisional`.) //! -//! If the `cycle_fn` returns a fallback value, the cycle head will replace its provisional value -//! with that fallback, and then iterate the cycle one more time. A fallback value is expected to -//! result in a stable, converged cycle. If it does not (that is, if the result of another -//! iteration of the cycle is not the same as the fallback value), we'll panic. -//! //! In nested cycle cases, the inner cycles are iterated as part of the outer cycle iteration. This helps //! to significantly reduce the number of iterations needed to reach a fixpoint. For nested cycles, //! the inner cycles head will transfer their lock ownership to the outer cycle. This ensures @@ -64,20 +57,6 @@ use crate::Revision; /// Should only be relevant in case of a badly configured cycle recovery. pub const MAX_ITERATIONS: IterationCount = IterationCount(200); -/// Return value from a cycle recovery function. -#[derive(Debug)] -pub enum CycleRecoveryAction { - /// Iterate the cycle again to look for a fixpoint. - Iterate, - - /// Use the given value as the result for the current iteration instead - /// of the value computed by the query function. - /// - /// Returning `Fallback` doesn't stop the fixpoint iteration. It only - /// allows the iterate function to return a different value. - Fallback(T), -} - /// Cycle recovery strategy: Is this query capable of recovering from /// a cycle that results from executing the function? If so, how? #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/src/function.rs b/src/function.rs index 045825e19..b9878bc41 100644 --- a/src/function.rs +++ b/src/function.rs @@ -7,7 +7,7 @@ use std::ptr::NonNull; use std::sync::atomic::Ordering; use std::sync::OnceLock; -use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; +use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::hash::{FxHashSet, FxIndexSet}; @@ -91,9 +91,11 @@ pub trait Configuration: Any { input: Self::Input<'db>, ) -> Self::Output<'db>; - /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return - /// value from the latest iteration of this cycle. `count` is the number of cycle iterations - /// completed so far. + /// Decide what value to use for this cycle iteration. Takes ownership of the new value + /// and returns an owned value to use. + /// + /// The function is called for every iteration of the cycle head, regardless of whether the cycle + /// has converged (the values are equal). /// /// # Id /// @@ -112,17 +114,22 @@ pub trait Configuration: Any { /// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that /// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case, /// `iteration` continues from the nested cycle's iteration count rather than resetting to zero. - /// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle - /// and the value for this query remains unchanged for one iteration. But the outer cycle might - /// keep iterating because other heads keep changing. + /// * **Non-contiguous values**: The iteration count can be non-contigious for cycle heads + /// that are only conditionally part of a cycle. + /// + /// # Return value + /// + /// The function should return the value to use for this iteration. This can be the `value` + /// that was computed, or a different value (e.g., a fallback value). This cycle will continue + /// iterating until the returned value equals the previous iteration's value. fn recover_from_cycle<'db>( db: &'db Self::DbView, id: Id, last_provisional_value: &Self::Output<'db>, - new_value: &Self::Output<'db>, + value: Self::Output<'db>, iteration: u32, input: Self::Input<'db>, - ) -> CycleRecoveryAction>; + ) -> Self::Output<'db>; /// Serialize the output type using `serde`. /// diff --git a/src/function/execute.rs b/src/function/execute.rs index 53bc640a2..d07bb45f6 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -357,8 +357,6 @@ where I am a cycle head, comparing last provisional value with new value" ); - let mut this_converged = C::values_equal(&new_value, last_provisional_value); - // If this is the outermost cycle, use the maximum iteration count of all cycles. // This is important for when later iterations introduce new cycle heads (that then // become the outermost cycle). We want to ensure that the iteration count keeps increasing @@ -373,36 +371,25 @@ where iteration_count }; - if !this_converged { - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle( - db, - id, - last_provisional_value, - &new_value, - iteration_count.as_u32(), - C::id_to_input(zalsa, id), - ) { - crate::CycleRecoveryAction::Iterate => {} - crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( - "{database_key_index:?}: execute: user cycle_fn says to fall back" - ); - new_value = fallback_value; - - this_converged = C::values_equal(&new_value, last_provisional_value); - } - } + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do (it may return the same value or a different one): + new_value = C::recover_from_cycle( + db, + id, + last_provisional_value, + new_value, + iteration_count.as_u32(), + C::id_to_input(zalsa, id), + ); - let new_cycle_heads = active_query.take_cycle_heads(); - for head in new_cycle_heads { - if !cycle_heads.contains(&head.database_key_index) { - panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index); - } + let new_cycle_heads = active_query.take_cycle_heads(); + for head in new_cycle_heads { + if !cycle_heads.contains(&head.database_key_index) { + panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index); } } + let this_converged = C::values_equal(&new_value, last_provisional_value); let mut completed_query = active_query.pop(); if let Some(outer_cycle) = outer_cycle { diff --git a/src/function/memo.rs b/src/function/memo.rs index d8faf3e0b..f22af65fe 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -496,7 +496,7 @@ mod _memory_usage { use crate::plumbing::{self, IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb}; use crate::table::memo::MemoTableWithTypes; use crate::zalsa::Zalsa; - use crate::{CycleRecoveryAction, Database, Id, Revision}; + use crate::{Database, Id, Revision}; use std::any::TypeId; use std::num::NonZeroUsize; @@ -564,11 +564,11 @@ mod _memory_usage { _: &'db Self::DbView, _: Id, _: &Self::Output<'db>, - _: &Self::Output<'db>, + value: Self::Output<'db>, _: u32, _: Self::Input<'db>, - ) -> CycleRecoveryAction> { - unimplemented!() + ) -> Self::Output<'db> { + value } fn serialize(_: &Self::Output<'_>, _: S) -> Result diff --git a/src/lib.rs b/src/lib.rs index 8c50c9052..d4409c4a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,7 @@ pub use self::database::IngredientInfo; pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; -pub use self::cycle::CycleRecoveryAction; + pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; @@ -92,7 +92,7 @@ pub mod plumbing { #[cfg(feature = "accumulator")] pub use crate::accumulator::Accumulator; pub use crate::attach::{attach, with_attached_database}; - pub use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy}; + pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::{current_revision, Database}; pub use crate::durability::Durability; pub use crate::id::{AsId, FromId, FromIdWithDb, Id}; diff --git a/tests/cycle.rs b/tests/cycle.rs index dbe0bdc19..dd476ab76 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -7,7 +7,7 @@ mod common; use common::{ExecuteValidateLoggerDatabase, LogDatabase}; use expect_test::expect; -use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; +use salsa::{Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; #[cfg(not(miri))] use test_log::test; @@ -122,24 +122,26 @@ const MAX_ITERATIONS: u32 = 3; /// Recover from a cycle by falling back to `Value::OutOfBounds` if the value is out of bounds, /// `Value::TooManyIterations` if we've iterated more than `MAX_ITERATIONS` times, or else -/// iterating again. +/// returning the computed value to continue iterating. fn cycle_recover( _db: &dyn Db, _id: salsa::Id, - _last_provisional_value: &Value, - value: &Value, + last_provisional_value: &Value, + value: Value, count: u32, _inputs: Inputs, -) -> CycleRecoveryAction { - if value +) -> Value { + if &value == last_provisional_value { + value + } else if value .to_value() .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) { - CycleRecoveryAction::Fallback(Value::OutOfBounds) + Value::OutOfBounds } else if count > MAX_ITERATIONS { - CycleRecoveryAction::Fallback(Value::TooManyIterations) + Value::TooManyIterations } else { - CycleRecoveryAction::Iterate + value } } diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 49f1d06d9..6377805b8 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -52,11 +52,11 @@ fn cycle_fn( _db: &dyn LogDatabase, _id: salsa::Id, _last_provisional_value: &[u32], - _value: &[u32], + value: Vec, _count: u32, _file: File, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate +) -> Vec { + value } #[test] diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 4ab236565..77f7378e4 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -28,11 +28,15 @@ fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 { fn cycle_fn( db: &dyn ValueDatabase, _id: salsa::Id, - _last_provisional_value: &u32, - _value: &u32, + last_provisional_value: &u32, + value: u32, _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +) -> u32 { + if &value == last_provisional_value { + value + } else { + fallback_value(db) + } } #[test] diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index a227d6122..dae4203d7 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -25,10 +25,10 @@ fn cycle_fn( db: &dyn salsa::Database, _id: salsa::Id, _last_provisional_value: &u32, - _value: &u32, + _value: u32, _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +) -> u32 { + fallback_value(db) } #[test_log::test] diff --git a/tests/cycle_recovery_dependencies.rs b/tests/cycle_recovery_dependencies.rs index b26ce973b..fe93428e5 100644 --- a/tests/cycle_recovery_dependencies.rs +++ b/tests/cycle_recovery_dependencies.rs @@ -39,12 +39,12 @@ fn cycle_fn( db: &dyn salsa::Database, _id: salsa::Id, _last_provisional_value: &u32, - _value: &u32, + value: u32, _count: u32, input: Input, -) -> salsa::CycleRecoveryAction { +) -> u32 { let _input = input.value(db); - salsa::CycleRecoveryAction::Iterate + value } #[test_log::test] diff --git a/tests/dataflow.rs b/tests/dataflow.rs index 69c91d513..f91123ef0 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -7,7 +7,7 @@ use std::collections::BTreeSet; use std::iter::IntoIterator; -use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use salsa::{Database as Db, Setter}; /// A Use of a symbol. #[salsa::input] @@ -78,12 +78,16 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, _id: salsa::Id, - _last_provisional_value: &Type, - value: &Type, + last_provisional_value: &Type, + value: Type, count: u32, _def: Definition, -) -> CycleRecoveryAction { - cycle_recover(value, count) +) -> Type { + if &value == last_provisional_value { + value + } else { + cycle_recover(value, count) + } } fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { @@ -93,25 +97,29 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, _id: salsa::Id, - _last_provisional_value: &Type, - value: &Type, + last_provisional_value: &Type, + value: Type, count: u32, _use: Use, -) -> CycleRecoveryAction { - cycle_recover(value, count) +) -> Type { + if &value == last_provisional_value { + value + } else { + cycle_recover(value, count) + } } -fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { - match value { - Type::Bottom => CycleRecoveryAction::Iterate, +fn cycle_recover(value: Type, count: u32) -> Type { + match &value { + Type::Bottom => value, Type::Values(_) => { if count > 4 { - CycleRecoveryAction::Fallback(Type::Top) + Type::Top } else { - CycleRecoveryAction::Iterate + value } } - Type::Top => CycleRecoveryAction::Iterate, + Type::Top => value, } } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index ba05291a5..34cbb7ed2 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -22,9 +22,9 @@ fn cycle_fn( _db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, - _value: &u32, + _value: u32, _count: u32, -) -> salsa::CycleRecoveryAction { +) -> u32 { panic!("cancel!") } From a885bb4c4c192741b8a17418fef81a71e33d111e Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 13 Nov 2025 10:17:44 +0100 Subject: [PATCH 17/21] Fix cycle head durability (#1024) --- src/function/execute.rs | 36 ++++++++---- src/zalsa_local.rs | 4 ++ tests/cycle_input_different_cycle_head.rs | 72 +++++++++++++++++++++++ 3 files changed, 101 insertions(+), 11 deletions(-) create mode 100644 tests/cycle_input_different_cycle_head.rs diff --git a/src/function/execute.rs b/src/function/execute.rs index d07bb45f6..558ace738 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -170,7 +170,7 @@ where // Our provisional value from the previous iteration, when doing fixpoint iteration. // This is different from `opt_old_memo` which might be from a different revision. - let mut last_provisional_memo: Option<&Memo<'db, C>> = None; + let mut last_provisional_memo_opt: Option<&Memo<'db, C>> = None; // TODO: Can we seed those somehow? let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); @@ -194,7 +194,7 @@ where // Only use the last provisional memo if it was a cycle head in the last iteration. This is to // force at least two executions. if old_memo.cycle_heads().contains(&database_key_index) { - last_provisional_memo = Some(old_memo); + last_provisional_memo_opt = Some(old_memo); } iteration_count = old_memo.revisions.iteration(); @@ -219,7 +219,7 @@ where db, zalsa, active_query, - last_provisional_memo.or(opt_old_memo), + last_provisional_memo_opt.or(opt_old_memo), ); // Take the cycle heads to not-fight-rust's-borrow-checker. @@ -329,10 +329,7 @@ where // Get the last provisional value for this query so that we can compare it with the new value // to test if the cycle converged. - let last_provisional_value = if let Some(last_provisional) = last_provisional_memo { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() - } else { + let last_provisional_memo = last_provisional_memo_opt.unwrap_or_else(|| { // This is our first time around the loop; a provisional value must have been // inserted into the memo table when the cycle was hit, so let's pull our // initial provisional value from there. @@ -346,8 +343,10 @@ where }); debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; + memo + }); + + let last_provisional_value = last_provisional_memo.value.as_ref(); let last_provisional_value = last_provisional_value.expect( "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", @@ -389,9 +388,24 @@ where } } - let this_converged = C::values_equal(&new_value, last_provisional_value); let mut completed_query = active_query.pop(); + let value_converged = C::values_equal(&new_value, last_provisional_value); + + // It's important to force a re-execution of the cycle if `changed_at` or `durability` has changed + // to ensure the reduced durability and changed propagates to all queries depending on this head. + let metadata_converged = last_provisional_memo.revisions.durability + == completed_query.revisions.durability + && last_provisional_memo.revisions.changed_at + == completed_query.revisions.changed_at + && last_provisional_memo + .revisions + .origin + .is_derived_untracked() + == completed_query.revisions.origin.is_derived_untracked(); + + let this_converged = value_converged && metadata_converged; + if let Some(outer_cycle) = outer_cycle { tracing::info!( "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" @@ -494,7 +508,7 @@ where memo_ingredient_index, ); - last_provisional_memo = Some(new_memo); + last_provisional_memo_opt = Some(new_memo); last_stale_tracked_ids = completed_query.stale_tracked_structs; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index bde3b6b24..8f0239e56 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -934,6 +934,10 @@ impl QueryOrigin { } } + pub fn is_derived_untracked(&self) -> bool { + matches!(self.kind, QueryOriginKind::DerivedUntracked) + } + /// Create a query origin of type `QueryOriginKind::Derived`, with the given edges. pub fn derived(input_outputs: Box<[QueryEdge]>) -> QueryOrigin { // Exceeding `u32::MAX` query edges should never happen in real-world usage. diff --git a/tests/cycle_input_different_cycle_head.rs b/tests/cycle_input_different_cycle_head.rs new file mode 100644 index 000000000..d7f75143c --- /dev/null +++ b/tests/cycle_input_different_cycle_head.rs @@ -0,0 +1,72 @@ +#![cfg(feature = "inventory")] + +//! Tests that the durability correctly propagates +//! to all cycle heads. + +use salsa::Setter as _; + +#[test_log::test] +fn low_durability_cycle_enter_from_different_head() { + let mut db = MyDbImpl::default(); + // Start with 0, the same as returned by cycle initial + let input = Input::builder(0).new(&db); + db.input = Some(input); + + assert_eq!(query_a(&db), 0); // Prime the Db + + input.set_value(&mut db).to(10); + + assert_eq!(query_b(&db), 10); +} + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::db] +trait MyDb: salsa::Database { + fn input(&self) -> Input; +} + +#[salsa::db] +#[derive(Clone, Default)] +struct MyDbImpl { + storage: salsa::Storage, + input: Option, +} + +#[salsa::db] +impl salsa::Database for MyDbImpl {} + +#[salsa::db] +impl MyDb for MyDbImpl { + fn input(&self) -> Input { + self.input.unwrap() + } +} + +#[salsa::tracked(cycle_initial=cycle_initial)] +fn query_a(db: &dyn MyDb) -> u32 { + query_b(db); + db.input().value(db) +} + +fn cycle_initial(_db: &dyn MyDb, _id: salsa::Id) -> u32 { + 0 +} + +#[salsa::interned] +struct Interned { + value: u32, +} + +#[salsa::tracked(cycle_initial=cycle_initial)] +fn query_b<'db>(db: &'db dyn MyDb) -> u32 { + query_c(db) +} + +#[salsa::tracked] +fn query_c(db: &dyn MyDb) -> u32 { + query_a(db) +} From 17bc55d699565e5a1cb1bd42363b905af2f9f3e7 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Fri, 21 Nov 2025 16:39:30 +0900 Subject: [PATCH 18/21] pass `Cycle` to the cycle recovery function (#1028) * pass `CycleHeads` to the cycle recovery function * remove the second parameter `Id` of `cycle_fn` * Update cycle.rs * Revert "Update cycle.rs" This reverts commit cc35b82a99a8e45270cda21f4d7bf19a093455ed. * partially revert changes in #1021 There was actually no need to run `recover_from_cycle` if the query is converged * Expose `Cycle` instead of `CycleHeads` * add `Cycle::map` * Separate `previous_value` from `Cycle` This is more ergonomic when sharing `Cycle` * `Cycle` should be passed by ref * add `Cycle::id` * Update execute.rs * Update memo.rs * defer `head_ids` creation * Revert "defer `head_ids` creation" This reverts commit 23b4ba79e79d6a8ffb8e48315d37396074469869. * make all `Cycle` fields private and provide public accessor methods --- benches/dataflow.rs | 10 ++-- .../salsa-macro-rules/src/setup_tracked_fn.rs | 5 +- .../src/unexpected_cycle_recovery.rs | 4 +- src/cycle.rs | 48 ++++++++++++++++++- src/function.rs | 5 +- src/function/execute.rs | 10 ++-- src/function/memo.rs | 3 +- src/lib.rs | 1 + tests/cycle.rs | 5 +- tests/cycle_accumulate.rs | 3 +- tests/cycle_recovery_call_back_into_cycle.rs | 3 +- tests/cycle_recovery_call_query.rs | 3 +- tests/cycle_recovery_dependencies.rs | 3 +- tests/dataflow.rs | 10 ++-- tests/parallel/cycle_panic.rs | 3 +- 15 files changed, 77 insertions(+), 39 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index cf20140f6..4d18a2532 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -76,13 +76,12 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, _last_provisional_value: &Type, value: Type, - count: u32, _def: Definition, ) -> Type { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { @@ -91,13 +90,12 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, _last_provisional_value: &Type, value: Type, - count: u32, _use: Use, ) -> Type { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } fn cycle_recover(value: Type, count: u32) -> Type { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 1c3312372..9cb311fc5 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -308,13 +308,12 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Self::Output<$db_lt>, value: Self::Output<$db_lt>, - iteration_count: u32, ($($input_id),*): ($($interned_input_ty),*) ) -> Self::Output<$db_lt> { - $($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*) + $($cycle_recovery_fn)*(db, cycle, last_provisional_value, value, $($input_id),*) } fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index fe002fa4e..e22875311 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,8 +3,8 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{ - let (_db, _id, _last_provisional_value, _count) = ($db, $id, $last_provisional_value, $count); + ($db:ident, $cycle:ident, $last_provisional_value:ident, $new_value:ident, $($other_inputs:ident),*) => {{ + let (_db, _cycle, _last_provisional_value) = ($db, $cycle, $last_provisional_value); std::mem::drop(($($other_inputs,)*)); $new_value }}; diff --git a/src/cycle.rs b/src/cycle.rs index 0f12472b4..8ab5dabcd 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -50,7 +50,7 @@ use thin_vec::{thin_vec, ThinVec}; use crate::key::DatabaseKeyIndex; use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use crate::sync::OnceLock; -use crate::Revision; +use crate::{Id, Revision}; /// The maximum number of times we'll fixpoint-iterate before panicking. /// @@ -238,6 +238,10 @@ impl CycleHeads { } } + pub(crate) fn ids(&self) -> CycleHeadIdsIterator<'_> { + CycleHeadIdsIterator { inner: self.iter() } + } + /// Iterates over all cycle heads that aren't equal to `own`. pub(crate) fn iter_not_eq( &self, @@ -392,6 +396,7 @@ impl IntoIterator for CycleHeads { } } +#[derive(Clone)] pub struct CycleHeadsIterator<'a> { inner: std::slice::Iter<'a, CycleHead>, } @@ -448,6 +453,47 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new())) } +#[derive(Clone)] +pub struct CycleHeadIdsIterator<'a> { + inner: CycleHeadsIterator<'a>, +} + +impl Iterator for CycleHeadIdsIterator<'_> { + type Item = crate::Id; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|head| head.database_key_index.key_index()) + } +} + +/// The context that the cycle recovery function receives when a query cycle occurs. +pub struct Cycle<'a> { + pub(crate) head_ids: CycleHeadIdsIterator<'a>, + pub(crate) id: Id, + pub(crate) iteration: u32, +} + +impl Cycle<'_> { + /// An iterator that outputs the [`Id`]s of the current cycle heads. + /// This always contains the [`Id`] of the current query but it can contain additional cycle head [`Id`]s + /// if this query is nested in an outer cycle or if it has nested cycles. + pub fn head_ids(&self) -> CycleHeadIdsIterator<'_> { + self.head_ids.clone() + } + + /// The [`Id`] of the query that the current cycle recovery function is processing. + pub fn id(&self) -> Id { + self.id + } + + /// The counter of the current fixed point iteration. + pub fn iteration(&self) -> u32 { + self.iteration + } +} + #[derive(Debug)] pub enum ProvisionalStatus<'db> { Provisional { diff --git a/src/function.rs b/src/function.rs index b9878bc41..f7f302727 100644 --- a/src/function.rs +++ b/src/function.rs @@ -21,7 +21,7 @@ use crate::table::Table; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryEdge, QueryOriginRef}; -use crate::{Id, Revision}; +use crate::{Cycle, Id, Revision}; #[cfg(feature = "accumulator")] mod accumulated; @@ -124,10 +124,9 @@ pub trait Configuration: Any { /// iterating until the returned value equals the previous iteration's value. fn recover_from_cycle<'db>( db: &'db Self::DbView, - id: Id, + cycle: &Cycle, last_provisional_value: &Self::Output<'db>, value: Self::Output<'db>, - iteration: u32, input: Self::Input<'db>, ) -> Self::Output<'db>; diff --git a/src/function/execute.rs b/src/function/execute.rs index 558ace738..b0b8b8609 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -12,7 +12,7 @@ use crate::sync::thread; use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; -use crate::{tracing, Cancelled}; +use crate::{tracing, Cancelled, Cycle}; use crate::{DatabaseKeyIndex, Event, EventKind, Id}; impl IngredientImpl @@ -370,14 +370,18 @@ where iteration_count }; + let cycle = Cycle { + head_ids: cycle_heads.ids(), + id, + iteration: iteration_count.as_u32(), + }; // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do (it may return the same value or a different one): new_value = C::recover_from_cycle( db, - id, + &cycle, last_provisional_value, new_value, - iteration_count.as_u32(), C::id_to_input(zalsa, id), ); diff --git a/src/function/memo.rs b/src/function/memo.rs index f22af65fe..234829cb1 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -562,10 +562,9 @@ mod _memory_usage { fn recover_from_cycle<'db>( _: &'db Self::DbView, - _: Id, + _: &crate::Cycle, _: &Self::Output<'db>, value: Self::Output<'db>, - _: u32, _: Self::Input<'db>, ) -> Self::Output<'db> { value diff --git a/src/lib.rs b/src/lib.rs index d4409c4a9..f90fce338 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; +pub use self::cycle::Cycle; pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; diff --git a/tests/cycle.rs b/tests/cycle.rs index dd476ab76..ba407226e 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -125,10 +125,9 @@ const MAX_ITERATIONS: u32 = 3; /// returning the computed value to continue iterating. fn cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Value, value: Value, - count: u32, _inputs: Inputs, ) -> Value { if &value == last_provisional_value { @@ -138,7 +137,7 @@ fn cycle_recover( .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) { Value::OutOfBounds - } else if count > MAX_ITERATIONS { + } else if cycle.iteration() > MAX_ITERATIONS { Value::TooManyIterations } else { value diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 6377805b8..63325ec13 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -50,10 +50,9 @@ fn cycle_initial(_db: &dyn LogDatabase, _id: salsa::Id, _file: File) -> Vec fn cycle_fn( _db: &dyn LogDatabase, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &[u32], value: Vec, - _count: u32, _file: File, ) -> Vec { value diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 77f7378e4..c0bbf00c1 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -27,10 +27,9 @@ fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 { fn cycle_fn( db: &dyn ValueDatabase, - _id: salsa::Id, + _cycle: &salsa::Cycle, last_provisional_value: &u32, value: u32, - _count: u32, ) -> u32 { if &value == last_provisional_value { value diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index dae4203d7..57c1915ab 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -23,10 +23,9 @@ fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 { fn cycle_fn( db: &dyn salsa::Database, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &u32, _value: u32, - _count: u32, ) -> u32 { fallback_value(db) } diff --git a/tests/cycle_recovery_dependencies.rs b/tests/cycle_recovery_dependencies.rs index fe93428e5..fd9a5f956 100644 --- a/tests/cycle_recovery_dependencies.rs +++ b/tests/cycle_recovery_dependencies.rs @@ -37,10 +37,9 @@ fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u3 fn cycle_fn( db: &dyn salsa::Database, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &u32, value: u32, - _count: u32, input: Input, ) -> u32 { let _input = input.value(db); diff --git a/tests/dataflow.rs b/tests/dataflow.rs index f91123ef0..a0d50834f 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -77,16 +77,15 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Type, value: Type, - count: u32, _def: Definition, ) -> Type { if &value == last_provisional_value { value } else { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } } @@ -96,16 +95,15 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Type, value: Type, - count: u32, _use: Use, ) -> Type { if &value == last_provisional_value { value } else { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index 34cbb7ed2..4afc375d5 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -20,10 +20,9 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 { fn cycle_fn( _db: &dyn KnobsDatabase, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &u32, _value: u32, - _count: u32, ) -> u32 { panic!("cancel!") } From 59aa1075e837f5deb0d6ffb24b68fedc0f4bc5e0 Mon Sep 17 00:00:00 2001 From: Andrew Lilley Brinker Date: Tue, 25 Nov 2025 23:38:31 -0800 Subject: [PATCH 19/21] Fully qualify std Result type (#1025) * Fully qualify std Result type The prior macro expansion could produce errors if the macros were called in a context where `Result` is redefined, for example in a crate with its own `Result` type which pre-fills the error type. This replaces existing `Result` uses with `std::result::Result` to avoid the compilation error in that case. Signed-off-by: Andrew Lilley Brinker * Use qualified names --------- Signed-off-by: Andrew Lilley Brinker Co-authored-by: Micha Reiser --- .../src/setup_input_struct.rs | 32 +++++++-------- .../src/setup_interned_struct.rs | 40 +++++++++---------- .../salsa-macro-rules/src/setup_tracked_fn.rs | 32 +++++++-------- .../src/setup_tracked_struct.rs | 28 ++++++------- .../redefine-result-input-struct-derive.rs | 15 +++++++ tests/compile_pass.rs | 8 ++++ 6 files changed, 89 insertions(+), 66 deletions(-) create mode 100644 tests/compile-pass/redefine-result-input-struct-derive.rs create mode 100644 tests/compile_pass.rs diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 741f9393e..d3c897045 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -81,7 +81,7 @@ macro_rules! setup_input_struct { #[allow(clippy::all)] #[allow(dead_code)] const _: () = { - use salsa::plumbing as $zalsa; + use ::salsa::plumbing as $zalsa; use $zalsa::input as $zalsa_struct; type $Configuration = $Struct; @@ -123,7 +123,7 @@ macro_rules! setup_input_struct { fn serialize( fields: &Self::Fields, serializer: S, - ) -> Result { + ) -> ::std::result::Result { $zalsa::macro_if! { if $persist { $($serialize_fn(fields, serializer))? @@ -135,7 +135,7 @@ macro_rules! setup_input_struct { fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( deserializer: D, - ) -> Result { + ) -> ::std::result::Result { $zalsa::macro_if! { if $persist { $($deserialize_fn(deserializer))? @@ -198,8 +198,8 @@ macro_rules! setup_input_struct { } $zalsa::macro_if! { $generate_debug_impl => - impl std::fmt::Debug for $Struct { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl ::std::fmt::Debug for $Struct { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { Self::default_debug_fmt(*self, f) } } @@ -241,7 +241,7 @@ macro_rules! setup_input_struct { $zalsa::macro_if! { $persist => impl $zalsa::serde::Serialize for $Struct { - fn serialize(&self, serializer: S) -> Result + fn serialize(&self, serializer: S) -> ::std::result::Result where S: $zalsa::serde::Serializer, { @@ -250,7 +250,7 @@ macro_rules! setup_input_struct { } impl<'de> $zalsa::serde::Deserialize<'de> for $Struct { - fn deserialize(deserializer: D) -> Result + fn deserialize(deserializer: D) -> ::std::result::Result where D: $zalsa::serde::Deserializer<'de>, { @@ -310,7 +310,7 @@ macro_rules! setup_input_struct { self, $field_index, ingredient, - |fields, f| std::mem::replace(&mut fields.$field_index, f), + |fields, f| ::std::mem::replace(&mut fields.$field_index, f), ) } )* @@ -336,11 +336,11 @@ macro_rules! setup_input_struct { } /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) - pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + pub fn default_debug_fmt(this: Self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result where // rustc rejects trivial bounds, but it cannot see through higher-ranked bounds // with its check :^) - $(for<'__trivial_bounds> $field_ty: std::fmt::Debug),* + $(for<'__trivial_bounds> $field_ty: ::std::fmt::Debug),* { $zalsa::with_attached_database(|db| { let zalsa = db.zalsa(); @@ -371,7 +371,7 @@ macro_rules! setup_input_struct { pub fn new<$Db>(self, db: &$Db) -> $Struct where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` - $Db: ?Sized + salsa::Database + $Db: ?Sized + ::salsa::Database { let (zalsa, zalsa_local) = db.zalsas(); let current_revision = zalsa.current_revision(); @@ -384,7 +384,7 @@ macro_rules! setup_input_struct { mod builder { use super::*; - use salsa::plumbing as $zalsa; + use ::salsa::plumbing as $zalsa; use $zalsa::input as $zalsa_struct; // These are standalone functions instead of methods on `Builder` to prevent @@ -392,7 +392,7 @@ macro_rules! setup_input_struct { pub(super) fn new_builder($($field_id: $field_ty),*) -> $Builder { $Builder { fields: ($($field_id,)*), - durabilities: [salsa::Durability::default(); $N], + durabilities: [::salsa::Durability::default(); $N], } } @@ -406,14 +406,14 @@ macro_rules! setup_input_struct { fields: ($($field_ty,)*), /// The durabilities per field. - durabilities: [salsa::Durability; $N], + durabilities: [::salsa::Durability; $N], } impl $Builder { /// Sets the durability of all fields. /// /// Overrides any previously set durabilities. - pub fn durability(mut self, durability: salsa::Durability) -> Self { + pub fn durability(mut self, durability: ::salsa::Durability) -> Self { self.durabilities = [durability; $N]; self } @@ -431,7 +431,7 @@ macro_rules! setup_input_struct { $( /// Sets the durability for the field `$field_id`. #[must_use] - pub fn $field_durability_id(mut self, durability: salsa::Durability) -> Self + pub fn $field_durability_id(mut self, durability: ::salsa::Durability) -> Self { self.durabilities[$field_index] = durability; self diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 1d27a33a2..b069bcac9 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -99,7 +99,7 @@ macro_rules! setup_interned_struct { #[allow(clippy::all)] #[allow(dead_code)] const _: () = { - use salsa::plumbing as $zalsa; + use ::salsa::plumbing as $zalsa; use $zalsa::interned as $zalsa_struct; type $Configuration = $StructWithStatic; @@ -120,7 +120,7 @@ macro_rules! setup_interned_struct { #[derive(Hash)] struct StructKey<$db_lt, $($indexed_ty),*>( $($indexed_ty,)* - std::marker::PhantomData<&$db_lt ()>, + ::std::marker::PhantomData<&$db_lt ()>, ); impl<$db_lt, $($indexed_ty,)*> $zalsa::interned::HashEqLike> @@ -129,7 +129,7 @@ macro_rules! setup_interned_struct { $($field_ty: $zalsa::interned::HashEqLike<$indexed_ty>),* { - fn hash(&self, h: &mut H) { + fn hash(&self, h: &mut H) { $($zalsa::interned::HashEqLike::<$indexed_ty>::hash(&self.$field_index, &mut *h);)* } @@ -147,7 +147,7 @@ macro_rules! setup_interned_struct { } } - impl salsa::plumbing::interned::Configuration for $StructWithStatic { + impl $zalsa::interned::Configuration for $StructWithStatic { const LOCATION: $zalsa::Location = $zalsa::Location { file: file!(), line: line!(), @@ -171,7 +171,7 @@ macro_rules! setup_interned_struct { fn serialize( fields: &Self::Fields<'_>, serializer: S, - ) -> Result { + ) -> ::std::result::Result { $zalsa::macro_if! { if $persist { $($serialize_fn(fields, serializer))? @@ -183,7 +183,7 @@ macro_rules! setup_interned_struct { fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( deserializer: D, - ) -> Result, D::Error> { + ) -> ::std::result::Result, D::Error> { $zalsa::macro_if! { if $persist { $($deserialize_fn(deserializer))? @@ -210,14 +210,14 @@ macro_rules! setup_interned_struct { } impl< $($db_lt_arg)? > $zalsa::AsId for $Struct< $($db_lt_arg)? > { - fn as_id(&self) -> salsa::Id { + fn as_id(&self) -> ::salsa::Id { self.0.as_id() } } impl< $($db_lt_arg)? > $zalsa::FromId for $Struct< $($db_lt_arg)? > { - fn from_id(id: salsa::Id) -> Self { - Self(<$Id>::from_id(id), std::marker::PhantomData) + fn from_id(id: ::salsa::Id) -> Self { + Self(<$Id>::from_id(id), ::std::marker::PhantomData) } } @@ -226,8 +226,8 @@ macro_rules! setup_interned_struct { unsafe impl< $($db_lt_arg)? > Sync for $Struct< $($db_lt_arg)? > {} $zalsa::macro_if! { $generate_debug_impl => - impl< $($db_lt_arg)? > std::fmt::Debug for $Struct< $($db_lt_arg)? > { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl< $($db_lt_arg)? > ::std::fmt::Debug for $Struct< $($db_lt_arg)? > { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { Self::default_debug_fmt(*self, f) } } @@ -269,7 +269,7 @@ macro_rules! setup_interned_struct { $zalsa::macro_if! { $persist => impl<$($db_lt_arg)?> $zalsa::serde::Serialize for $Struct<$($db_lt_arg)?> { - fn serialize(&self, serializer: S) -> Result + fn serialize(&self, serializer: S) -> ::std::result::Result where S: $zalsa::serde::Serializer, { @@ -278,7 +278,7 @@ macro_rules! setup_interned_struct { } impl<'de, $($db_lt_arg)?> $zalsa::serde::Deserialize<'de> for $Struct<$($db_lt_arg)?> { - fn deserialize(deserializer: D) -> Result + fn deserialize(deserializer: D) -> ::std::result::Result where D: $zalsa::serde::Deserializer<'de>, { @@ -301,17 +301,17 @@ macro_rules! setup_interned_struct { } impl<$db_lt> $Struct< $($db_lt_arg)? > { - pub fn $new_fn<$Db, $($indexed_ty: $zalsa::interned::Lookup<$field_ty> + std::hash::Hash,)*>(db: &$db_lt $Db, $($field_id: $indexed_ty),*) -> Self + pub fn $new_fn<$Db, $($indexed_ty: $zalsa::interned::Lookup<$field_ty> + ::std::hash::Hash,)*>(db: &$db_lt $Db, $($field_id: $indexed_ty),*) -> Self where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` - $Db: ?Sized + salsa::Database, + $Db: ?Sized + ::salsa::Database, $( $field_ty: $zalsa::interned::HashEqLike<$indexed_ty>, )* { let (zalsa, zalsa_local) = db.zalsas(); $Configuration::ingredient(zalsa).intern(zalsa, zalsa_local, - StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*)) + StructKey::<$db_lt>($($field_id,)* ::std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*)) } $( @@ -337,11 +337,11 @@ macro_rules! setup_interned_struct { iftt ($($db_lt_arg)?) { impl $Struct<'_> { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) - pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + pub fn default_debug_fmt(this: Self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result where // rustc rejects trivial bounds, but it cannot see through higher-ranked bounds // with its check :^) - $(for<$db_lt> $field_ty: std::fmt::Debug),* + $(for<$db_lt> $field_ty: ::std::fmt::Debug),* { $zalsa::with_attached_database(|db| { let zalsa = db.zalsa(); @@ -361,11 +361,11 @@ macro_rules! setup_interned_struct { } else { impl $Struct { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) - pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + pub fn default_debug_fmt(this: Self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result where // rustc rejects trivial bounds, but it cannot see through higher-ranked bounds // with its check :^) - $(for<$db_lt> $field_ty: std::fmt::Debug),* + $(for<$db_lt> $field_ty: ::std::fmt::Debug),* { $zalsa::with_attached_database(|db| { let zalsa = db.zalsa(); diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 9cb311fc5..e1619095f 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -89,8 +89,8 @@ macro_rules! setup_tracked_fn { $vis fn $fn_name<$db_lt>( $db: &$db_lt dyn $Db, $($input_id: $input_ty,)* - ) -> salsa::plumbing::return_mode_ty!(($return_mode, __, __), $db_lt, $output_ty) { - use salsa::plumbing as $zalsa; + ) -> ::salsa::plumbing::return_mode_ty!(($return_mode, __, __), $db_lt, $output_ty) { + use ::salsa::plumbing as $zalsa; struct $Configuration; @@ -111,8 +111,8 @@ macro_rules! setup_tracked_fn { if $needs_interner { #[derive(Copy, Clone)] struct $InternedData<$db_lt>( - salsa::Id, - std::marker::PhantomData &$db_lt ()>, + ::salsa::Id, + ::std::marker::PhantomData &$db_lt ()>, ); static $INTERN_CACHE: $zalsa::IngredientCache<$zalsa::interned::IngredientImpl<$Configuration>> = @@ -154,14 +154,14 @@ macro_rules! setup_tracked_fn { impl $zalsa::AsId for $InternedData<'_> { #[inline] - fn as_id(&self) -> salsa::Id { + fn as_id(&self) -> ::salsa::Id { self.0 } } impl $zalsa::FromId for $InternedData<'_> { #[inline] - fn from_id(id: salsa::Id) -> Self { + fn from_id(id: ::salsa::Id) -> Self { Self(id, ::core::marker::PhantomData) } } @@ -181,7 +181,7 @@ macro_rules! setup_tracked_fn { fn serialize( fields: &Self::Fields<'_>, serializer: S, - ) -> Result { + ) -> ::std::result::Result { $zalsa::macro_if! { if $persist { $zalsa::serde::Serialize::serialize(fields, serializer) @@ -193,7 +193,7 @@ macro_rules! setup_tracked_fn { fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( deserializer: D, - ) -> Result, D::Error> { + ) -> ::std::result::Result, D::Error> { $zalsa::macro_if! { if $persist { $zalsa::serde::Deserialize::deserialize(deserializer) @@ -302,13 +302,13 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } - fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, id: salsa::Id, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { + fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, id: ::salsa::Id, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { $($cycle_recovery_initial)*(db, id, $($input_id),*) } fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - cycle: &salsa::Cycle, + cycle: &::salsa::Cycle, last_provisional_value: &Self::Output<$db_lt>, value: Self::Output<$db_lt>, ($($input_id),*): ($($interned_input_ty),*) @@ -316,7 +316,7 @@ macro_rules! setup_tracked_fn { $($cycle_recovery_fn)*(db, cycle, last_provisional_value, value, $($input_id),*) } - fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { + fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: ::salsa::Id) -> Self::Input<$db_lt> { $zalsa::macro_if! { if $needs_interner { $Configuration::intern_ingredient_(zalsa).data(zalsa, key).clone() @@ -329,7 +329,7 @@ macro_rules! setup_tracked_fn { fn serialize( value: &Self::Output<'_>, serializer: S, - ) -> Result { + ) -> ::std::result::Result { $zalsa::macro_if! { if $persist { $zalsa::serde::Serialize::serialize(value, serializer) @@ -341,7 +341,7 @@ macro_rules! setup_tracked_fn { fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( deserializer: D, - ) -> Result, D::Error> { + ) -> ::std::result::Result, D::Error> { $zalsa::macro_if! { if $persist { $zalsa::serde::Deserialize::deserialize(deserializer) @@ -419,11 +419,11 @@ macro_rules! setup_tracked_fn { #[allow(non_local_definitions)] impl $fn_name { $zalsa::gate_accumulated! { - pub fn accumulated<$db_lt, A: salsa::Accumulator>( + pub fn accumulated<$db_lt, A: ::salsa::Accumulator>( $db: &$db_lt dyn $Db, $($input_id: $interned_input_ty,)* ) -> Vec<&$db_lt A> { - use salsa::plumbing as $zalsa; + use ::salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { if $needs_interner {{ let (zalsa, zalsa_local) = $db.zalsas(); @@ -491,7 +491,7 @@ macro_rules! setup_tracked_fn { #[doc(hidden)] #[allow(non_camel_case_types)] $vis struct $fn_name { - _priv: std::convert::Infallible, + _priv: ::std::convert::Infallible, } }; } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 92dc25974..970cb31d6 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -115,14 +115,14 @@ macro_rules! setup_tracked_struct { $(#[$attr])* #[derive(Copy, Clone, PartialEq, Eq, Hash)] $vis struct $Struct<$db_lt>( - salsa::Id, - std::marker::PhantomData &$db_lt ()> + ::salsa::Id, + ::std::marker::PhantomData &$db_lt ()> ); #[allow(dead_code)] #[allow(clippy::all)] const _: () = { - use salsa::plumbing as $zalsa; + use ::salsa::plumbing as $zalsa; use $zalsa::tracked_struct as $zalsa_struct; use $zalsa::Revision as $Revision; @@ -160,7 +160,7 @@ macro_rules! setup_tracked_struct { type Struct<$db_lt> = $Struct<$db_lt>; - fn untracked_fields(fields: &Self::Fields<'_>) -> impl std::hash::Hash { + fn untracked_fields(fields: &Self::Fields<'_>) -> impl ::std::hash::Hash { ( $( &fields.$absolute_untracked_index ),* ) } @@ -209,7 +209,7 @@ macro_rules! setup_tracked_struct { fn serialize( fields: &Self::Fields<'_>, serializer: S, - ) -> Result { + ) -> ::std::result::Result { $zalsa::macro_if! { if $persist { $($serialize_fn(fields, serializer))? @@ -221,7 +221,7 @@ macro_rules! setup_tracked_struct { fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( deserializer: D, - ) -> Result, D::Error> { + ) -> ::std::result::Result, D::Error> { $zalsa::macro_if! { if $persist { $($deserialize_fn(deserializer))? @@ -253,8 +253,8 @@ macro_rules! setup_tracked_struct { impl<$db_lt> $zalsa::FromId for $Struct<$db_lt> { #[inline] - fn from_id(id: salsa::Id) -> Self { - $Struct(id, std::marker::PhantomData) + fn from_id(id: ::salsa::Id) -> Self { + $Struct(id, ::std::marker::PhantomData) } } @@ -307,7 +307,7 @@ macro_rules! setup_tracked_struct { $zalsa::macro_if! { $persist => impl $zalsa::serde::Serialize for $Struct<'_> { - fn serialize(&self, serializer: S) -> Result + fn serialize(&self, serializer: S) -> ::std::result::Result where S: $zalsa::serde::Serializer, { @@ -316,7 +316,7 @@ macro_rules! setup_tracked_struct { } impl<'de> $zalsa::serde::Deserialize<'de> for $Struct<'_> { - fn deserialize(deserializer: D) -> Result + fn deserialize(deserializer: D) -> ::std::result::Result where D: $zalsa::serde::Deserializer<'de>, { @@ -332,8 +332,8 @@ macro_rules! setup_tracked_struct { unsafe impl Sync for $Struct<'_> {} $zalsa::macro_if! { $generate_debug_impl => - impl std::fmt::Debug for $Struct<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl ::std::fmt::Debug for $Struct<'_> { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { Self::default_debug_fmt(*self, f) } } @@ -401,13 +401,13 @@ macro_rules! setup_tracked_struct { #[allow(unused_lifetimes)] impl<'_db> $Struct<'_db> { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) - pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + pub fn default_debug_fmt(this: Self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result where // `zalsa::with_attached_database` has a local lifetime for the database // so we need this function to be higher-ranked over the db lifetime // Thus the actual lifetime of `Self` does not matter here so we discard // it with the `'_db` lifetime name as we cannot shadow lifetimes. - $(for<$db_lt> $field_ty: std::fmt::Debug),* + $(for<$db_lt> $field_ty: ::std::fmt::Debug),* { $zalsa::with_attached_database(|db| { let zalsa = db.zalsa(); diff --git a/tests/compile-pass/redefine-result-input-struct-derive.rs b/tests/compile-pass/redefine-result-input-struct-derive.rs new file mode 100644 index 000000000..7615cb286 --- /dev/null +++ b/tests/compile-pass/redefine-result-input-struct-derive.rs @@ -0,0 +1,15 @@ +// Ensure the `salsa::tracked` attribute macro doesn't conflict with local +// redefinition of the `Result` type. +// +// See: https://github.com/salsa-rs/salsa/pull/1025 + +type Result = std::result::Result; + +#[salsa::tracked] +fn example_query(_db: &dyn salsa::Database) -> Result<()> { + Ok(()) +} + +fn main() { + println!("Hello, world!"); +} diff --git a/tests/compile_pass.rs b/tests/compile_pass.rs new file mode 100644 index 000000000..d6a5265a3 --- /dev/null +++ b/tests/compile_pass.rs @@ -0,0 +1,8 @@ +#![cfg(all(feature = "inventory", feature = "persistence"))] + +#[rustversion::all(stable, since(1.90))] +#[test] +fn compile_pass() { + let t = trybuild::TestCases::new(); + t.pass("tests/compile-pass/*.rs"); +} From 60d029a4d9e305ff6cf3237c6edb65740024e822 Mon Sep 17 00:00:00 2001 From: Jack O'Connor Date: Wed, 3 Dec 2025 23:14:18 -0800 Subject: [PATCH 20/21] implement `Update` for `OrderMap` and `OrderSet` (#1033) --- Cargo.toml | 1 + src/update.rs | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 9c419e339..62deac321 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ hashbrown = "0.15" hashlink = "0.10" indexmap = "2" intrusive-collections = "0.9.7" +ordermap = "1" parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" diff --git a/src/update.rs b/src/update.rs index d95bd13a9..65a10df07 100644 --- a/src/update.rs +++ b/src/update.rs @@ -340,6 +340,27 @@ where } } +unsafe impl Update for ordermap::OrderMap +where + K: Update + Eq + Hash, + V: Update, + S: BuildHasher, +{ + unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool { + maybe_update_map!(old_pointer, new_map) + } +} + +unsafe impl Update for ordermap::OrderSet +where + K: Update + Eq + Hash, + S: BuildHasher, +{ + unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool { + maybe_update_set!(old_pointer, new_set) + } +} + unsafe impl Update for BTreeMap where K: Update + Eq + Ord, From 55e5e7d32fa3fc189276f35bb04c9438f9aedbd1 Mon Sep 17 00:00:00 2001 From: Jack O'Connor Date: Thu, 4 Dec 2025 09:52:00 -0800 Subject: [PATCH 21/21] Make `ordermap` an optional feature (#1034) --- Cargo.toml | 3 ++- src/update.rs | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 62deac321..ebca6071f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,6 @@ hashbrown = "0.15" hashlink = "0.10" indexmap = "2" intrusive-collections = "0.9.7" -ordermap = "1" parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" @@ -42,6 +41,8 @@ shuttle = { version = "0.8.1", optional = true } erased-serde = { version = "0.4.6", optional = true } serde = { version = "1.0.219", features = ["derive"], optional = true } +ordermap = { version = "1.0.0", optional = true } + [features] default = ["salsa_unstable", "rayon", "macros", "inventory", "accumulator"] inventory = ["dep:inventory"] diff --git a/src/update.rs b/src/update.rs index 65a10df07..126c813b3 100644 --- a/src/update.rs +++ b/src/update.rs @@ -340,6 +340,7 @@ where } } +#[cfg(feature = "ordermap")] unsafe impl Update for ordermap::OrderMap where K: Update + Eq + Hash, @@ -351,6 +352,7 @@ where } } +#[cfg(feature = "ordermap")] unsafe impl Update for ordermap::OrderSet where K: Update + Eq + Hash,