1use core::future::{poll_fn, Future};
4use core::mem;
5use core::task::{Context, Poll};
6
7use smoltcp::iface::{Interface, SocketHandle};
8use smoltcp::socket::udp;
9pub use smoltcp::socket::udp::{PacketMetadata, UdpMetadata};
10use smoltcp::wire::IpListenEndpoint;
11
12use crate::Stack;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub enum BindError {
18 InvalidState,
20 NoRoute,
22}
23
24#[derive(PartialEq, Eq, Clone, Copy, Debug)]
26#[cfg_attr(feature = "defmt", derive(defmt::Format))]
27pub enum SendError {
28 NoRoute,
30 SocketNotBound,
32}
33
34#[derive(PartialEq, Eq, Clone, Copy, Debug)]
36#[cfg_attr(feature = "defmt", derive(defmt::Format))]
37pub enum RecvError {
38 Truncated,
40}
41
42pub struct UdpSocket<'a> {
44 stack: Stack<'a>,
45 handle: SocketHandle,
46}
47
48impl<'a> UdpSocket<'a> {
49 pub fn new(
51 stack: Stack<'a>,
52 rx_meta: &'a mut [PacketMetadata],
53 rx_buffer: &'a mut [u8],
54 tx_meta: &'a mut [PacketMetadata],
55 tx_buffer: &'a mut [u8],
56 ) -> Self {
57 let handle = stack.with_mut(|i| {
58 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
59 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
60 let tx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(tx_meta) };
61 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
62 i.sockets.add(udp::Socket::new(
63 udp::PacketBuffer::new(rx_meta, rx_buffer),
64 udp::PacketBuffer::new(tx_meta, tx_buffer),
65 ))
66 });
67
68 Self { stack, handle }
69 }
70
71 pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError>
73 where
74 T: Into<IpListenEndpoint>,
75 {
76 let mut endpoint = endpoint.into();
77
78 if endpoint.port == 0 {
79 endpoint.port = self.stack.with_mut(|i| i.get_local_port());
81 }
82
83 match self.with_mut(|s, _| s.bind(endpoint)) {
84 Ok(()) => Ok(()),
85 Err(udp::BindError::InvalidState) => Err(BindError::InvalidState),
86 Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute),
87 }
88 }
89
90 fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
91 self.stack.with(|i| {
92 let socket = i.sockets.get::<udp::Socket>(self.handle);
93 f(socket, &i.iface)
94 })
95 }
96
97 fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
98 self.stack.with_mut(|i| {
99 let socket = i.sockets.get_mut::<udp::Socket>(self.handle);
100 let res = f(socket, &mut i.iface);
101 i.waker.wake();
102 res
103 })
104 }
105
106 pub fn wait_recv_ready(&self) -> impl Future<Output = ()> + '_ {
111 poll_fn(move |cx| self.poll_recv_ready(cx))
112 }
113
114 pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
121 self.with_mut(|s, _| {
122 if s.can_recv() {
123 Poll::Ready(())
124 } else {
125 s.register_recv_waker(cx.waker());
127 Poll::Pending
128 }
129 })
130 }
131
132 pub fn recv_from<'s>(
138 &'s self,
139 buf: &'s mut [u8],
140 ) -> impl Future<Output = Result<(usize, UdpMetadata), RecvError>> + 's {
141 poll_fn(|cx| self.poll_recv_from(buf, cx))
142 }
143
144 pub fn poll_recv_from(
152 &self,
153 buf: &mut [u8],
154 cx: &mut Context<'_>,
155 ) -> Poll<Result<(usize, UdpMetadata), RecvError>> {
156 self.with_mut(|s, _| match s.recv_slice(buf) {
157 Ok((n, meta)) => Poll::Ready(Ok((n, meta))),
158 Err(udp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)),
160 Err(udp::RecvError::Exhausted) => {
161 s.register_recv_waker(cx.waker());
162 Poll::Pending
163 }
164 })
165 }
166
167 pub async fn recv_from_with<F, R>(&mut self, f: F) -> R
176 where
177 F: FnOnce(&[u8], UdpMetadata) -> R,
178 {
179 let mut f = Some(f);
180 poll_fn(move |cx| {
181 self.with_mut(|s, _| {
182 match s.recv() {
183 Ok((buffer, endpoint)) => Poll::Ready(unwrap!(f.take())(buffer, endpoint)),
184 Err(udp::RecvError::Truncated) => unreachable!(),
185 Err(udp::RecvError::Exhausted) => {
186 s.register_recv_waker(cx.waker());
188 Poll::Pending
189 }
190 }
191 })
192 })
193 .await
194 }
195
196 pub fn wait_send_ready(&self) -> impl Future<Output = ()> + '_ {
201 poll_fn(|cx| self.poll_send_ready(cx))
202 }
203
204 pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
212 self.with_mut(|s, _| {
213 if s.can_send() {
214 Poll::Ready(())
215 } else {
216 s.register_send_waker(cx.waker());
218 Poll::Pending
219 }
220 })
221 }
222
223 pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), SendError>
229 where
230 T: Into<UdpMetadata>,
231 {
232 let remote_endpoint: UdpMetadata = remote_endpoint.into();
233 poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await
234 }
235
236 pub fn poll_send_to<T>(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll<Result<(), SendError>>
245 where
246 T: Into<UdpMetadata>,
247 {
248 self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) {
249 Ok(()) => Poll::Ready(Ok(())),
251 Err(udp::SendError::BufferFull) => {
252 s.register_send_waker(cx.waker());
253 Poll::Pending
254 }
255 Err(udp::SendError::Unaddressable) => {
256 if s.endpoint().port == 0 {
258 Poll::Ready(Err(SendError::SocketNotBound))
259 } else {
260 Poll::Ready(Err(SendError::NoRoute))
261 }
262 }
263 })
264 }
265
266 pub async fn send_to_with<T, F, R>(&mut self, size: usize, remote_endpoint: T, f: F) -> Result<R, SendError>
273 where
274 T: Into<UdpMetadata> + Copy,
275 F: FnOnce(&mut [u8]) -> R,
276 {
277 let mut f = Some(f);
278 poll_fn(move |cx| {
279 self.with_mut(|s, _| {
280 match s.send(size, remote_endpoint) {
281 Ok(buffer) => Poll::Ready(Ok(unwrap!(f.take())(buffer))),
282 Err(udp::SendError::BufferFull) => {
283 s.register_send_waker(cx.waker());
284 Poll::Pending
285 }
286 Err(udp::SendError::Unaddressable) => {
287 if s.endpoint().port == 0 {
289 Poll::Ready(Err(SendError::SocketNotBound))
290 } else {
291 Poll::Ready(Err(SendError::NoRoute))
292 }
293 }
294 }
295 })
296 })
297 .await
298 }
299
300 pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
304 poll_fn(|cx| {
305 self.with_mut(|s, _| {
306 if s.send_queue() == 0 {
307 Poll::Ready(())
308 } else {
309 s.register_send_waker(cx.waker());
310 Poll::Pending
311 }
312 })
313 })
314 }
315
316 pub fn endpoint(&self) -> IpListenEndpoint {
318 self.with(|s, _| s.endpoint())
319 }
320
321 pub fn is_open(&self) -> bool {
324 self.with(|s, _| s.is_open())
325 }
326
327 pub fn close(&mut self) {
329 self.with_mut(|s, _| s.close())
330 }
331
332 pub fn may_send(&self) -> bool {
334 self.with(|s, _| s.can_send())
335 }
336
337 pub fn may_recv(&self) -> bool {
339 self.with(|s, _| s.can_recv())
340 }
341
342 pub fn packet_recv_capacity(&self) -> usize {
344 self.with(|s, _| s.packet_recv_capacity())
345 }
346
347 pub fn packet_send_capacity(&self) -> usize {
349 self.with(|s, _| s.packet_send_capacity())
350 }
351
352 pub fn payload_recv_capacity(&self) -> usize {
354 self.with(|s, _| s.payload_recv_capacity())
355 }
356
357 pub fn payload_send_capacity(&self) -> usize {
359 self.with(|s, _| s.payload_send_capacity())
360 }
361
362 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
364 self.with_mut(|s, _| s.set_hop_limit(hop_limit))
365 }
366}
367
368impl Drop for UdpSocket<'_> {
369 fn drop(&mut self) {
370 self.stack.with_mut(|i| i.sockets.remove(self.handle));
371 }
372}
373
374fn _assert_covariant<'a, 'b: 'a>(x: UdpSocket<'b>) -> UdpSocket<'a> {
375 x
376}