axsync/
mutex.rs

1//! A naïve sleeping mutex.
2
3use core::sync::atomic::{AtomicU64, Ordering};
4
5use axtask::{WaitQueue, current};
6
7/// A [`lock_api::RawMutex`] implementation.
8///
9/// When the mutex is locked, the current task will block and be put into the
10/// wait queue. When the mutex is unlocked, all tasks waiting on the queue
11/// will be woken up.
12pub struct RawMutex {
13    wq: WaitQueue,
14    owner_id: AtomicU64,
15}
16
17impl RawMutex {
18    /// Creates a [`RawMutex`].
19    #[inline(always)]
20    pub const fn new() -> Self {
21        Self {
22            wq: WaitQueue::new(),
23            owner_id: AtomicU64::new(0),
24        }
25    }
26}
27
28unsafe impl lock_api::RawMutex for RawMutex {
29    const INIT: Self = RawMutex::new();
30
31    type GuardMarker = lock_api::GuardSend;
32
33    fn lock(&self) {
34        let current_id = current().id().as_u64();
35        loop {
36            // Can fail to lock even if the spinlock is not locked. May be more efficient than `try_lock`
37            // when called in a loop.
38            match self.owner_id.compare_exchange_weak(
39                0,
40                current_id,
41                Ordering::Acquire,
42                Ordering::Relaxed,
43            ) {
44                Ok(_) => break,
45                Err(owner_id) => {
46                    assert_ne!(
47                        owner_id,
48                        current_id,
49                        "{} tried to acquire mutex it already owns.",
50                        current().id_name()
51                    );
52                    // Wait until the lock looks unlocked before retrying
53                    self.wq.wait_until(|| !self.is_locked());
54                }
55            }
56        }
57    }
58
59    fn try_lock(&self) -> bool {
60        let current_id = current().id().as_u64();
61        // The reason for using a strong compare_exchange is explained here:
62        // https://github.com/Amanieu/parking_lot/pull/207#issuecomment-575869107
63        self.owner_id
64            .compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
65            .is_ok()
66    }
67
68    unsafe fn unlock(&self) {
69        let owner_id = self.owner_id.swap(0, Ordering::Release);
70        assert_eq!(
71            owner_id,
72            current().id().as_u64(),
73            "{} tried to release mutex it doesn't own",
74            current().id_name()
75        );
76        self.wq.notify_one(true);
77    }
78
79    fn is_locked(&self) -> bool {
80        self.owner_id.load(Ordering::Relaxed) != 0
81    }
82}
83
84/// An alias of [`lock_api::Mutex`].
85pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
86/// An alias of [`lock_api::MutexGuard`].
87pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
88
89#[cfg(test)]
90mod tests {
91    use crate::Mutex;
92    use axtask as thread;
93    use std::sync::Once;
94
95    static INIT: Once = Once::new();
96
97    fn may_interrupt() {
98        // simulate interrupts
99        if rand::random::<u32>() % 3 == 0 {
100            thread::yield_now();
101        }
102    }
103
104    #[test]
105    fn lots_and_lots() {
106        INIT.call_once(thread::init_scheduler);
107
108        const NUM_TASKS: u32 = 10;
109        const NUM_ITERS: u32 = 10_000;
110        static M: Mutex<u32> = Mutex::new(0);
111
112        fn inc(delta: u32) {
113            for _ in 0..NUM_ITERS {
114                let mut val = M.lock();
115                *val += delta;
116                may_interrupt();
117                drop(val);
118                may_interrupt();
119            }
120        }
121
122        for _ in 0..NUM_TASKS {
123            thread::spawn(|| inc(1));
124            thread::spawn(|| inc(2));
125        }
126
127        println!("spawn OK");
128        loop {
129            let val = M.lock();
130            if *val == NUM_ITERS * NUM_TASKS * 3 {
131                break;
132            }
133            may_interrupt();
134            drop(val);
135            may_interrupt();
136        }
137
138        assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
139        println!("Mutex test OK");
140    }
141}