axnet/smoltcp_impl/
listen_table.rs1use alloc::{boxed::Box, collections::VecDeque};
2use core::ops::{Deref, DerefMut};
3
4use axerrno::{AxError, AxResult, ax_err};
5use axsync::Mutex;
6use smoltcp::iface::{SocketHandle, SocketSet};
7use smoltcp::socket::tcp::{self, State};
8use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint};
9
10use super::{LISTEN_QUEUE_SIZE, SOCKET_SET, SocketSetWrapper};
11
12const PORT_NUM: usize = 65536;
13
14struct ListenTableEntry {
15 listen_endpoint: IpListenEndpoint,
16 syn_queue: VecDeque<SocketHandle>,
17}
18
19impl ListenTableEntry {
20 pub fn new(listen_endpoint: IpListenEndpoint) -> Self {
21 Self {
22 listen_endpoint,
23 syn_queue: VecDeque::with_capacity(LISTEN_QUEUE_SIZE),
24 }
25 }
26
27 #[inline]
28 fn can_accept(&self, dst: IpAddress) -> bool {
29 match self.listen_endpoint.addr {
30 Some(addr) => addr == dst,
31 None => true,
32 }
33 }
34}
35
36impl Drop for ListenTableEntry {
37 fn drop(&mut self) {
38 for &handle in &self.syn_queue {
39 SOCKET_SET.remove(handle);
40 }
41 }
42}
43
44pub struct ListenTable {
45 tcp: Box<[Mutex<Option<Box<ListenTableEntry>>>]>,
46}
47
48impl ListenTable {
49 pub fn new() -> Self {
50 let tcp = unsafe {
51 let mut buf = Box::new_uninit_slice(PORT_NUM);
52 for i in 0..PORT_NUM {
53 buf[i].write(Mutex::new(None));
54 }
55 buf.assume_init()
56 };
57 Self { tcp }
58 }
59
60 pub fn can_listen(&self, port: u16) -> bool {
61 self.tcp[port as usize].lock().is_none()
62 }
63
64 pub fn listen(&self, listen_endpoint: IpListenEndpoint) -> AxResult {
65 let port = listen_endpoint.port;
66 assert_ne!(port, 0);
67 let mut entry = self.tcp[port as usize].lock();
68 if entry.is_none() {
69 *entry = Some(Box::new(ListenTableEntry::new(listen_endpoint)));
70 Ok(())
71 } else {
72 ax_err!(AddrInUse, "socket listen() failed")
73 }
74 }
75
76 pub fn unlisten(&self, port: u16) {
77 debug!("TCP socket unlisten on {}", port);
78 *self.tcp[port as usize].lock() = None;
79 }
80
81 pub fn can_accept(&self, port: u16) -> AxResult<bool> {
82 if let Some(entry) = self.tcp[port as usize].lock().deref() {
83 Ok(entry.syn_queue.iter().any(|&handle| is_connected(handle)))
84 } else {
85 ax_err!(InvalidInput, "socket accept() failed: not listen")
86 }
87 }
88
89 pub fn accept(&self, port: u16) -> AxResult<(SocketHandle, (IpEndpoint, IpEndpoint))> {
90 if let Some(entry) = self.tcp[port as usize].lock().deref_mut() {
91 let syn_queue = &mut entry.syn_queue;
92 let (idx, addr_tuple) = syn_queue
93 .iter()
94 .enumerate()
95 .find_map(|(idx, &handle)| {
96 is_connected(handle).then(|| (idx, get_addr_tuple(handle)))
97 })
98 .ok_or(AxError::WouldBlock)?; if idx > 0 {
100 warn!(
101 "slow SYN queue enumeration: index = {}, len = {}!",
102 idx,
103 syn_queue.len()
104 );
105 }
106 let handle = syn_queue.swap_remove_front(idx).unwrap();
107 Ok((handle, addr_tuple))
108 } else {
109 ax_err!(InvalidInput, "socket accept() failed: not listen")
110 }
111 }
112
113 pub fn incoming_tcp_packet(
114 &self,
115 src: IpEndpoint,
116 dst: IpEndpoint,
117 sockets: &mut SocketSet<'_>,
118 ) {
119 if let Some(entry) = self.tcp[dst.port as usize].lock().deref_mut() {
120 if !entry.can_accept(dst.addr) {
121 return;
123 }
124 if entry.syn_queue.len() >= LISTEN_QUEUE_SIZE {
125 warn!("SYN queue overflow!");
127 return;
128 }
129 let mut socket = SocketSetWrapper::new_tcp_socket();
130 if socket.listen(entry.listen_endpoint).is_ok() {
131 let handle = sockets.add(socket);
132 debug!(
133 "TCP socket {}: prepare for connection {} -> {}",
134 handle, src, entry.listen_endpoint
135 );
136 entry.syn_queue.push_back(handle);
137 }
138 }
139 }
140}
141
142fn is_connected(handle: SocketHandle) -> bool {
143 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
144 !matches!(socket.state(), State::Listen | State::SynReceived)
145 })
146}
147
148fn get_addr_tuple(handle: SocketHandle) -> (IpEndpoint, IpEndpoint) {
149 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
150 (
151 socket.local_endpoint().unwrap(),
152 socket.remote_endpoint().unwrap(),
153 )
154 })
155}