fast_stm/transaction/
mod.rs

1#[cfg(feature = "wait-on-retry")]
2pub mod control_block;
3pub mod log_var;
4
5use std::any::Any;
6use std::cell::Cell;
7// #[cfg(feature = "hash-registers")]
8// use std::collections::hash_map::Entry;
9// #[cfg(not(feature = "hash-registers"))]
10// use std::collections::{btree_map::Entry, BTreeMap};
11cfg_if::cfg_if! {
12    if #[cfg(feature = "hash-registers")] {
13        use std::collections::hash_map::Entry;
14    } else {
15        use std::collections::{btree_map::Entry, BTreeMap};
16    }
17}
18use std::mem;
19use std::sync::Arc;
20
21#[cfg(feature = "hash-registers")]
22use rustc_hash::FxHashMap;
23
24use crate::{TransactionClosureResult, TransactionError, TransactionResult};
25
26#[cfg(feature = "wait-on-retry")]
27use self::control_block::ControlBlock;
28use self::log_var::LogVar;
29use super::result::{StmClosureResult, StmError};
30use super::tvar::{TVar, VarControlBlock};
31
32thread_local!(static TRANSACTION_RUNNING: Cell<bool> = const { Cell::new(false) });
33
34/// `TransactionGuard` checks against nested STM calls.
35///
36/// Use guard, so that it correctly marks the Transaction as finished.
37struct TransactionGuard;
38
39impl TransactionGuard {
40    pub fn new() -> TransactionGuard {
41        TRANSACTION_RUNNING.with(|t| {
42            assert!(!t.get(), "STM: Nested Transaction");
43            t.set(true);
44        });
45        TransactionGuard
46    }
47}
48
49impl Drop for TransactionGuard {
50    fn drop(&mut self) {
51        TRANSACTION_RUNNING.with(|t| {
52            t.set(false);
53        });
54    }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum TransactionControl {
59    Retry,
60    Abort,
61}
62
63/// Transaction tracks all the read and written variables.
64///
65/// It is used for checking vars, to ensure atomicity.
66pub struct Transaction {
67    /// Map of all vars that map the `VarControlBlock` of a var to a `LogVar`.
68    /// The `VarControlBlock` is unique because it uses it's address for comparing.
69    ///
70    /// The logs need to be accessed in a order to prevend dead-locks on locking.
71    #[cfg(not(feature = "hash-registers"))]
72    vars: BTreeMap<Arc<VarControlBlock>, LogVar>,
73    #[cfg(feature = "hash-registers")]
74    vars: FxHashMap<*const VarControlBlock, LogVar>,
75}
76
77impl Transaction {
78    /// Create a new log.
79    ///
80    /// Normally you don't need to call this directly.
81    /// Use `atomically` instead.
82    fn new() -> Transaction {
83        Transaction {
84            #[cfg(not(feature = "hash-registers"))]
85            vars: BTreeMap::new(),
86            #[cfg(feature = "hash-registers")]
87            vars: FxHashMap::default(),
88        }
89    }
90
91    /// Run a function with a transaction.
92    ///
93    /// It is equivalent to `atomically`.
94    pub fn with<T, F>(f: F) -> T
95    where
96        F: Fn(&mut Transaction) -> StmClosureResult<T>,
97    {
98        match Transaction::with_control(|_| TransactionControl::Retry, f) {
99            Some(t) => t,
100            None => unreachable!(),
101        }
102    }
103
104    /// Run a function with a transaction.
105    ///
106    /// `with_control` takes another control function, that
107    /// can steer the control flow and possible terminate early.
108    ///
109    /// `control` can react to counters, timeouts or external inputs.
110    ///
111    /// It allows the user to fall back to another strategy, like a global lock
112    /// in the case of too much contention.
113    ///
114    /// Please not, that the transaction may still infinitely wait for changes when `retry` is
115    /// called and `control` does not abort.
116    /// If you need a timeout, another thread should signal this through a [`TVar`].
117    pub fn with_control<T, F, C>(mut control: C, f: F) -> Option<T>
118    where
119        F: Fn(&mut Transaction) -> StmClosureResult<T>,
120        C: FnMut(StmError) -> TransactionControl,
121    {
122        let _guard = TransactionGuard::new();
123
124        // create a log guard for initializing and cleaning up
125        // the log
126        let mut transaction = Transaction::new();
127
128        // loop until success
129        loop {
130            // run the computation
131            match f(&mut transaction) {
132                // on success exit loop
133                Ok(t) => {
134                    if transaction.commit() {
135                        return Some(t);
136                    }
137                }
138
139                Err(e) => {
140                    // Check if the user wants to abort the transaction.
141                    if let TransactionControl::Abort = control(e) {
142                        return None;
143                    }
144
145                    // on retry wait for changes
146                    #[cfg(feature = "wait-on-retry")]
147                    if let StmError::Retry = e {
148                        transaction.wait_for_change();
149                    }
150                }
151            }
152
153            // clear log before retrying computation
154            transaction.clear();
155        }
156    }
157
158    /// Run a function with a transaction.
159    ///
160    /// The transaction will be retried until:
161    /// - it is validated, or
162    /// - it is explicitly aborted from the function, using the `TODO` function.
163    pub fn with_err<T, F, E>(f: F) -> Result<T, E>
164    where
165        F: Fn(&mut Transaction) -> TransactionClosureResult<T, E>,
166    {
167        let _guard = TransactionGuard::new();
168
169        // create a log guard for initializing and cleaning up
170        // the log
171        let mut transaction = Transaction::new();
172
173        // loop until success
174        loop {
175            // run the computation
176            match f(&mut transaction) {
177                // on success exit loop
178                Ok(t) => {
179                    if transaction.commit() {
180                        return Ok(t);
181                    }
182                }
183                // on error,
184                Err(e) => match e {
185                    // abort and return the error
186                    TransactionError::Abort(err) => return Err(err),
187                    // retry
188                    TransactionError::Stm(_) => {
189                        #[cfg(feature = "wait-on-retry")]
190                        transaction.wait_for_change();
191                    }
192                },
193            }
194
195            // clear log before retrying computation
196            transaction.clear();
197        }
198    }
199
200    /// Run a function with a transaction.
201    ///
202    /// `with_control` takes another control function, that
203    /// can steer the control flow and possible terminate early.
204    ///
205    /// `control` can react to counters, timeouts or external inputs.
206    ///
207    /// It allows the user to fall back to another strategy, like a global lock
208    /// in the case of too much contention.
209    ///
210    /// Please not, that the transaction may still infinitely wait for changes when `retry` is
211    /// called and `control` does not abort.
212    /// If you need a timeout, another thread should signal this through a [`TVar`].
213    pub fn with_control_and_err<T, F, C, E>(mut control: C, f: F) -> TransactionResult<T, E>
214    where
215        F: Fn(&mut Transaction) -> TransactionClosureResult<T, E>,
216        C: FnMut(StmError) -> TransactionControl,
217    {
218        let _guard = TransactionGuard::new();
219
220        // create a log guard for initializing and cleaning up
221        // the log
222        let mut transaction = Transaction::new();
223
224        // loop until success
225        loop {
226            // run the computation
227            match f(&mut transaction) {
228                // on success exit loop
229                Ok(t) => {
230                    if transaction.commit() {
231                        return TransactionResult::Validated(t);
232                    }
233                }
234
235                Err(e) => {
236                    match e {
237                        TransactionError::Abort(err) => {
238                            return TransactionResult::Cancelled(err);
239                        }
240                        TransactionError::Stm(err) => {
241                            // Check if the user wants to abort the transaction.
242                            if let TransactionControl::Abort = control(err) {
243                                return TransactionResult::Abandoned;
244                            }
245
246                            // on retry wait for changes
247                            #[cfg(feature = "wait-on-retry")]
248                            if let StmError::Retry = err {
249                                transaction.wait_for_change();
250                            }
251                        }
252                    }
253                }
254            }
255
256            // clear log before retrying computation
257            transaction.clear();
258        }
259    }
260
261    #[allow(clippy::needless_pass_by_value)]
262    /// Perform a downcast on a var.
263    fn downcast<T: Any + Clone>(var: Arc<dyn Any>) -> T {
264        match var.downcast_ref::<T>() {
265            Some(s) => s.clone(),
266            None => unreachable!("TVar has wrong type"),
267        }
268    }
269
270    /// Read a variable and return the value.
271    ///
272    /// The returned value is not always consistent with the current value of the var,
273    /// but may be an outdated or or not yet commited value.
274    ///
275    /// The used code should be capable of handling inconsistent states
276    /// without running into infinite loops.
277    /// Just the commit of wrong values is prevented by STM.
278    pub fn read<T: Send + Sync + Any + Clone>(&mut self, var: &TVar<T>) -> StmClosureResult<T> {
279        let ctrl = var.control_block().clone();
280        // Check if the same var was written before.
281        #[cfg(not(feature = "hash-registers"))]
282        let key = ctrl;
283        #[cfg(feature = "hash-registers")]
284        let key = Arc::as_ptr(&ctrl);
285        let value = match self.vars.entry(key) {
286            // If the variable has been accessed before, then load that value.
287            #[cfg(feature = "early-conflict-detection")]
288            Entry::Occupied(mut entry) => {
289                let log = entry.get_mut();
290                // if we previously read the var, check for value change
291                if let LogVar::Read(v) = log {
292                    let crt_v = var.read_ref_atomic();
293                    if !Arc::ptr_eq(v, &crt_v) {
294                        return Err(StmError::Failure);
295                    }
296                }
297                log.read()
298            }
299            #[cfg(not(feature = "early-conflict-detection"))]
300            Entry::Occupied(mut entry) => entry.get_mut().read(),
301
302            // Else load the variable statically.
303            Entry::Vacant(entry) => {
304                // Read the value from the var.
305                let value = var.read_ref_atomic();
306
307                // Store in in an entry.
308                entry.insert(LogVar::Read(value.clone()));
309                value
310            }
311        };
312
313        Ok(Transaction::downcast(value))
314    }
315
316    /// Write a variable.
317    ///
318    /// The write is not immediately visible to other threads,
319    /// but atomically commited at the end of the computation.
320    pub fn write<T: Any + Send + Sync + Clone>(
321        &mut self,
322        var: &TVar<T>,
323        value: T,
324    ) -> StmClosureResult<()> {
325        // box the value
326        let boxed = Arc::new(value);
327
328        // new control block
329        let ctrl = var.control_block().clone();
330        // update or create new entry
331        #[cfg(not(feature = "hash-registers"))]
332        let key = ctrl;
333        #[cfg(feature = "hash-registers")]
334        let key = Arc::as_ptr(&ctrl);
335        match self.vars.entry(key) {
336            Entry::Occupied(mut entry) => entry.get_mut().write(boxed),
337            Entry::Vacant(entry) => {
338                entry.insert(LogVar::Write(boxed));
339            }
340        }
341
342        // For now always succeeds, but that may change later.
343        Ok(())
344    }
345
346    /// Combine two calculations. When one blocks with `retry`,
347    /// run the other, but don't commit the changes in the first.
348    ///
349    /// If both block, `Transaction::or` still waits for `TVar`s in both functions.
350    /// Use `Transaction::or` instead of handling errors directly with the `Result::or`.
351    /// The later does not handle all the blocking correctly.
352    pub fn or<T, F1, F2>(&mut self, first: F1, second: F2) -> StmClosureResult<T>
353    where
354        F1: Fn(&mut Transaction) -> StmClosureResult<T>,
355        F2: Fn(&mut Transaction) -> StmClosureResult<T>,
356    {
357        // Create a backup of the log.
358        let mut copy = Transaction {
359            vars: self.vars.clone(),
360        };
361
362        // Run the first computation.
363        let f = first(self);
364
365        match f {
366            // Run other on manual retry call.
367            Err(StmError::Retry) => {
368                // swap, so that self is the current run
369                mem::swap(self, &mut copy);
370
371                // Run other action.
372                let s = second(self);
373
374                // If both called retry then exit.
375                match s {
376                    Err(StmError::Failure) => Err(StmError::Failure),
377                    s => {
378                        self.combine(copy);
379                        s
380                    }
381                }
382            }
383
384            // Return success and failure directly
385            x => x,
386        }
387    }
388
389    /// Combine two logs into a single log, to allow waiting for all reads.
390    fn combine(&mut self, other: Transaction) {
391        // combine reads
392        for (var, value) in other.vars {
393            // only insert new values
394            if let Some(value) = value.obsolete() {
395                self.vars.entry(var).or_insert(value);
396            }
397        }
398    }
399
400    /// Clear the log's data.
401    ///
402    /// This should be used before redoing a computation, but
403    /// nowhere else.
404    fn clear(&mut self) {
405        self.vars.clear();
406    }
407
408    /// Wait for any variable to change,
409    /// because the change may lead to a new calculation result.
410    #[cfg(feature = "wait-on-retry")]
411    fn wait_for_change(&mut self) {
412        // Create control block for waiting.
413        let ctrl = Arc::new(ControlBlock::new());
414
415        #[allow(clippy::mutable_key_type)]
416        let vars = std::mem::take(&mut self.vars);
417        let mut reads = Vec::with_capacity(self.vars.len());
418
419        let blocking = vars
420            .into_iter()
421            .filter_map(|(a, b)| b.into_read_value().map(|b| (a, b)))
422            // Check for consistency.
423            .all(|(var, value)| {
424                #[cfg(feature = "hash-registers")]
425                let var = unsafe { var.as_ref() }.expect("E: unreachabel");
426                var.wait(&ctrl);
427                let x = {
428                    // Take read lock and read value.
429                    let guard = var.value.read();
430                    Arc::ptr_eq(&value, &guard)
431                };
432                reads.push(var);
433                x
434            });
435
436        // If no var has changed, then block.
437        if blocking {
438            // Propably wait until one var has changed.
439            ctrl.wait();
440        }
441
442        // Let others know that ctrl is dead.
443        // It does not matter, if we set too many
444        // to dead since it may slightly reduce performance
445        // but not break the semantics.
446        for var in &reads {
447            var.set_dead();
448        }
449    }
450
451    /// Write the log back to the variables.
452    ///
453    /// Return true for success and false, if a read var has changed
454    fn commit(&mut self) -> bool {
455        // Use two phase locking for safely writing data back to the vars.
456
457        // First phase: acquire locks.
458        // Check for consistency of all the reads and perform
459        // an early return if something is not consistent.
460
461        // Created arrays for storing the locks
462        // vector of locks.
463        let mut read_vec = Vec::with_capacity(self.vars.len());
464
465        // vector of tuple (value, lock)
466        let mut write_vec = Vec::with_capacity(self.vars.len());
467
468        // vector of written variables
469        let mut written = Vec::with_capacity(self.vars.len());
470
471        #[cfg(feature = "hash-registers")]
472        let records = {
473            let mut recs: Vec<_> = self.vars.iter().collect();
474            recs.sort_by(|(k1, _), (k2, _)| k1.cmp(&k2));
475            recs
476        };
477        #[cfg(not(feature = "hash-registers"))]
478        let records = &self.vars;
479
480        for (var, value) in records {
481            // lock the variable and read the value
482            #[cfg(feature = "hash-registers")]
483            let var = unsafe { var.as_ref() }.expect("E: unreachabel");
484
485            match *value {
486                // We need to take a write lock.
487                LogVar::Write(ref w) | LogVar::ReadObsoleteWrite(_, ref w) => {
488                    // take write lock
489                    let lock = var.value.write();
490                    // add all data to the vector
491                    write_vec.push((w, lock));
492                    written.push(var);
493                }
494
495                // We need to check for consistency and
496                // take a write lock.
497                LogVar::ReadWrite(ref original, ref w) => {
498                    // take write lock
499                    let lock = var.value.write();
500
501                    if !Arc::ptr_eq(&lock, original) {
502                        return false;
503                    }
504                    // add all data to the vector
505                    write_vec.push((w, lock));
506                    written.push(var);
507                }
508                // Nothing to do. ReadObsolete is only needed for blocking, not
509                // for consistency checks.
510                LogVar::ReadObsolete(_) => {}
511                // Take read lock and check for consistency.
512                LogVar::Read(ref original) => {
513                    // Take a read lock.
514                    let lock = var.value.read();
515
516                    if !Arc::ptr_eq(&lock, original) {
517                        return false;
518                    }
519
520                    read_vec.push(lock);
521                }
522            }
523        }
524
525        // Second phase: write back and release
526
527        // Release the reads first.
528        // This allows other threads to continue quickly.
529        drop(read_vec);
530
531        for (value, mut lock) in write_vec {
532            // Commit value.
533            *lock = value.clone();
534        }
535
536        #[cfg(feature = "wait-on-retry")]
537        for var in written {
538            // Unblock all threads waiting for it.
539            var.wake_all();
540        }
541
542        // Commit succeded.
543        true
544    }
545}
546
547#[cfg(test)]
548mod test {
549    use super::*;
550    #[test]
551    fn read() {
552        let mut log = Transaction::new();
553        let var = TVar::new(vec![1, 2, 3, 4]);
554
555        // The variable can be read.
556        assert_eq!(&*log.read(&var).unwrap(), &[1, 2, 3, 4]);
557    }
558
559    #[test]
560    fn write_read() {
561        let mut log = Transaction::new();
562        let var = TVar::new(vec![1, 2]);
563
564        log.write(&var, vec![1, 2, 3, 4]).unwrap();
565
566        // Consecutive reads get the updated version.
567        assert_eq!(log.read(&var).unwrap(), [1, 2, 3, 4]);
568
569        // The original value is still preserved.
570        assert_eq!(var.read_atomic(), [1, 2]);
571    }
572
573    #[test]
574    fn transaction_simple() {
575        let x = Transaction::with(|_| Ok(42));
576        assert_eq!(x, 42);
577    }
578
579    #[test]
580    fn transaction_read() {
581        let read = TVar::new(42);
582
583        let x = Transaction::with(|trans| read.read(trans));
584
585        assert_eq!(x, 42);
586    }
587
588    /// Run a transaction with a control function, that always aborts.
589    /// The transaction still tries to run a single time and should successfully
590    /// commit in this test.
591    #[test]
592    fn transaction_with_control_abort_on_single_run() {
593        let read = TVar::new(42);
594
595        let x = Transaction::with_control(|_| TransactionControl::Abort, |tx| read.read(tx));
596
597        assert_eq!(x, Some(42));
598    }
599
600    /// Run a transaction with a control function, that always aborts.
601    /// The transaction retries infinitely often. The control function will abort this loop.
602    #[test]
603    fn transaction_with_control_abort_on_retry() {
604        let x: Option<i32> =
605            Transaction::with_control(|_| TransactionControl::Abort, |_| Err(StmError::Retry));
606
607        assert_eq!(x, None);
608    }
609
610    #[test]
611    fn transaction_write() {
612        let write = TVar::new(42);
613
614        Transaction::with(|trans| write.write(trans, 0));
615
616        assert_eq!(write.read_atomic(), 0);
617    }
618
619    #[test]
620    fn transaction_copy() {
621        let read = TVar::new(42);
622        let write = TVar::new(0);
623
624        Transaction::with(|trans| {
625            let r = read.read(trans)?;
626            write.write(trans, r)
627        });
628
629        assert_eq!(write.read_atomic(), 42);
630    }
631
632    // Dat name. seriously?
633    #[test]
634    fn transaction_control_stuff() {
635        let read = TVar::new(42);
636        let write = TVar::new(0);
637
638        Transaction::with(|trans| {
639            let r = read.read(trans)?;
640            write.write(trans, r)
641        });
642
643        assert_eq!(write.read_atomic(), 42);
644    }
645
646    /// Test if nested transactions are correctly detected.
647    #[test]
648    #[should_panic]
649    fn transaction_nested_fail() {
650        Transaction::with(|_| {
651            Transaction::with(|_| Ok(42));
652            Ok(1)
653        });
654    }
655}