Skip to main content

embassy_net/
udp.rs

1//! UDP sockets.
2
3use core::future::{poll_fn, Future};
4use core::mem;
5use core::task::{Context, Poll};
6
7use smoltcp::iface::{Interface, SocketHandle};
8use smoltcp::socket::udp;
9pub use smoltcp::socket::udp::{PacketMetadata, UdpMetadata};
10use smoltcp::wire::IpListenEndpoint;
11
12use crate::Stack;
13
14/// Error returned by [`UdpSocket::bind`].
15#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub enum BindError {
18    /// The socket was already open.
19    InvalidState,
20    /// No route to host.
21    NoRoute,
22}
23
24/// Error returned by [`UdpSocket::recv_from`] and [`UdpSocket::send_to`].
25#[derive(PartialEq, Eq, Clone, Copy, Debug)]
26#[cfg_attr(feature = "defmt", derive(defmt::Format))]
27pub enum SendError {
28    /// No route to host.
29    NoRoute,
30    /// Socket not bound to an outgoing port.
31    SocketNotBound,
32}
33
34/// Error returned by [`UdpSocket::recv_from`] and [`UdpSocket::send_to`].
35#[derive(PartialEq, Eq, Clone, Copy, Debug)]
36#[cfg_attr(feature = "defmt", derive(defmt::Format))]
37pub enum RecvError {
38    /// Provided buffer was smaller than the received packet.
39    Truncated,
40}
41
42/// An UDP socket.
43pub struct UdpSocket<'a> {
44    stack: Stack<'a>,
45    handle: SocketHandle,
46}
47
48impl<'a> UdpSocket<'a> {
49    /// Create a new UDP socket using the provided stack and buffers.
50    pub fn new(
51        stack: Stack<'a>,
52        rx_meta: &'a mut [PacketMetadata],
53        rx_buffer: &'a mut [u8],
54        tx_meta: &'a mut [PacketMetadata],
55        tx_buffer: &'a mut [u8],
56    ) -> Self {
57        let handle = stack.with_mut(|i| {
58            let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
59            let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
60            let tx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(tx_meta) };
61            let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
62            i.sockets.add(udp::Socket::new(
63                udp::PacketBuffer::new(rx_meta, rx_buffer),
64                udp::PacketBuffer::new(tx_meta, tx_buffer),
65            ))
66        });
67
68        Self { stack, handle }
69    }
70
71    /// Bind the socket to a local endpoint.
72    pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError>
73    where
74        T: Into<IpListenEndpoint>,
75    {
76        let mut endpoint = endpoint.into();
77
78        if endpoint.port == 0 {
79            // If user didn't specify port allocate a dynamic port.
80            endpoint.port = self.stack.with_mut(|i| i.get_local_port());
81        }
82
83        match self.with_mut(|s, _| s.bind(endpoint)) {
84            Ok(()) => Ok(()),
85            Err(udp::BindError::InvalidState) => Err(BindError::InvalidState),
86            Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute),
87        }
88    }
89
90    fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
91        self.stack.with(|i| {
92            let socket = i.sockets.get::<udp::Socket>(self.handle);
93            f(socket, &i.iface)
94        })
95    }
96
97    fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
98        self.stack.with_mut(|i| {
99            let socket = i.sockets.get_mut::<udp::Socket>(self.handle);
100            let res = f(socket, &mut i.iface);
101            i.waker.wake();
102            res
103        })
104    }
105
106    /// Wait until the socket becomes readable.
107    ///
108    /// A socket is readable when a packet has been received, or when there are queued packets in
109    /// the buffer.
110    pub fn wait_recv_ready(&self) -> impl Future<Output = ()> + '_ {
111        poll_fn(move |cx| self.poll_recv_ready(cx))
112    }
113
114    /// Wait until a datagram can be read.
115    ///
116    /// When no datagram is readable, this method will return `Poll::Pending` and
117    /// register the current task to be notified when a datagram is received.
118    ///
119    /// When a datagram is received, this method will return `Poll::Ready`.
120    pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
121        self.with_mut(|s, _| {
122            if s.can_recv() {
123                Poll::Ready(())
124            } else {
125                // socket buffer is empty wait until at least one byte has arrived
126                s.register_recv_waker(cx.waker());
127                Poll::Pending
128            }
129        })
130    }
131
132    /// Receive a datagram.
133    ///
134    /// This method will wait until a datagram is received.
135    ///
136    /// Returns the number of bytes received and the remote endpoint.
137    pub fn recv_from<'s>(
138        &'s self,
139        buf: &'s mut [u8],
140    ) -> impl Future<Output = Result<(usize, UdpMetadata), RecvError>> + 's {
141        poll_fn(|cx| self.poll_recv_from(buf, cx))
142    }
143
144    /// Receive a datagram.
145    ///
146    /// When no datagram is available, this method will return `Poll::Pending` and
147    /// register the current task to be notified when a datagram is received.
148    ///
149    /// When a datagram is received, this method will return `Poll::Ready` with the
150    /// number of bytes received and the remote endpoint.
151    pub fn poll_recv_from(
152        &self,
153        buf: &mut [u8],
154        cx: &mut Context<'_>,
155    ) -> Poll<Result<(usize, UdpMetadata), RecvError>> {
156        self.with_mut(|s, _| match s.recv_slice(buf) {
157            Ok((n, meta)) => Poll::Ready(Ok((n, meta))),
158            // No data ready
159            Err(udp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)),
160            Err(udp::RecvError::Exhausted) => {
161                s.register_recv_waker(cx.waker());
162                Poll::Pending
163            }
164        })
165    }
166
167    /// Receive a datagram with a zero-copy function.
168    ///
169    /// When no datagram is available, this method will return `Poll::Pending` and
170    /// register the current task to be notified when a datagram is received.
171    ///
172    /// When a datagram is received, this method will call the provided function
173    /// with the number of bytes received and the remote endpoint and return
174    /// `Poll::Ready` with the function's returned value.
175    pub async fn recv_from_with<F, R>(&mut self, f: F) -> R
176    where
177        F: FnOnce(&[u8], UdpMetadata) -> R,
178    {
179        let mut f = Some(f);
180        poll_fn(move |cx| {
181            self.with_mut(|s, _| {
182                match s.recv() {
183                    Ok((buffer, endpoint)) => Poll::Ready(unwrap!(f.take())(buffer, endpoint)),
184                    Err(udp::RecvError::Truncated) => unreachable!(),
185                    Err(udp::RecvError::Exhausted) => {
186                        // socket buffer is empty wait until at least one byte has arrived
187                        s.register_recv_waker(cx.waker());
188                        Poll::Pending
189                    }
190                }
191            })
192        })
193        .await
194    }
195
196    /// Wait until the socket becomes writable.
197    ///
198    /// A socket becomes writable when there is space in the buffer, from initial memory or after
199    /// dispatching datagrams on a full buffer.
200    pub fn wait_send_ready(&self) -> impl Future<Output = ()> + '_ {
201        poll_fn(|cx| self.poll_send_ready(cx))
202    }
203
204    /// Wait until a datagram can be sent.
205    ///
206    /// When no datagram can be sent (i.e. the buffer is full), this method will return
207    /// `Poll::Pending` and register the current task to be notified when
208    /// space is freed in the buffer after a datagram has been dispatched.
209    ///
210    /// When a datagram can be sent, this method will return `Poll::Ready`.
211    pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
212        self.with_mut(|s, _| {
213            if s.can_send() {
214                Poll::Ready(())
215            } else {
216                // socket buffer is full wait until a datagram has been dispatched
217                s.register_send_waker(cx.waker());
218                Poll::Pending
219            }
220        })
221    }
222
223    /// Send a datagram to the specified remote endpoint.
224    ///
225    /// This method will wait until the datagram has been sent.
226    ///
227    /// When the remote endpoint is not reachable, this method will return `Err(SendError::NoRoute)`
228    pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), SendError>
229    where
230        T: Into<UdpMetadata>,
231    {
232        let remote_endpoint: UdpMetadata = remote_endpoint.into();
233        poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await
234    }
235
236    /// Send a datagram to the specified remote endpoint.
237    ///
238    /// When the datagram has been sent, this method will return `Poll::Ready(Ok())`.
239    ///
240    /// When the socket's send buffer is full, this method will return `Poll::Pending`
241    /// and register the current task to be notified when the buffer has space available.
242    ///
243    /// When the remote endpoint is not reachable, this method will return `Poll::Ready(Err(Error::NoRoute))`.
244    pub fn poll_send_to<T>(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll<Result<(), SendError>>
245    where
246        T: Into<UdpMetadata>,
247    {
248        self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) {
249            // Entire datagram has been sent
250            Ok(()) => Poll::Ready(Ok(())),
251            Err(udp::SendError::BufferFull) => {
252                s.register_send_waker(cx.waker());
253                Poll::Pending
254            }
255            Err(udp::SendError::Unaddressable) => {
256                // If no sender/outgoing port is specified, there is not really "no route"
257                if s.endpoint().port == 0 {
258                    Poll::Ready(Err(SendError::SocketNotBound))
259                } else {
260                    Poll::Ready(Err(SendError::NoRoute))
261                }
262            }
263        })
264    }
265
266    /// Send a datagram to the specified remote endpoint with a zero-copy function.
267    ///
268    /// This method will wait until the buffer can fit the requested size before
269    /// calling the function to fill its contents.
270    ///
271    /// When the remote endpoint is not reachable, this method will return `Err(SendError::NoRoute)`
272    pub async fn send_to_with<T, F, R>(&mut self, size: usize, remote_endpoint: T, f: F) -> Result<R, SendError>
273    where
274        T: Into<UdpMetadata> + Copy,
275        F: FnOnce(&mut [u8]) -> R,
276    {
277        let mut f = Some(f);
278        poll_fn(move |cx| {
279            self.with_mut(|s, _| {
280                match s.send(size, remote_endpoint) {
281                    Ok(buffer) => Poll::Ready(Ok(unwrap!(f.take())(buffer))),
282                    Err(udp::SendError::BufferFull) => {
283                        s.register_send_waker(cx.waker());
284                        Poll::Pending
285                    }
286                    Err(udp::SendError::Unaddressable) => {
287                        // If no sender/outgoing port is specified, there is not really "no route"
288                        if s.endpoint().port == 0 {
289                            Poll::Ready(Err(SendError::SocketNotBound))
290                        } else {
291                            Poll::Ready(Err(SendError::NoRoute))
292                        }
293                    }
294                }
295            })
296        })
297        .await
298    }
299
300    /// Flush the socket.
301    ///
302    /// This method will wait until the socket is flushed.
303    pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
304        poll_fn(|cx| {
305            self.with_mut(|s, _| {
306                if s.send_queue() == 0 {
307                    Poll::Ready(())
308                } else {
309                    s.register_send_waker(cx.waker());
310                    Poll::Pending
311                }
312            })
313        })
314    }
315
316    /// Returns the local endpoint of the socket.
317    pub fn endpoint(&self) -> IpListenEndpoint {
318        self.with(|s, _| s.endpoint())
319    }
320
321    /// Returns whether the socket is open.
322
323    pub fn is_open(&self) -> bool {
324        self.with(|s, _| s.is_open())
325    }
326
327    /// Close the socket.
328    pub fn close(&mut self) {
329        self.with_mut(|s, _| s.close())
330    }
331
332    /// Returns whether the socket is ready to send data, i.e. it has enough buffer space to hold a packet.
333    pub fn may_send(&self) -> bool {
334        self.with(|s, _| s.can_send())
335    }
336
337    /// Returns whether the socket is ready to receive data, i.e. it has received a packet that's now in the buffer.
338    pub fn may_recv(&self) -> bool {
339        self.with(|s, _| s.can_recv())
340    }
341
342    /// Return the maximum number packets the socket can receive.
343    pub fn packet_recv_capacity(&self) -> usize {
344        self.with(|s, _| s.packet_recv_capacity())
345    }
346
347    /// Return the maximum number packets the socket can receive.
348    pub fn packet_send_capacity(&self) -> usize {
349        self.with(|s, _| s.packet_send_capacity())
350    }
351
352    /// Return the maximum number of bytes inside the recv buffer.
353    pub fn payload_recv_capacity(&self) -> usize {
354        self.with(|s, _| s.payload_recv_capacity())
355    }
356
357    /// Return the maximum number of bytes inside the transmit buffer.
358    pub fn payload_send_capacity(&self) -> usize {
359        self.with(|s, _| s.payload_send_capacity())
360    }
361
362    /// Set the hop limit field in the IP header of sent packets.
363    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
364        self.with_mut(|s, _| s.set_hop_limit(hop_limit))
365    }
366}
367
368impl Drop for UdpSocket<'_> {
369    fn drop(&mut self) {
370        self.stack.with_mut(|i| i.sockets.remove(self.handle));
371    }
372}
373
374fn _assert_covariant<'a, 'b: 'a>(x: UdpSocket<'b>) -> UdpSocket<'a> {
375    x
376}