1use core::sync::atomic::{AtomicU64, Ordering};
4
5use axtask::{WaitQueue, current};
6
7pub struct RawMutex {
13 wq: WaitQueue,
14 owner_id: AtomicU64,
15}
16
17impl RawMutex {
18 #[inline(always)]
20 pub const fn new() -> Self {
21 Self {
22 wq: WaitQueue::new(),
23 owner_id: AtomicU64::new(0),
24 }
25 }
26}
27
28unsafe impl lock_api::RawMutex for RawMutex {
29 const INIT: Self = RawMutex::new();
30
31 type GuardMarker = lock_api::GuardSend;
32
33 fn lock(&self) {
34 let current_id = current().id().as_u64();
35 loop {
36 match self.owner_id.compare_exchange_weak(
39 0,
40 current_id,
41 Ordering::Acquire,
42 Ordering::Relaxed,
43 ) {
44 Ok(_) => break,
45 Err(owner_id) => {
46 assert_ne!(
47 owner_id,
48 current_id,
49 "{} tried to acquire mutex it already owns.",
50 current().id_name()
51 );
52 self.wq.wait_until(|| !self.is_locked());
54 }
55 }
56 }
57 }
58
59 fn try_lock(&self) -> bool {
60 let current_id = current().id().as_u64();
61 self.owner_id
64 .compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
65 .is_ok()
66 }
67
68 unsafe fn unlock(&self) {
69 let owner_id = self.owner_id.swap(0, Ordering::Release);
70 assert_eq!(
71 owner_id,
72 current().id().as_u64(),
73 "{} tried to release mutex it doesn't own",
74 current().id_name()
75 );
76 self.wq.notify_one(true);
77 }
78
79 fn is_locked(&self) -> bool {
80 self.owner_id.load(Ordering::Relaxed) != 0
81 }
82}
83
84pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
86pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
88
89#[cfg(test)]
90mod tests {
91 use crate::Mutex;
92 use axtask as thread;
93 use std::sync::Once;
94
95 static INIT: Once = Once::new();
96
97 fn may_interrupt() {
98 if rand::random::<u32>() % 3 == 0 {
100 thread::yield_now();
101 }
102 }
103
104 #[test]
105 fn lots_and_lots() {
106 INIT.call_once(thread::init_scheduler);
107
108 const NUM_TASKS: u32 = 10;
109 const NUM_ITERS: u32 = 10_000;
110 static M: Mutex<u32> = Mutex::new(0);
111
112 fn inc(delta: u32) {
113 for _ in 0..NUM_ITERS {
114 let mut val = M.lock();
115 *val += delta;
116 may_interrupt();
117 drop(val);
118 may_interrupt();
119 }
120 }
121
122 for _ in 0..NUM_TASKS {
123 thread::spawn(|| inc(1));
124 thread::spawn(|| inc(2));
125 }
126
127 println!("spawn OK");
128 loop {
129 let val = M.lock();
130 if *val == NUM_ITERS * NUM_TASKS * 3 {
131 break;
132 }
133 may_interrupt();
134 drop(val);
135 may_interrupt();
136 }
137
138 assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
139 println!("Mutex test OK");
140 }
141}