1use core::{alloc::Layout, ffi::c_char, mem::transmute, ptr, slice, str};
2
3use axerrno::{LinuxError, LinuxResult};
4use axhal::paging::MappingFlags;
5use axtask::{TaskExtRef, current};
6use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr, VirtAddrRange};
7use starry_core::mm::access_user_memory;
8
9fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> LinuxResult<()> {
10 let align = layout.align();
11 if start.as_usize() & (align - 1) != 0 {
12 return Err(LinuxError::EFAULT);
13 }
14
15 let task = current();
16 let mut aspace = task.task_ext().process_data().aspace.lock();
17
18 if !aspace.check_region_access(
19 VirtAddrRange::from_start_size(start, layout.size()),
20 access_flags,
21 ) {
22 return Err(LinuxError::EFAULT);
23 }
24
25 let page_start = start.align_down_4k();
26 let page_end = (start + layout.size()).align_up_4k();
27 aspace.populate_area(page_start, page_end - page_start, access_flags)?;
28
29 Ok(())
30}
31
32fn check_null_terminated<T: PartialEq + Default>(
33 start: VirtAddr,
34 access_flags: MappingFlags,
35) -> LinuxResult<usize> {
36 let align = Layout::new::<T>().align();
37 if start.as_usize() & (align - 1) != 0 {
38 return Err(LinuxError::EFAULT);
39 }
40
41 let zero = T::default();
42
43 let mut page = start.align_down_4k();
44
45 let start = start.as_ptr_of::<T>();
46 let mut len = 0;
47
48 access_user_memory(|| {
49 loop {
50 let ptr = unsafe { start.add(len) };
53 while ptr as usize >= page.as_ptr() as usize {
54 let task = current();
62 let aspace = task.task_ext().process_data().aspace.lock();
63 if !aspace.check_region_access(
64 VirtAddrRange::from_start_size(page, PAGE_SIZE_4K),
65 access_flags,
66 ) {
67 return Err(LinuxError::EFAULT);
68 }
69
70 page += PAGE_SIZE_4K;
71 }
72
73 if unsafe { ptr.read_volatile() } == zero {
76 break;
77 }
78 len += 1;
79 }
80 Ok(())
81 })?;
82
83 Ok(len)
84}
85
86#[repr(transparent)]
88#[derive(PartialEq, Clone, Copy)]
89pub struct UserPtr<T>(*mut T);
90
91impl<T> From<usize> for UserPtr<T> {
92 fn from(value: usize) -> Self {
93 UserPtr(value as *mut _)
94 }
95}
96
97impl<T> Default for UserPtr<T> {
98 fn default() -> Self {
99 Self(ptr::null_mut())
100 }
101}
102
103impl<T> UserPtr<T> {
104 const ACCESS_FLAGS: MappingFlags = MappingFlags::READ.union(MappingFlags::WRITE);
105
106 pub fn address(&self) -> VirtAddr {
107 VirtAddr::from_ptr_of(self.0)
108 }
109
110 pub fn is_null(&self) -> bool {
111 self.0.is_null()
112 }
113
114 pub fn get_as_mut(self) -> LinuxResult<&'static mut T> {
115 check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
116 Ok(unsafe { &mut *self.0 })
117 }
118
119 pub fn get_as_mut_slice(self, len: usize) -> LinuxResult<&'static mut [T]> {
120 check_region(
121 self.address(),
122 Layout::array::<T>(len).unwrap(),
123 Self::ACCESS_FLAGS,
124 )?;
125 Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
126 }
127
128 pub fn get_as_mut_null_terminated(self) -> LinuxResult<&'static mut [T]>
129 where
130 T: PartialEq + Default,
131 {
132 let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
133 Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
134 }
135}
136
137#[repr(transparent)]
139#[derive(PartialEq, Clone, Copy)]
140pub struct UserConstPtr<T>(*const T);
141
142impl<T> From<usize> for UserConstPtr<T> {
143 fn from(value: usize) -> Self {
144 UserConstPtr(value as *const _)
145 }
146}
147
148impl<T> Default for UserConstPtr<T> {
149 fn default() -> Self {
150 Self(ptr::null())
151 }
152}
153
154impl<T> UserConstPtr<T> {
155 const ACCESS_FLAGS: MappingFlags = MappingFlags::READ;
156
157 pub fn address(&self) -> VirtAddr {
158 VirtAddr::from_ptr_of(self.0)
159 }
160
161 pub fn is_null(&self) -> bool {
162 self.0.is_null()
163 }
164
165 pub fn get_as_ref(self) -> LinuxResult<&'static T> {
166 check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
167 Ok(unsafe { &*self.0 })
168 }
169
170 pub fn get_as_slice(self, len: usize) -> LinuxResult<&'static [T]> {
171 check_region(
172 self.address(),
173 Layout::array::<T>(len).unwrap(),
174 Self::ACCESS_FLAGS,
175 )?;
176 Ok(unsafe { slice::from_raw_parts(self.0, len) })
177 }
178
179 pub fn get_as_null_terminated(self) -> LinuxResult<&'static [T]>
180 where
181 T: PartialEq + Default,
182 {
183 let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
184 Ok(unsafe { slice::from_raw_parts(self.0, len) })
185 }
186}
187
188impl UserConstPtr<c_char> {
189 pub fn get_as_str(self) -> LinuxResult<&'static str> {
191 let slice = self.get_as_null_terminated()?;
192 let slice = unsafe { transmute::<&[c_char], &[u8]>(slice) };
194
195 str::from_utf8(slice).map_err(|_| LinuxError::EILSEQ)
196 }
197}
198
199macro_rules! nullable {
200 ($ptr:ident.$func:ident($($arg:expr),*)) => {
201 if $ptr.is_null() {
202 Ok(None)
203 } else {
204 Some($ptr.$func($($arg),*)).transpose()
205 }
206 };
207}
208pub(crate) use nullable;