starry_api/
ptr.rs

1use core::{alloc::Layout, ffi::c_char, mem::transmute, ptr, slice, str};
2
3use axerrno::{LinuxError, LinuxResult};
4use axhal::paging::MappingFlags;
5use axtask::{TaskExtRef, current};
6use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr, VirtAddrRange};
7use starry_core::mm::access_user_memory;
8
9fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> LinuxResult<()> {
10    let align = layout.align();
11    if start.as_usize() & (align - 1) != 0 {
12        return Err(LinuxError::EFAULT);
13    }
14
15    let task = current();
16    let mut aspace = task.task_ext().process_data().aspace.lock();
17
18    if !aspace.check_region_access(
19        VirtAddrRange::from_start_size(start, layout.size()),
20        access_flags,
21    ) {
22        return Err(LinuxError::EFAULT);
23    }
24
25    let page_start = start.align_down_4k();
26    let page_end = (start + layout.size()).align_up_4k();
27    aspace.populate_area(page_start, page_end - page_start, access_flags)?;
28
29    Ok(())
30}
31
32fn check_null_terminated<T: PartialEq + Default>(
33    start: VirtAddr,
34    access_flags: MappingFlags,
35) -> LinuxResult<usize> {
36    let align = Layout::new::<T>().align();
37    if start.as_usize() & (align - 1) != 0 {
38        return Err(LinuxError::EFAULT);
39    }
40
41    let zero = T::default();
42
43    let mut page = start.align_down_4k();
44
45    let start = start.as_ptr_of::<T>();
46    let mut len = 0;
47
48    access_user_memory(|| {
49        loop {
50            // SAFETY: This won't overflow the address space since we'll check
51            // it below.
52            let ptr = unsafe { start.add(len) };
53            while ptr as usize >= page.as_ptr() as usize {
54                // We cannot prepare `aspace` outside of the loop, since holding
55                // aspace requires a mutex which would be required on page
56                // fault, and page faults can trigger inside the loop.
57
58                // TODO: this is inefficient, but we have to do this instead of
59                // querying the page table since the page might has not been
60                // allocated yet.
61                let task = current();
62                let aspace = task.task_ext().process_data().aspace.lock();
63                if !aspace.check_region_access(
64                    VirtAddrRange::from_start_size(page, PAGE_SIZE_4K),
65                    access_flags,
66                ) {
67                    return Err(LinuxError::EFAULT);
68                }
69
70                page += PAGE_SIZE_4K;
71            }
72
73            // This might trigger a page fault
74            // SAFETY: The pointer is valid and points to a valid memory region.
75            if unsafe { ptr.read_volatile() } == zero {
76                break;
77            }
78            len += 1;
79        }
80        Ok(())
81    })?;
82
83    Ok(len)
84}
85
86/// A pointer to user space memory.
87#[repr(transparent)]
88#[derive(PartialEq, Clone, Copy)]
89pub struct UserPtr<T>(*mut T);
90
91impl<T> From<usize> for UserPtr<T> {
92    fn from(value: usize) -> Self {
93        UserPtr(value as *mut _)
94    }
95}
96
97impl<T> Default for UserPtr<T> {
98    fn default() -> Self {
99        Self(ptr::null_mut())
100    }
101}
102
103impl<T> UserPtr<T> {
104    const ACCESS_FLAGS: MappingFlags = MappingFlags::READ.union(MappingFlags::WRITE);
105
106    pub fn address(&self) -> VirtAddr {
107        VirtAddr::from_ptr_of(self.0)
108    }
109
110    pub fn is_null(&self) -> bool {
111        self.0.is_null()
112    }
113
114    pub fn get_as_mut(self) -> LinuxResult<&'static mut T> {
115        check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
116        Ok(unsafe { &mut *self.0 })
117    }
118
119    pub fn get_as_mut_slice(self, len: usize) -> LinuxResult<&'static mut [T]> {
120        check_region(
121            self.address(),
122            Layout::array::<T>(len).unwrap(),
123            Self::ACCESS_FLAGS,
124        )?;
125        Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
126    }
127
128    pub fn get_as_mut_null_terminated(self) -> LinuxResult<&'static mut [T]>
129    where
130        T: PartialEq + Default,
131    {
132        let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
133        Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
134    }
135}
136
137/// An immutable pointer to user space memory.
138#[repr(transparent)]
139#[derive(PartialEq, Clone, Copy)]
140pub struct UserConstPtr<T>(*const T);
141
142impl<T> From<usize> for UserConstPtr<T> {
143    fn from(value: usize) -> Self {
144        UserConstPtr(value as *const _)
145    }
146}
147
148impl<T> Default for UserConstPtr<T> {
149    fn default() -> Self {
150        Self(ptr::null())
151    }
152}
153
154impl<T> UserConstPtr<T> {
155    const ACCESS_FLAGS: MappingFlags = MappingFlags::READ;
156
157    pub fn address(&self) -> VirtAddr {
158        VirtAddr::from_ptr_of(self.0)
159    }
160
161    pub fn is_null(&self) -> bool {
162        self.0.is_null()
163    }
164
165    pub fn get_as_ref(self) -> LinuxResult<&'static T> {
166        check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
167        Ok(unsafe { &*self.0 })
168    }
169
170    pub fn get_as_slice(self, len: usize) -> LinuxResult<&'static [T]> {
171        check_region(
172            self.address(),
173            Layout::array::<T>(len).unwrap(),
174            Self::ACCESS_FLAGS,
175        )?;
176        Ok(unsafe { slice::from_raw_parts(self.0, len) })
177    }
178
179    pub fn get_as_null_terminated(self) -> LinuxResult<&'static [T]>
180    where
181        T: PartialEq + Default,
182    {
183        let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
184        Ok(unsafe { slice::from_raw_parts(self.0, len) })
185    }
186}
187
188impl UserConstPtr<c_char> {
189    /// Get the pointer as `&str`, validating the memory region.
190    pub fn get_as_str(self) -> LinuxResult<&'static str> {
191        let slice = self.get_as_null_terminated()?;
192        // SAFETY: c_char is u8
193        let slice = unsafe { transmute::<&[c_char], &[u8]>(slice) };
194
195        str::from_utf8(slice).map_err(|_| LinuxError::EILSEQ)
196    }
197}
198
199macro_rules! nullable {
200    ($ptr:ident.$func:ident($($arg:expr),*)) => {
201        if $ptr.is_null() {
202            Ok(None)
203        } else {
204            Some($ptr.$func($($arg),*)).transpose()
205        }
206    };
207}
208pub(crate) use nullable;