1use core::{arch::naked_asm, fmt};
2use memory_addr::VirtAddr;
3#[allow(missing_docs)]
5#[repr(C)]
6#[derive(Debug, Default, Clone, Copy)]
7pub struct TrapFrame {
8    pub rax: u64,
9    pub rcx: u64,
10    pub rdx: u64,
11    pub rbx: u64,
12    pub rbp: u64,
13    pub rsi: u64,
14    pub rdi: u64,
15    pub r8: u64,
16    pub r9: u64,
17    pub r10: u64,
18    pub r11: u64,
19    pub r12: u64,
20    pub r13: u64,
21    pub r14: u64,
22    pub r15: u64,
23
24    pub fs_base: u64,
26    pub __pad: u64,
27
28    pub vector: u64,
30    pub error_code: u64,
31
32    pub rip: u64,
34    pub cs: u64,
35    pub rflags: u64,
36    pub rsp: u64,
37    pub ss: u64,
38}
39
40impl TrapFrame {
41    pub const fn arg0(&self) -> usize {
43        self.rdi as _
44    }
45
46    pub const fn set_arg0(&mut self, rdi: usize) {
48        self.rdi = rdi as _;
49    }
50
51    pub const fn arg1(&self) -> usize {
53        self.rsi as _
54    }
55
56    pub const fn set_arg1(&mut self, rsi: usize) {
58        self.rsi = rsi as _;
59    }
60
61    pub const fn arg2(&self) -> usize {
63        self.rdx as _
64    }
65
66    pub const fn set_arg2(&mut self, rdx: usize) {
68        self.rdx = rdx as _;
69    }
70
71    pub const fn arg3(&self) -> usize {
73        self.r10 as _
74    }
75
76    pub const fn set_arg3(&mut self, r10: usize) {
78        self.r10 = r10 as _;
79    }
80
81    pub const fn arg4(&self) -> usize {
83        self.r8 as _
84    }
85
86    pub const fn set_arg4(&mut self, r8: usize) {
88        self.r8 = r8 as _;
89    }
90
91    pub const fn arg5(&self) -> usize {
93        self.r9 as _
94    }
95
96    pub const fn set_arg5(&mut self, r9: usize) {
98        self.r9 = r9 as _;
99    }
100
101    pub const fn is_user(&self) -> bool {
103        self.cs & 0b11 == 3
104    }
105
106    pub const fn ip(&self) -> usize {
108        self.rip as _
109    }
110
111    pub const fn set_ip(&mut self, rip: usize) {
113        self.rip = rip as _;
114    }
115
116    pub const fn sp(&self) -> usize {
118        self.rsp as _
119    }
120
121    pub const fn set_sp(&mut self, rsp: usize) {
123        self.rsp = rsp as _;
124    }
125
126    pub const fn retval(&self) -> usize {
128        self.rax as _
129    }
130
131    pub const fn set_retval(&mut self, rax: usize) {
133        self.rax = rax as _;
134    }
135
136    pub fn push_ra(&mut self, addr: usize) {
142        self.rsp -= 8;
143        unsafe {
144            core::ptr::write(self.rsp as *mut usize, addr);
145        }
146    }
147
148    pub const fn tls(&self) -> usize {
150        self.fs_base as _
151    }
152
153    pub const fn set_tls(&mut self, tls_area: usize) {
155        self.fs_base = tls_area as _;
156    }
157}
158
159#[cfg(feature = "uspace")]
161pub struct UspaceContext(TrapFrame);
162
163#[cfg(feature = "uspace")]
164impl UspaceContext {
165    pub const fn empty() -> Self {
167        unsafe { core::mem::MaybeUninit::zeroed().assume_init() }
168    }
169
170    pub fn new(entry: usize, ustack_top: VirtAddr, arg0: usize) -> Self {
173        use crate::arch::GdtStruct;
174        use x86_64::registers::rflags::RFlags;
175        Self(TrapFrame {
176            rdi: arg0 as _,
177            rip: entry as _,
178            cs: GdtStruct::UCODE64_SELECTOR.0 as _,
179            #[cfg(feature = "irq")]
180            rflags: RFlags::INTERRUPT_FLAG.bits(), rsp: ustack_top.as_usize() as _,
182            ss: GdtStruct::UDATA_SELECTOR.0 as _,
183            ..Default::default()
184        })
185    }
186
187    pub const fn from(tf: &TrapFrame) -> Self {
192        use crate::arch::GdtStruct;
193        let mut tf = *tf;
194        tf.cs = GdtStruct::UCODE64_SELECTOR.0 as _;
195        tf.ss = GdtStruct::UDATA_SELECTOR.0 as _;
196        Self(tf)
197    }
198
199    pub unsafe fn enter_uspace(&self, kstack_top: VirtAddr) -> ! {
210        super::disable_irqs();
211        assert_eq!(super::tss_get_rsp0(), kstack_top);
212        super::tls::switch_to_user_fs_base(&self.0);
213        unsafe {
214            core::arch::asm!("
215                mov     rsp, {tf}
216                pop     rax
217                pop     rcx
218                pop     rdx
219                pop     rbx
220                pop     rbp
221                pop     rsi
222                pop     rdi
223                pop     r8
224                pop     r9
225                pop     r10
226                pop     r11
227                pop     r12
228                pop     r13
229                pop     r14
230                pop     r15
231                add     rsp, 32     // skip fs_base, vector, error_code
232                swapgs
233                iretq",
234                tf = in(reg) &self.0,
235                options(noreturn),
236            )
237        }
238    }
239}
240
241#[cfg(feature = "uspace")]
242impl core::ops::Deref for UspaceContext {
243    type Target = TrapFrame;
244
245    fn deref(&self) -> &Self::Target {
246        &self.0
247    }
248}
249
250#[cfg(feature = "uspace")]
251impl core::ops::DerefMut for UspaceContext {
252    fn deref_mut(&mut self) -> &mut Self::Target {
253        &mut self.0
254    }
255}
256
257#[repr(C)]
258#[derive(Debug, Default)]
259struct ContextSwitchFrame {
260    r15: u64,
261    r14: u64,
262    r13: u64,
263    r12: u64,
264    rbx: u64,
265    rbp: u64,
266    rip: u64,
267}
268
269#[allow(missing_docs)]
274#[repr(C, align(16))]
275#[derive(Debug)]
276pub struct FxsaveArea {
277    pub fcw: u16,
278    pub fsw: u16,
279    pub ftw: u16,
280    pub fop: u16,
281    pub fip: u64,
282    pub fdp: u64,
283    pub mxcsr: u32,
284    pub mxcsr_mask: u32,
285    pub st: [u64; 16],
286    pub xmm: [u64; 32],
287    _padding: [u64; 12],
288}
289
290static_assertions::const_assert_eq!(core::mem::size_of::<FxsaveArea>(), 512);
291
292pub struct ExtendedState {
294    pub fxsave_area: FxsaveArea,
296}
297
298#[cfg(feature = "fp_simd")]
299impl ExtendedState {
300    #[inline]
301    fn save(&mut self) {
302        unsafe { core::arch::x86_64::_fxsave64(&mut self.fxsave_area as *mut _ as *mut u8) }
303    }
304
305    #[inline]
306    fn restore(&self) {
307        unsafe { core::arch::x86_64::_fxrstor64(&self.fxsave_area as *const _ as *const u8) }
308    }
309
310    const fn default() -> Self {
311        let mut area: FxsaveArea = unsafe { core::mem::MaybeUninit::zeroed().assume_init() };
312        area.fcw = 0x37f;
313        area.ftw = 0xffff;
314        area.mxcsr = 0x1f80;
315        Self { fxsave_area: area }
316    }
317}
318
319impl fmt::Debug for ExtendedState {
320    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
321        f.debug_struct("ExtendedState")
322            .field("fxsave_area", &self.fxsave_area)
323            .finish()
324    }
325}
326
327#[derive(Debug)]
347pub struct TaskContext {
348    pub kstack_top: VirtAddr,
350    pub rsp: u64,
352    pub fs_base: usize,
354    #[cfg(feature = "uspace")]
358    pub gs_base: usize,
359    #[cfg(feature = "fp_simd")]
361    pub ext_state: ExtendedState,
362    #[cfg(feature = "uspace")]
364    pub cr3: memory_addr::PhysAddr,
365}
366
367impl TaskContext {
368    pub fn new() -> Self {
376        Self {
377            kstack_top: va!(0),
378            rsp: 0,
379            fs_base: 0,
380            #[cfg(feature = "uspace")]
381            cr3: crate::paging::kernel_page_table_root(),
382            #[cfg(feature = "fp_simd")]
383            ext_state: ExtendedState::default(),
384            #[cfg(feature = "uspace")]
385            gs_base: 0,
386        }
387    }
388
389    pub fn init(&mut self, entry: usize, kstack_top: VirtAddr, tls_area: VirtAddr) {
392        unsafe {
393            let frame_ptr = (kstack_top.as_mut_ptr() as *mut u64).sub(1);
397            let frame_ptr = (frame_ptr as *mut ContextSwitchFrame).sub(1);
398            core::ptr::write(
399                frame_ptr,
400                ContextSwitchFrame {
401                    rip: entry as _,
402                    ..Default::default()
403                },
404            );
405            self.rsp = frame_ptr as u64;
406        }
407        self.kstack_top = kstack_top;
408        self.fs_base = tls_area.as_usize();
409    }
410
411    #[cfg(feature = "uspace")]
418    pub fn set_page_table_root(&mut self, cr3: memory_addr::PhysAddr) {
419        self.cr3 = cr3;
420    }
421
422    pub fn switch_to(&mut self, next_ctx: &Self) {
427        #[cfg(feature = "fp_simd")]
428        {
429            self.ext_state.save();
430            next_ctx.ext_state.restore();
431        }
432        #[cfg(feature = "tls")]
433        unsafe {
434            self.fs_base = super::read_thread_pointer();
435            super::write_thread_pointer(next_ctx.fs_base);
436        }
437        #[cfg(feature = "uspace")]
438        unsafe {
439            self.gs_base = x86::msr::rdmsr(x86::msr::IA32_KERNEL_GSBASE) as usize;
441            x86::msr::wrmsr(x86::msr::IA32_KERNEL_GSBASE, next_ctx.gs_base as u64);
442            super::tss_set_rsp0(next_ctx.kstack_top);
443            if next_ctx.cr3 != self.cr3 {
444                super::write_page_table_root(next_ctx.cr3);
445            }
446        }
447        unsafe { context_switch(&mut self.rsp, &next_ctx.rsp) }
448    }
449}
450
451#[unsafe(naked)]
452unsafe extern "C" fn context_switch(_current_stack: &mut u64, _next_stack: &u64) {
453    naked_asm!(
454        "
455        .code64
456        push    rbp
457        push    rbx
458        push    r12
459        push    r13
460        push    r14
461        push    r15
462        mov     [rdi], rsp
463
464        mov     rsp, [rsi]
465        pop     r15
466        pop     r14
467        pop     r13
468        pop     r12
469        pop     rbx
470        pop     rbp
471        ret",
472    )
473}