starry_api/
socket.rs

1//! Wrapper for [`sockaddr`]. Using trait to convert between [`SocketAddr`] and [`sockaddr`] types.
2
3use crate::ptr::{UserConstPtr, UserPtr};
4use axerrno::{LinuxError, LinuxResult};
5use core::{
6    mem::{MaybeUninit, size_of},
7    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
8};
9use linux_raw_sys::net::{
10    __kernel_sa_family_t, AF_INET, AF_INET6, in_addr, in6_addr, sockaddr, sockaddr_in,
11    sockaddr_in6, socklen_t,
12};
13
14/// Trait to extend [`SocketAddr`] and its variants with methods for reading from and writing to user space.
15///
16pub trait SocketAddrExt: Sized {
17    /// This method attempts to interpret the data pointed to by `addr` with the
18    /// given `addrlen` as a valid socket address of the implementing type.
19    fn read_from_user(addr: UserConstPtr<sockaddr>, addrlen: socklen_t) -> LinuxResult<Self>;
20
21    /// This method serializes the current socket address instance into the
22    /// [`sockaddr`] structure pointed to by `addr` in user space.
23    fn write_to_user(&self, addr: UserPtr<sockaddr>) -> LinuxResult<socklen_t>;
24
25    /// Gets the address family of the socket address.
26    fn family(&self) -> u16;
27
28    /// Gets the encoded length of the socket address.
29    fn addr_len(&self) -> socklen_t;
30}
31
32/// Copies a socket address from user space into a temporary kernel storage.
33///
34/// This function reads `addrlen` bytes from the user-space pointer `addr` and
35/// copies them into a `MaybeUninit<sockaddr>` in kernel memory.
36///
37#[inline]
38fn copy_sockaddr_from_user(
39    addr: UserConstPtr<sockaddr>,
40    addrlen: socklen_t,
41) -> LinuxResult<MaybeUninit<sockaddr>> {
42    let mut storage = MaybeUninit::<sockaddr>::uninit();
43    let sock_addr = addr.get_as_ref()?;
44    unsafe {
45        core::ptr::copy_nonoverlapping(
46            sock_addr as *const sockaddr as *const u8,
47            storage.as_mut_ptr() as *mut u8,
48            addrlen as usize,
49        )
50    };
51    Ok(storage)
52}
53
54impl SocketAddrExt for SocketAddr {
55    /// Reads a [`SocketAddr`] from user space.
56    ///
57    /// This implementation first performs basic length validation. Then, it copies
58    /// the raw [`sockaddr`] data from user space into a temporary kernel buffer.
59    /// Based on the address family ([`AF_INET`] or [`AF_INET6`]) extracted from the
60    /// copied data, it delegates the actual parsing to [`SocketAddrV4::read_from_user`]
61    /// or [`SocketAddrV6::read_from_user`].
62    fn read_from_user(addr: UserConstPtr<sockaddr>, addrlen: socklen_t) -> LinuxResult<Self> {
63        if size_of::<__kernel_sa_family_t>() > addrlen as usize
64            || addrlen as usize > size_of::<sockaddr>()
65        {
66            return Err(LinuxError::EINVAL);
67        }
68        let src_addr = addr.get_as_ref()?;
69        let family = unsafe {
70            src_addr
71                .__storage
72                .__bindgen_anon_1
73                .__bindgen_anon_1
74                .ss_family as u32
75        };
76        match family {
77            AF_INET => SocketAddrV4::read_from_user(addr, addrlen).map(SocketAddr::V4),
78            AF_INET6 => SocketAddrV6::read_from_user(addr, addrlen).map(SocketAddr::V6),
79            _ => Err(LinuxError::EAFNOSUPPORT),
80        }
81    }
82
83    /// Writes the [`SocketAddr`] to user space.
84    ///
85    /// This implementation checks for a null user-space pointer. Then, it delegates
86    /// the actual writing to the specific [`SocketAddrV4`] or [`SocketAddrV6`]
87    /// `write_to_user` implementation based on the variant of `self`.
88    fn write_to_user(&self, addr: UserPtr<sockaddr>) -> LinuxResult<socklen_t> {
89        if addr.is_null() {
90            return Err(LinuxError::EINVAL);
91        }
92
93        match self {
94            SocketAddr::V4(v4) => v4.write_to_user(addr),
95            SocketAddr::V6(v6) => v6.write_to_user(addr),
96        }
97    }
98
99    /// Gets the address family of the [`SocketAddr`].
100    ///
101    /// Returns `AF_INET` for IPv4 addresses or `AF_INET6` for IPv6 addresses.
102    fn family(&self) -> u16 {
103        match self {
104            SocketAddr::V4(v4) => v4.family(),
105            SocketAddr::V6(v6) => v6.family(),
106        }
107    }
108
109    /// Gets the encoded length of the [`SocketAddr`] instance.
110    ///
111    /// Returns the size in bytes that this [`SocketAddr`] would occupy when
112    /// encoded as a [`sockaddr_in`] (for IPv4) or [`sockaddr_in6`] (for IPv6) structure.
113    fn addr_len(&self) -> socklen_t {
114        match self {
115            SocketAddr::V4(v4) => v4.addr_len(),
116            SocketAddr::V6(v6) => v6.addr_len(),
117        }
118    }
119}
120
121impl SocketAddrExt for SocketAddrV4 {
122    /// Reads an [`SocketAddrV4`] from user space.
123    fn read_from_user(addr: UserConstPtr<sockaddr>, addrlen: socklen_t) -> LinuxResult<Self> {
124        if addrlen < size_of::<sockaddr_in>() as socklen_t {
125            return Err(LinuxError::EINVAL);
126        }
127        let storage = copy_sockaddr_from_user(addr, addrlen)?;
128        let addr_in = unsafe { &*(storage.as_ptr() as *const sockaddr_in) };
129        if addr_in.sin_family as u32 != AF_INET {
130            return Err(LinuxError::EAFNOSUPPORT);
131        }
132
133        Ok(SocketAddrV4::new(
134            Ipv4Addr::from_bits(u32::from_be(addr_in.sin_addr.s_addr)),
135            u16::from_be(addr_in.sin_port),
136        ))
137    }
138
139    /// Writes the `SocketAddrV4` to user space.
140    fn write_to_user(&self, addr: UserPtr<sockaddr>) -> LinuxResult<socklen_t> {
141        if addr.is_null() {
142            return Err(LinuxError::EINVAL);
143        }
144        let dst_addr = addr.get_as_mut()?;
145        let len = size_of::<sockaddr_in>() as socklen_t;
146        let sockin_addr = sockaddr_in {
147            sin_family: AF_INET as _,
148            sin_port: self.port().to_be(),
149            sin_addr: in_addr {
150                s_addr: u32::from_ne_bytes(self.ip().octets()),
151            },
152            __pad: [0_u8; 8],
153        };
154        unsafe {
155            core::ptr::copy_nonoverlapping(
156                &sockin_addr as *const sockaddr_in as *const u8,
157                dst_addr as *mut sockaddr as *mut u8,
158                len as usize,
159            )
160        };
161
162        Ok(len)
163    }
164
165    /// Gets the address family for [`SocketAddrV4`].
166    fn family(&self) -> u16 {
167        AF_INET as u16
168    }
169
170    /// Gets the encoded length of [`SocketAddrV4`].
171    fn addr_len(&self) -> socklen_t {
172        size_of::<sockaddr_in>() as socklen_t
173    }
174}
175
176impl SocketAddrExt for SocketAddrV6 {
177    /// Reads an [`SocketAddrV6`] from user space.
178    fn read_from_user(addr: UserConstPtr<sockaddr>, addrlen: socklen_t) -> LinuxResult<Self> {
179        if addrlen < size_of::<sockaddr_in6>() as socklen_t {
180            return Err(LinuxError::EINVAL);
181        }
182        let storage = copy_sockaddr_from_user(addr, addrlen)?;
183        let addr_in6 = unsafe { &*(storage.as_ptr() as *const sockaddr_in6) };
184        if addr_in6.sin6_family as u32 != AF_INET6 {
185            return Err(LinuxError::EAFNOSUPPORT);
186        }
187
188        Ok(SocketAddrV6::new(
189            Ipv6Addr::from(unsafe { addr_in6.sin6_addr.in6_u.u6_addr8 }),
190            u16::from_be(addr_in6.sin6_port),
191            u32::from_be(addr_in6.sin6_flowinfo),
192            addr_in6.sin6_scope_id,
193        ))
194    }
195    /// Writes the `SocketAddrV6` to user space.
196    fn write_to_user(&self, addr: UserPtr<sockaddr>) -> LinuxResult<socklen_t> {
197        if addr.is_null() {
198            return Err(LinuxError::EINVAL);
199        }
200        let dst_addr = addr.get_as_mut()?;
201        let len = size_of::<sockaddr_in6>() as socklen_t;
202        let sockin_addr = sockaddr_in6 {
203            sin6_family: AF_INET6 as _,
204            sin6_port: self.port().to_be(),
205            sin6_flowinfo: self.flowinfo().to_be(),
206            sin6_addr: in6_addr {
207                in6_u: linux_raw_sys::net::in6_addr__bindgen_ty_1 {
208                    u6_addr8: self.ip().octets(),
209                },
210            },
211            sin6_scope_id: self.scope_id(),
212        };
213
214        unsafe {
215            core::ptr::copy_nonoverlapping(
216                &sockin_addr as *const sockaddr_in6 as *const u8,
217                dst_addr as *mut sockaddr as *mut u8,
218                len as usize,
219            )
220        };
221
222        Ok(len)
223    }
224
225    /// Gets the address family for [`SocketAddrV6`].
226    fn family(&self) -> u16 {
227        AF_INET6 as u16
228    }
229
230    /// Gets the encoded length of [`SocketAddrV6`].
231    fn addr_len(&self) -> socklen_t {
232        size_of::<sockaddr_in6>() as socklen_t
233    }
234}