Skip to content

Commit 91e6085

Browse files
committed
Add support for sendmmsg(2) on linux
https://man7.org/linux/man-pages/man2/sendmmsg.2.html Partially addresses bytecodealliance#1156. Signed-off-by: Colin Marc <[email protected]>
1 parent c16dcc7 commit 91e6085

File tree

6 files changed

+293
-7
lines changed

6 files changed

+293
-7
lines changed

Diff for: src/backend/libc/net/syscalls.rs

+21
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
use super::read_sockaddr::initialize_family_to_unspec;
44
use super::send_recv::{RecvFlags, SendFlags};
55
use crate::backend::c;
6+
#[cfg(target_os = "linux")]
7+
use crate::backend::conv::ret_u32;
68
use crate::backend::conv::{borrowed_fd, ret, ret_owned_fd, ret_send_recv, send_recv_len};
79
use crate::fd::{BorrowedFd, OwnedFd};
810
use crate::io;
11+
#[cfg(target_os = "linux")]
12+
use crate::net::MMsgHdr;
913
use crate::net::SocketAddrBuf;
1014
use crate::net::{
1115
addr::SocketAddrArg, AddressFamily, Protocol, Shutdown, SocketAddrAny, SocketFlags, SocketType,
@@ -231,6 +235,23 @@ pub(crate) fn sendmsg_addr(
231235
})
232236
}
233237

238+
#[cfg(target_os = "linux")]
239+
pub(crate) fn sendmmsg(
240+
sockfd: BorrowedFd<'_>,
241+
msgs: &mut [MMsgHdr<'_>],
242+
flags: SendFlags,
243+
) -> io::Result<usize> {
244+
unsafe {
245+
ret_u32(c::sendmmsg(
246+
borrowed_fd(sockfd),
247+
msgs.as_mut_ptr() as _,
248+
msgs.len().try_into().unwrap_or(c::c_uint::MAX),
249+
bitflags_bits!(flags),
250+
))
251+
.map(|ret| ret as usize)
252+
}
253+
}
254+
234255
#[cfg(not(any(
235256
apple,
236257
windows,

Diff for: src/backend/linux_raw/c.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ pub(crate) use linux_raw_sys::{
7676
general::{O_CLOEXEC as SOCK_CLOEXEC, O_NONBLOCK as SOCK_NONBLOCK},
7777
if_ether::*,
7878
net::{
79-
linger, msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet,
8079
__kernel_sa_family_t as sa_family_t, __kernel_sockaddr_storage as sockaddr_storage,
81-
cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, AF_APPLETALK,
82-
AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, AF_ECONET,
83-
AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, AF_LLC,
84-
AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE,
80+
cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, linger, mmsghdr,
81+
msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet,
82+
AF_APPLETALK, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN,
83+
AF_ECONET, AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY,
84+
AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE,
8585
AF_RXRPC, AF_SECURITY, AF_SNA, AF_TIPC, AF_UNIX, AF_UNSPEC, AF_WANPIPE, AF_X25, AF_XDP,
8686
IP6T_SO_ORIGINAL_DST, IPPROTO_FRAGMENT, IPPROTO_ICMPV6, IPPROTO_MH, IPPROTO_ROUTING,
8787
IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_FREEBIND, IPV6_MULTICAST_HOPS,

Diff for: src/backend/linux_raw/net/syscalls.rs

+30-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@ use super::msghdr::{with_msghdr, with_noaddr_msghdr, with_recv_msghdr};
99
use super::read_sockaddr::initialize_family_to_unspec;
1010
use super::send_recv::{RecvFlags, ReturnFlags, SendFlags};
1111
use crate::backend::c;
12+
#[cfg(target_os = "linux")]
13+
use crate::backend::conv::slice_mut;
1214
use crate::backend::conv::{
1315
by_mut, by_ref, c_int, c_uint, pass_usize, ret, ret_owned_fd, ret_usize, size_of, slice,
1416
socklen_t, zero,
1517
};
1618
use crate::backend::reg::raw_arg;
1719
use crate::fd::{BorrowedFd, OwnedFd};
1820
use crate::io::{self, IoSlice, IoSliceMut};
21+
#[cfg(target_os = "linux")]
22+
use crate::net::MMsgHdr;
1923
use crate::net::SocketAddrBuf;
2024
use crate::net::{
2125
addr::SocketAddrArg, AddressFamily, Protocol, RecvAncillaryBuffer, RecvMsg,
@@ -28,8 +32,8 @@ use {
2832
crate::backend::reg::{ArgReg, SocketArg},
2933
linux_raw_sys::net::{
3034
SYS_ACCEPT, SYS_ACCEPT4, SYS_BIND, SYS_CONNECT, SYS_GETPEERNAME, SYS_GETSOCKNAME,
31-
SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMSG, SYS_SENDTO,
32-
SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR,
35+
SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMMSG, SYS_SENDMSG,
36+
SYS_SENDTO, SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR,
3337
},
3438
};
3539

@@ -331,6 +335,30 @@ pub(crate) fn sendmsg_addr(
331335
})
332336
}
333337

338+
#[cfg(target_os = "linux")]
339+
#[inline]
340+
pub(crate) fn sendmmsg(
341+
sockfd: BorrowedFd<'_>,
342+
msgs: &mut [MMsgHdr<'_>],
343+
flags: SendFlags,
344+
) -> io::Result<usize> {
345+
let (msgs, len) = slice_mut(msgs);
346+
347+
#[cfg(not(target_arch = "x86"))]
348+
let result = unsafe { ret_usize(syscall!(__NR_sendmmsg, sockfd, msgs, len, flags)) };
349+
350+
#[cfg(target_arch = "x86")]
351+
let result = unsafe {
352+
ret_usize(syscall!(
353+
__NR_socketcall,
354+
x86_sys(SYS_SENDMMSG),
355+
slice_just_addr::<ArgReg<'_, SocketArg>, _>(&[sockfd.into(), msgs, len, flags.into()])
356+
))
357+
};
358+
359+
result
360+
}
361+
334362
#[inline]
335363
pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> {
336364
#[cfg(not(target_arch = "x86"))]

Diff for: src/net/send_recv/msg.rs

+60
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
#![allow(unsafe_code)]
44

5+
#[cfg(target_os = "linux")]
6+
use crate::backend::net::msghdr::{with_msghdr, with_noaddr_msghdr};
57
use crate::backend::{self, c};
68
use crate::fd::{AsFd, BorrowedFd, OwnedFd};
79
use crate::io::{self, IoSlice, IoSliceMut};
@@ -591,6 +593,48 @@ impl<'buf> Iterator for AncillaryDrain<'buf> {
591593

592594
impl FusedIterator for AncillaryDrain<'_> {}
593595

596+
/// An ABI-compatible wrapper for `mmsghdr`, for sending multiple messages with
597+
/// [sendmmsg].
598+
#[cfg(target_os = "linux")]
599+
#[repr(transparent)]
600+
pub struct MMsgHdr<'a> {
601+
raw: c::mmsghdr,
602+
_phantom: PhantomData<&'a mut ()>,
603+
}
604+
605+
#[cfg(target_os = "linux")]
606+
impl<'a> MMsgHdr<'a> {
607+
/// Constructs a new message with no destination address.
608+
pub fn new(iov: &'a [IoSlice<'_>], control: &'a mut SendAncillaryBuffer<'_, '_, '_>) -> Self {
609+
with_noaddr_msghdr(iov, control, Self::wrap)
610+
}
611+
612+
/// Constructs a new message to a specific address.
613+
pub fn new_with_addr(
614+
addr: &'a SocketAddrAny,
615+
iov: &'a [IoSlice<'_>],
616+
control: &'a mut SendAncillaryBuffer<'_, '_, '_>,
617+
) -> MMsgHdr<'a> {
618+
with_msghdr(addr, iov, control, Self::wrap)
619+
}
620+
621+
fn wrap(msg_hdr: c::msghdr) -> Self {
622+
Self {
623+
raw: c::mmsghdr {
624+
msg_hdr,
625+
msg_len: 0,
626+
},
627+
_phantom: PhantomData,
628+
}
629+
}
630+
631+
/// Returns the number of bytes sent. This will return 0 until after a
632+
/// successful call to [sendmmsg].
633+
pub fn bytes_sent(&self) -> usize {
634+
self.raw.msg_len as _
635+
}
636+
}
637+
594638
/// `sendmsg(msghdr)`—Sends a message on a socket.
595639
///
596640
/// This function is for use on connected sockets, as it doesn't have
@@ -656,6 +700,22 @@ pub fn sendmsg_addr(
656700
backend::net::syscalls::sendmsg_addr(socket.as_fd(), addr, iov, control, flags)
657701
}
658702

703+
/// `sendmmsg(msghdr)`—Sends multiple messages on a socket.
704+
///
705+
/// # References
706+
/// - [Linux]
707+
///
708+
/// [Linux]: https://man7.org/linux/man-pages/man2/sendmmsg.2.html
709+
#[inline]
710+
#[cfg(target_os = "linux")]
711+
pub fn sendmmsg(
712+
socket: impl AsFd,
713+
msgs: &mut [MMsgHdr<'_>],
714+
flags: SendFlags,
715+
) -> io::Result<usize> {
716+
backend::net::syscalls::sendmmsg(socket.as_fd(), msgs, flags)
717+
}
718+
659719
/// `recvmsg(msghdr)`—Receives a message from a socket.
660720
///
661721
/// # References

Diff for: tests/net/v4.rs

+86
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,89 @@ fn test_v4_msg() {
194194
client.join().unwrap();
195195
server.join().unwrap();
196196
}
197+
198+
#[test]
199+
#[cfg(target_os = "linux")]
200+
fn test_v4_sendmmsg() {
201+
crate::init();
202+
203+
use std::net::TcpStream;
204+
205+
use rustix::io::IoSlice;
206+
use rustix::net::{sendmmsg, MMsgHdr};
207+
208+
fn server(ready: Arc<(Mutex<u16>, Condvar)>) {
209+
let connection_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap();
210+
211+
let name = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0);
212+
bind(&connection_socket, &name).unwrap();
213+
214+
let who = getsockname(&connection_socket).unwrap();
215+
let who = SocketAddrV4::try_from(who).unwrap();
216+
217+
listen(&connection_socket, 1).unwrap();
218+
219+
{
220+
let (lock, cvar) = &*ready;
221+
let mut port = lock.lock().unwrap();
222+
*port = who.port();
223+
cvar.notify_all();
224+
}
225+
226+
let mut buffer = vec![0; 13];
227+
let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into();
228+
229+
std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap();
230+
assert_eq!(String::from_utf8_lossy(&buffer), "hello...world");
231+
}
232+
233+
fn client(ready: Arc<(Mutex<u16>, Condvar)>) {
234+
let port = {
235+
let (lock, cvar) = &*ready;
236+
let mut port = lock.lock().unwrap();
237+
while *port == 0 {
238+
port = cvar.wait(port).unwrap();
239+
}
240+
*port
241+
};
242+
243+
let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port);
244+
let data_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap();
245+
connect(&data_socket, &addr).unwrap();
246+
247+
let mut off = 0;
248+
while off < 2 {
249+
let sent = sendmmsg(
250+
&data_socket,
251+
&mut [
252+
MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()),
253+
MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()),
254+
][off..],
255+
SendFlags::empty(),
256+
)
257+
.unwrap();
258+
259+
off += sent;
260+
}
261+
}
262+
263+
let ready = Arc::new((Mutex::new(0_u16), Condvar::new()));
264+
let ready_clone = Arc::clone(&ready);
265+
266+
let server = thread::Builder::new()
267+
.name("server".to_string())
268+
.spawn(move || {
269+
server(ready);
270+
})
271+
.unwrap();
272+
273+
let client = thread::Builder::new()
274+
.name("client".to_string())
275+
.spawn(move || {
276+
client(ready_clone);
277+
})
278+
.unwrap();
279+
280+
client.join().unwrap();
281+
server.join().unwrap();
282+
}

Diff for: tests/net/v6.rs

+91
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,94 @@ fn test_v6_msg() {
193193
client.join().unwrap();
194194
server.join().unwrap();
195195
}
196+
197+
#[test]
198+
#[cfg(target_os = "linux")]
199+
fn test_v6_sendmmsg() {
200+
crate::init();
201+
202+
use std::net::TcpStream;
203+
204+
use rustix::io::IoSlice;
205+
use rustix::net::addr::SocketAddrArg as _;
206+
use rustix::net::{sendmmsg, MMsgHdr};
207+
208+
fn server(ready: Arc<(Mutex<u16>, Condvar)>) {
209+
let connection_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap();
210+
211+
let name = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 0, 0, 0);
212+
bind(&connection_socket, &name).unwrap();
213+
214+
let who = getsockname(&connection_socket).unwrap();
215+
let who = SocketAddrV6::try_from(who).unwrap();
216+
217+
listen(&connection_socket, 1).unwrap();
218+
219+
{
220+
let (lock, cvar) = &*ready;
221+
let mut port = lock.lock().unwrap();
222+
*port = who.port();
223+
cvar.notify_all();
224+
}
225+
226+
let mut buffer = vec![0; 13];
227+
let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into();
228+
229+
std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap();
230+
assert_eq!(String::from_utf8_lossy(&buffer), "hello...world");
231+
}
232+
233+
fn client(ready: Arc<(Mutex<u16>, Condvar)>) {
234+
let port = {
235+
let (lock, cvar) = &*ready;
236+
let mut port = lock.lock().unwrap();
237+
while *port == 0 {
238+
port = cvar.wait(port).unwrap();
239+
}
240+
*port
241+
};
242+
243+
let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), port, 0, 0);
244+
let data_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap();
245+
connect(&data_socket, &addr).unwrap();
246+
247+
let mut off = 0;
248+
while off < 2 {
249+
let sent = sendmmsg(
250+
&data_socket,
251+
&mut [
252+
MMsgHdr::new_with_addr(
253+
&addr.as_any(),
254+
&[IoSlice::new(b"hello")],
255+
&mut Default::default(),
256+
),
257+
MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()),
258+
][off..],
259+
SendFlags::empty(),
260+
)
261+
.unwrap();
262+
263+
off += sent;
264+
}
265+
}
266+
267+
let ready = Arc::new((Mutex::new(0_u16), Condvar::new()));
268+
let ready_clone = Arc::clone(&ready);
269+
270+
let server = thread::Builder::new()
271+
.name("server".to_string())
272+
.spawn(move || {
273+
server(ready);
274+
})
275+
.unwrap();
276+
277+
let client = thread::Builder::new()
278+
.name("client".to_string())
279+
.spawn(move || {
280+
client(ready_clone);
281+
})
282+
.unwrap();
283+
284+
client.join().unwrap();
285+
server.join().unwrap();
286+
}

0 commit comments

Comments
 (0)