axnet/smoltcp_impl/
listen_table.rs

1use alloc::{boxed::Box, collections::VecDeque};
2use core::ops::{Deref, DerefMut};
3
4use axerrno::{AxError, AxResult, ax_err};
5use axsync::Mutex;
6use smoltcp::iface::{SocketHandle, SocketSet};
7use smoltcp::socket::tcp::{self, State};
8use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint};
9
10use super::{LISTEN_QUEUE_SIZE, SOCKET_SET, SocketSetWrapper};
11
12const PORT_NUM: usize = 65536;
13
14struct ListenTableEntry {
15    listen_endpoint: IpListenEndpoint,
16    syn_queue: VecDeque<SocketHandle>,
17}
18
19impl ListenTableEntry {
20    pub fn new(listen_endpoint: IpListenEndpoint) -> Self {
21        Self {
22            listen_endpoint,
23            syn_queue: VecDeque::with_capacity(LISTEN_QUEUE_SIZE),
24        }
25    }
26
27    #[inline]
28    fn can_accept(&self, dst: IpAddress) -> bool {
29        match self.listen_endpoint.addr {
30            Some(addr) => addr == dst,
31            None => true,
32        }
33    }
34}
35
36impl Drop for ListenTableEntry {
37    fn drop(&mut self) {
38        for &handle in &self.syn_queue {
39            SOCKET_SET.remove(handle);
40        }
41    }
42}
43
44pub struct ListenTable {
45    tcp: Box<[Mutex<Option<Box<ListenTableEntry>>>]>,
46}
47
48impl ListenTable {
49    pub fn new() -> Self {
50        let tcp = unsafe {
51            let mut buf = Box::new_uninit_slice(PORT_NUM);
52            for i in 0..PORT_NUM {
53                buf[i].write(Mutex::new(None));
54            }
55            buf.assume_init()
56        };
57        Self { tcp }
58    }
59
60    pub fn can_listen(&self, port: u16) -> bool {
61        self.tcp[port as usize].lock().is_none()
62    }
63
64    pub fn listen(&self, listen_endpoint: IpListenEndpoint) -> AxResult {
65        let port = listen_endpoint.port;
66        assert_ne!(port, 0);
67        let mut entry = self.tcp[port as usize].lock();
68        if entry.is_none() {
69            *entry = Some(Box::new(ListenTableEntry::new(listen_endpoint)));
70            Ok(())
71        } else {
72            ax_err!(AddrInUse, "socket listen() failed")
73        }
74    }
75
76    pub fn unlisten(&self, port: u16) {
77        debug!("TCP socket unlisten on {}", port);
78        *self.tcp[port as usize].lock() = None;
79    }
80
81    pub fn can_accept(&self, port: u16) -> AxResult<bool> {
82        if let Some(entry) = self.tcp[port as usize].lock().deref() {
83            Ok(entry.syn_queue.iter().any(|&handle| is_connected(handle)))
84        } else {
85            ax_err!(InvalidInput, "socket accept() failed: not listen")
86        }
87    }
88
89    pub fn accept(&self, port: u16) -> AxResult<(SocketHandle, (IpEndpoint, IpEndpoint))> {
90        if let Some(entry) = self.tcp[port as usize].lock().deref_mut() {
91            let syn_queue = &mut entry.syn_queue;
92            let (idx, addr_tuple) = syn_queue
93                .iter()
94                .enumerate()
95                .find_map(|(idx, &handle)| {
96                    is_connected(handle).then(|| (idx, get_addr_tuple(handle)))
97                })
98                .ok_or(AxError::WouldBlock)?; // wait for connection
99            if idx > 0 {
100                warn!(
101                    "slow SYN queue enumeration: index = {}, len = {}!",
102                    idx,
103                    syn_queue.len()
104                );
105            }
106            let handle = syn_queue.swap_remove_front(idx).unwrap();
107            Ok((handle, addr_tuple))
108        } else {
109            ax_err!(InvalidInput, "socket accept() failed: not listen")
110        }
111    }
112
113    pub fn incoming_tcp_packet(
114        &self,
115        src: IpEndpoint,
116        dst: IpEndpoint,
117        sockets: &mut SocketSet<'_>,
118    ) {
119        if let Some(entry) = self.tcp[dst.port as usize].lock().deref_mut() {
120            if !entry.can_accept(dst.addr) {
121                // not listening on this address
122                return;
123            }
124            if entry.syn_queue.len() >= LISTEN_QUEUE_SIZE {
125                // SYN queue is full, drop the packet
126                warn!("SYN queue overflow!");
127                return;
128            }
129            let mut socket = SocketSetWrapper::new_tcp_socket();
130            if socket.listen(entry.listen_endpoint).is_ok() {
131                let handle = sockets.add(socket);
132                debug!(
133                    "TCP socket {}: prepare for connection {} -> {}",
134                    handle, src, entry.listen_endpoint
135                );
136                entry.syn_queue.push_back(handle);
137            }
138        }
139    }
140}
141
142fn is_connected(handle: SocketHandle) -> bool {
143    SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
144        !matches!(socket.state(), State::Listen | State::SynReceived)
145    })
146}
147
148fn get_addr_tuple(handle: SocketHandle) -> (IpEndpoint, IpEndpoint) {
149    SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
150        (
151            socket.local_endpoint().unwrap(),
152            socket.remote_endpoint().unwrap(),
153        )
154    })
155}