axtask/
task.rs

1use alloc::{boxed::Box, string::String, sync::Arc};
2use core::ops::Deref;
3use core::sync::atomic::{AtomicBool, AtomicI32, AtomicU8, AtomicU64, Ordering};
4use core::{alloc::Layout, cell::UnsafeCell, fmt, ptr::NonNull};
5
6#[cfg(feature = "preempt")]
7use core::sync::atomic::AtomicUsize;
8
9use kspin::SpinNoIrq;
10use memory_addr::{VirtAddr, align_up_4k};
11
12use axhal::arch::TaskContext;
13#[cfg(feature = "tls")]
14use axhal::tls::TlsArea;
15
16use crate::task_ext::AxTaskExt;
17use crate::{AxCpuMask, AxTask, AxTaskRef, WaitQueue};
18
19/// A unique identifier for a thread.
20#[derive(Debug, Clone, Copy, Eq, PartialEq)]
21pub struct TaskId(u64);
22
23/// The possible states of a task.
24#[repr(u8)]
25#[derive(Debug, Clone, Copy, Eq, PartialEq)]
26pub enum TaskState {
27    /// Task is running on some CPU.
28    Running = 1,
29    /// Task is ready to run on some scheduler's ready queue.
30    Ready = 2,
31    /// Task is blocked (in the wait queue or timer list),
32    /// and it has finished its scheduling process, it can be wake up by `notify()` on any run queue safely.
33    Blocked = 3,
34    /// Task is exited and waiting for being dropped.
35    Exited = 4,
36}
37
38/// The inner task structure.
39pub struct TaskInner {
40    id: TaskId,
41    name: UnsafeCell<String>,
42    is_idle: bool,
43    is_init: bool,
44
45    entry: Option<*mut dyn FnOnce()>,
46    state: AtomicU8,
47
48    /// CPU affinity mask.
49    cpumask: SpinNoIrq<AxCpuMask>,
50
51    /// Mark whether the task is in the wait queue.
52    in_wait_queue: AtomicBool,
53
54    /// Used to indicate whether the task is running on a CPU.
55    #[cfg(feature = "smp")]
56    on_cpu: AtomicBool,
57
58    /// A ticket ID used to identify the timer event.
59    /// Set by `set_timer_ticket()` when creating a timer event in `set_alarm_wakeup()`,
60    /// expired by setting it as zero in `timer_ticket_expired()`, which is called by `cancel_events()`.
61    #[cfg(feature = "irq")]
62    timer_ticket_id: AtomicU64,
63
64    #[cfg(feature = "preempt")]
65    need_resched: AtomicBool,
66    #[cfg(feature = "preempt")]
67    preempt_disable_count: AtomicUsize,
68
69    exit_code: AtomicI32,
70    wait_for_exit: WaitQueue,
71
72    kstack: Option<TaskStack>,
73    ctx: UnsafeCell<TaskContext>,
74    task_ext: AxTaskExt,
75
76    #[cfg(feature = "tls")]
77    tls: TlsArea,
78}
79
80impl TaskId {
81    fn new() -> Self {
82        static ID_COUNTER: AtomicU64 = AtomicU64::new(1);
83        Self(ID_COUNTER.fetch_add(1, Ordering::Relaxed))
84    }
85
86    /// Convert the task ID to a `u64`.
87    pub const fn as_u64(&self) -> u64 {
88        self.0
89    }
90}
91
92impl From<u8> for TaskState {
93    #[inline]
94    fn from(state: u8) -> Self {
95        match state {
96            1 => Self::Running,
97            2 => Self::Ready,
98            3 => Self::Blocked,
99            4 => Self::Exited,
100            _ => unreachable!(),
101        }
102    }
103}
104
105unsafe impl Send for TaskInner {}
106unsafe impl Sync for TaskInner {}
107
108impl TaskInner {
109    /// Create a new task with the given entry function and stack size.
110    pub fn new<F>(entry: F, name: String, stack_size: usize) -> Self
111    where
112        F: FnOnce() + Send + 'static,
113    {
114        let mut t = Self::new_common(TaskId::new(), name);
115        debug!("new task: {}", t.id_name());
116        let kstack = TaskStack::alloc(align_up_4k(stack_size));
117
118        #[cfg(feature = "tls")]
119        let tls = VirtAddr::from(t.tls.tls_ptr() as usize);
120        #[cfg(not(feature = "tls"))]
121        let tls = VirtAddr::from(0);
122
123        t.entry = Some(Box::into_raw(Box::new(entry)));
124        t.ctx_mut().init(task_entry as usize, kstack.top(), tls);
125        t.kstack = Some(kstack);
126        if t.name() == "idle" {
127            t.is_idle = true;
128        }
129        t
130    }
131
132    /// Gets the ID of the task.
133    pub const fn id(&self) -> TaskId {
134        self.id
135    }
136
137    /// Gets the name of the task.
138    pub fn name(&self) -> &str {
139        unsafe { (*self.name.get()).as_str() }
140    }
141
142    /// Set the name of the task.
143    pub fn set_name(&self, name: &str) {
144        unsafe {
145            *self.name.get() = String::from(name);
146        }
147    }
148
149    /// Get a combined string of the task ID and name.
150    pub fn id_name(&self) -> alloc::string::String {
151        alloc::format!("Task({}, {:?})", self.id.as_u64(), self.name())
152    }
153
154    /// Wait for the task to exit, and return the exit code.
155    ///
156    /// It will return immediately if the task has already exited (but not dropped).
157    pub fn join(&self) -> Option<i32> {
158        self.wait_for_exit
159            .wait_until(|| self.state() == TaskState::Exited);
160        Some(self.exit_code.load(Ordering::Acquire))
161    }
162
163    /// Returns the pointer to the user-defined task extended data.
164    ///
165    /// # Safety
166    ///
167    /// The caller should not access the pointer directly, use [`TaskExtRef::task_ext`]
168    /// or [`TaskExtMut::task_ext_mut`] instead.
169    ///
170    /// [`TaskExtRef::task_ext`]: crate::task_ext::TaskExtRef::task_ext
171    /// [`TaskExtMut::task_ext_mut`]: crate::task_ext::TaskExtMut::task_ext_mut
172    pub unsafe fn task_ext_ptr(&self) -> *mut u8 {
173        self.task_ext.as_ptr()
174    }
175
176    /// Initialize the user-defined task extended data.
177    ///
178    /// Returns a reference to the task extended data if it has not been
179    /// initialized yet (empty), otherwise returns [`None`].
180    pub fn init_task_ext<T: Sized>(&mut self, data: T) -> Option<&T> {
181        if self.task_ext.is_empty() {
182            self.task_ext.write(data).map(|data| &*data)
183        } else {
184            None
185        }
186    }
187
188    /// Returns a mutable reference to the task context.
189    #[inline]
190    pub const fn ctx_mut(&mut self) -> &mut TaskContext {
191        self.ctx.get_mut()
192    }
193
194    /// Returns the top address of the kernel stack.
195    #[inline]
196    pub const fn kernel_stack_top(&self) -> Option<VirtAddr> {
197        match &self.kstack {
198            Some(s) => Some(s.top()),
199            None => None,
200        }
201    }
202
203    /// Gets the cpu affinity mask of the task.
204    ///
205    /// Returns the cpu affinity mask of the task in type [`AxCpuMask`].
206    #[inline]
207    pub fn cpumask(&self) -> AxCpuMask {
208        *self.cpumask.lock()
209    }
210
211    /// Sets the cpu affinity mask of the task.
212    ///
213    /// # Arguments
214    /// `cpumask` - The cpu affinity mask to be set in type [`AxCpuMask`].
215    #[inline]
216    pub fn set_cpumask(&self, cpumask: AxCpuMask) {
217        *self.cpumask.lock() = cpumask
218    }
219
220    /// Read the top address of the kernel stack for the task.
221    #[inline]
222    pub fn get_kernel_stack_top(&self) -> Option<usize> {
223        if let Some(kstack) = &self.kstack {
224            return Some(kstack.top().as_usize());
225        }
226        None
227    }
228
229    /// Returns the exit code of the task.
230    pub fn exit_code(&self) -> i32 {
231        self.exit_code.load(Ordering::Acquire)
232    }
233}
234
235// private methods
236impl TaskInner {
237    fn new_common(id: TaskId, name: String) -> Self {
238        Self {
239            id,
240            name: UnsafeCell::new(name),
241            is_idle: false,
242            is_init: false,
243            entry: None,
244            state: AtomicU8::new(TaskState::Ready as u8),
245            // By default, the task is allowed to run on all CPUs.
246            cpumask: SpinNoIrq::new(AxCpuMask::full()),
247            in_wait_queue: AtomicBool::new(false),
248            #[cfg(feature = "irq")]
249            timer_ticket_id: AtomicU64::new(0),
250            #[cfg(feature = "smp")]
251            on_cpu: AtomicBool::new(false),
252            #[cfg(feature = "preempt")]
253            need_resched: AtomicBool::new(false),
254            #[cfg(feature = "preempt")]
255            preempt_disable_count: AtomicUsize::new(0),
256            exit_code: AtomicI32::new(0),
257            wait_for_exit: WaitQueue::new(),
258            kstack: None,
259            ctx: UnsafeCell::new(TaskContext::new()),
260            task_ext: AxTaskExt::empty(),
261            #[cfg(feature = "tls")]
262            tls: TlsArea::alloc(),
263        }
264    }
265
266    /// Creates an "init task" using the current CPU states, to use as the
267    /// current task.
268    ///
269    /// As it is the current task, no other task can switch to it until it
270    /// switches out.
271    ///
272    /// And there is no need to set the `entry`, `kstack` or `tls` fields, as
273    /// they will be filled automatically when the task is switches out.
274    pub(crate) fn new_init(name: String) -> Self {
275        let mut t = Self::new_common(TaskId::new(), name);
276        t.is_init = true;
277        #[cfg(feature = "smp")]
278        t.set_on_cpu(true);
279        if t.name() == "idle" {
280            t.is_idle = true;
281        }
282        t
283    }
284
285    pub(crate) fn into_arc(self) -> AxTaskRef {
286        Arc::new(AxTask::new(self))
287    }
288
289    /// Returns the task's current state.
290    #[inline]
291    pub fn state(&self) -> TaskState {
292        self.state.load(Ordering::Acquire).into()
293    }
294
295    /// Set the task's state.
296    #[inline]
297    pub fn set_state(&self, state: TaskState) {
298        self.state.store(state as u8, Ordering::Release)
299    }
300
301    /// Transition the task state from `current_state` to `new_state`,
302    /// Returns `true` if the current state is `current_state` and the state is successfully set to `new_state`,
303    /// otherwise returns `false`.
304    #[inline]
305    pub(crate) fn transition_state(&self, current_state: TaskState, new_state: TaskState) -> bool {
306        self.state
307            .compare_exchange(
308                current_state as u8,
309                new_state as u8,
310                Ordering::AcqRel,
311                Ordering::Acquire,
312            )
313            .is_ok()
314    }
315
316    #[inline]
317    pub(crate) fn is_running(&self) -> bool {
318        matches!(self.state(), TaskState::Running)
319    }
320
321    #[inline]
322    pub(crate) fn is_ready(&self) -> bool {
323        matches!(self.state(), TaskState::Ready)
324    }
325
326    #[inline]
327    pub(crate) const fn is_init(&self) -> bool {
328        self.is_init
329    }
330
331    #[inline]
332    pub(crate) const fn is_idle(&self) -> bool {
333        self.is_idle
334    }
335
336    #[inline]
337    pub(crate) fn in_wait_queue(&self) -> bool {
338        self.in_wait_queue.load(Ordering::Acquire)
339    }
340
341    #[inline]
342    pub(crate) fn set_in_wait_queue(&self, in_wait_queue: bool) {
343        self.in_wait_queue.store(in_wait_queue, Ordering::Release);
344    }
345
346    /// Returns task's current timer ticket ID.
347    #[inline]
348    #[cfg(feature = "irq")]
349    pub(crate) fn timer_ticket(&self) -> u64 {
350        self.timer_ticket_id.load(Ordering::Acquire)
351    }
352
353    /// Set the timer ticket ID.
354    #[inline]
355    #[cfg(feature = "irq")]
356    pub(crate) fn set_timer_ticket(&self, timer_ticket_id: u64) {
357        // CAN NOT set timer_ticket_id to 0,
358        // because 0 is used to indicate the timer event is expired.
359        assert!(timer_ticket_id != 0);
360        self.timer_ticket_id
361            .store(timer_ticket_id, Ordering::Release);
362    }
363
364    /// Expire timer ticket ID by setting it to 0,
365    /// it can be used to identify one timer event is triggered or expired.
366    #[inline]
367    #[cfg(feature = "irq")]
368    pub(crate) fn timer_ticket_expired(&self) {
369        self.timer_ticket_id.store(0, Ordering::Release);
370    }
371
372    #[inline]
373    #[cfg(feature = "preempt")]
374    pub(crate) fn set_preempt_pending(&self, pending: bool) {
375        self.need_resched.store(pending, Ordering::Release)
376    }
377
378    #[inline]
379    #[cfg(feature = "preempt")]
380    pub(crate) fn can_preempt(&self, current_disable_count: usize) -> bool {
381        self.preempt_disable_count.load(Ordering::Acquire) == current_disable_count
382    }
383
384    #[inline]
385    #[cfg(feature = "preempt")]
386    pub(crate) fn disable_preempt(&self) {
387        self.preempt_disable_count.fetch_add(1, Ordering::Relaxed);
388    }
389
390    #[inline]
391    #[cfg(feature = "preempt")]
392    pub(crate) fn enable_preempt(&self, resched: bool) {
393        if self.preempt_disable_count.fetch_sub(1, Ordering::Relaxed) == 1 && resched {
394            // If current task is pending to be preempted, do rescheduling.
395            Self::current_check_preempt_pending();
396        }
397    }
398
399    #[cfg(feature = "preempt")]
400    fn current_check_preempt_pending() {
401        use kernel_guard::NoPreemptIrqSave;
402        let curr = crate::current();
403        if curr.need_resched.load(Ordering::Acquire) && curr.can_preempt(0) {
404            // Note: if we want to print log msg during `preempt_resched`, we have to
405            // disable preemption here, because the axlog may cause preemption.
406            let mut rq = crate::current_run_queue::<NoPreemptIrqSave>();
407            if curr.need_resched.load(Ordering::Acquire) {
408                rq.preempt_resched()
409            }
410        }
411    }
412
413    /// Notify all tasks that join on this task.
414    pub(crate) fn notify_exit(&self, exit_code: i32) {
415        self.exit_code.store(exit_code, Ordering::Release);
416        self.wait_for_exit.notify_all(false);
417    }
418
419    #[inline]
420    pub(crate) const unsafe fn ctx_mut_ptr(&self) -> *mut TaskContext {
421        self.ctx.get()
422    }
423
424    /// Returns whether the task is running on a CPU.
425    ///
426    /// It is used to protect the task from being moved to a different run queue
427    /// while it has not finished its scheduling process.
428    /// The `on_cpu field is set to `true` when the task is preparing to run on a CPU,
429    /// and it is set to `false` when the task has finished its scheduling process in `clear_prev_task_on_cpu()`.
430    #[cfg(feature = "smp")]
431    #[inline]
432    pub(crate) fn on_cpu(&self) -> bool {
433        self.on_cpu.load(Ordering::Acquire)
434    }
435
436    /// Sets whether the task is running on a CPU.
437    #[cfg(feature = "smp")]
438    #[inline]
439    pub(crate) fn set_on_cpu(&self, on_cpu: bool) {
440        self.on_cpu.store(on_cpu, Ordering::Release)
441    }
442}
443
444impl fmt::Debug for TaskInner {
445    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
446        f.debug_struct("TaskInner")
447            .field("id", &self.id)
448            .field("name", &self.name)
449            .field("state", &self.state())
450            .finish()
451    }
452}
453
454impl Drop for TaskInner {
455    fn drop(&mut self) {
456        debug!("task drop: {}", self.id_name());
457    }
458}
459
460struct TaskStack {
461    ptr: NonNull<u8>,
462    layout: Layout,
463}
464
465impl TaskStack {
466    pub fn alloc(size: usize) -> Self {
467        let layout = Layout::from_size_align(size, 16).unwrap();
468        Self {
469            ptr: NonNull::new(unsafe { alloc::alloc::alloc(layout) }).unwrap(),
470            layout,
471        }
472    }
473
474    pub const fn top(&self) -> VirtAddr {
475        unsafe { core::mem::transmute(self.ptr.as_ptr().add(self.layout.size())) }
476    }
477}
478
479impl Drop for TaskStack {
480    fn drop(&mut self) {
481        unsafe { alloc::alloc::dealloc(self.ptr.as_ptr(), self.layout) }
482    }
483}
484
485use core::mem::ManuallyDrop;
486
487/// A wrapper of [`AxTaskRef`] as the current task.
488///
489/// It won't change the reference count of the task when created or dropped.
490pub struct CurrentTask(ManuallyDrop<AxTaskRef>);
491
492impl CurrentTask {
493    pub(crate) fn try_get() -> Option<Self> {
494        let ptr: *const super::AxTask = axhal::cpu::current_task_ptr();
495        if !ptr.is_null() {
496            Some(Self(unsafe { ManuallyDrop::new(AxTaskRef::from_raw(ptr)) }))
497        } else {
498            None
499        }
500    }
501
502    pub(crate) fn get() -> Self {
503        Self::try_get().expect("current task is uninitialized")
504    }
505
506    /// Converts [`CurrentTask`] to [`AxTaskRef`].
507    pub fn as_task_ref(&self) -> &AxTaskRef {
508        &self.0
509    }
510
511    pub(crate) fn clone(&self) -> AxTaskRef {
512        self.0.deref().clone()
513    }
514
515    pub(crate) fn ptr_eq(&self, other: &AxTaskRef) -> bool {
516        Arc::ptr_eq(&self.0, other)
517    }
518
519    pub(crate) unsafe fn init_current(init_task: AxTaskRef) {
520        assert!(init_task.is_init());
521        #[cfg(feature = "tls")]
522        unsafe {
523            axhal::arch::write_thread_pointer(init_task.tls.tls_ptr() as usize);
524        }
525        let ptr = Arc::into_raw(init_task);
526        unsafe {
527            axhal::cpu::set_current_task_ptr(ptr);
528        }
529    }
530
531    pub(crate) unsafe fn set_current(prev: Self, next: AxTaskRef) {
532        let Self(arc) = prev;
533        ManuallyDrop::into_inner(arc); // `call Arc::drop()` to decrease prev task reference count.
534        let ptr = Arc::into_raw(next);
535        unsafe {
536            axhal::cpu::set_current_task_ptr(ptr);
537        }
538    }
539}
540
541impl Deref for CurrentTask {
542    type Target = TaskInner;
543    fn deref(&self) -> &Self::Target {
544        self.0.deref()
545    }
546}
547
548extern "C" fn task_entry() -> ! {
549    #[cfg(feature = "smp")]
550    unsafe {
551        // Clear the prev task on CPU before running the task entry function.
552        crate::run_queue::clear_prev_task_on_cpu();
553    }
554    // Enable irq (if feature "irq" is enabled) before running the task entry function.
555    #[cfg(feature = "irq")]
556    axhal::arch::enable_irqs();
557    let task = crate::current();
558    if let Some(entry) = task.entry {
559        unsafe { Box::from_raw(entry)() };
560    }
561    crate::exit(0);
562}