From ea7be103c14b4469b10a96035464f5f3a41167f0 Mon Sep 17 00:00:00 2001 From: Guus Sliepen Date: Thu, 16 Apr 2020 01:48:33 +0200 Subject: [PATCH] Add support for sendmmsg(). Try to batch consecutive packets sent to the same filedescriptor. To avoid calling malloc() or memcpy(), we tell SPTPS to store the encrypted packet into the next available pre-allocated buffer. --- configure.ac | 2 +- src/meshlink.c | 8 +++++ src/net.h | 5 ++- src/net_packet.c | 82 ++++++++++++++++++++++++++++++++++++++++++++++-- src/sptps.c | 35 ++++++++++++++------- src/sptps.h | 6 +++- 6 files changed, 120 insertions(+), 18 deletions(-) diff --git a/configure.ac b/configure.ac index 45e83b94..8bcb9f1c 100644 --- a/configure.ac +++ b/configure.ac @@ -126,7 +126,7 @@ MeshLink_ATTRIBUTE(__warn_unused_result__) dnl Checks for library functions. AC_TYPE_SIGNAL -AC_CHECK_FUNCS([asprintf fchmod fork gettimeofday random pselect recvmmsg select setns strdup usleep getifaddrs freeifaddrs], +AC_CHECK_FUNCS([asprintf fchmod fork gettimeofday random pselect recvmmsg select sendmmsg setns strdup usleep getifaddrs freeifaddrs], [], [], [#include "$srcdir/src/have.h"] ) diff --git a/src/meshlink.c b/src/meshlink.c index f14c4940..283a9492 100644 --- a/src/meshlink.c +++ b/src/meshlink.c @@ -975,6 +975,10 @@ static struct timespec idle(event_loop_t *loop, void *data) { meshlink_handle_t *mesh = data; struct timespec t, tmin = {3600, 0}; +#ifdef HAVE_SENDMMSG + flush_mmsg(mesh); +#endif + for splay_each(node_t, n, mesh->nodes) { if(!n->utcp) { continue; @@ -1652,7 +1656,9 @@ bool meshlink_start(meshlink_handle_t *mesh) { return false; } +#if defined(HAVE_RECVMMSG) || defined(HAVE_SENDMMSG) init_mmsg(mesh); +#endif init_outgoings(mesh); init_adns(mesh); @@ -1727,7 +1733,9 @@ void meshlink_stop(meshlink_handle_t *mesh) { exit_adns(mesh); exit_outgoings(mesh); +#if defined(HAVE_RECVMMSG) || defined(HAVE_SENDMMSG) exit_mmsg(mesh); +#endif // Ensure we are considered unreachable if(mesh->nodes) { diff --git a/src/net.h b/src/net.h index a198dfe4..9b994425 100644 --- a/src/net.h +++ b/src/net.h @@ -82,10 +82,13 @@ typedef struct outgoing_t { extern void init_outgoings(struct meshlink_handle *mesh); extern void exit_outgoings(struct meshlink_handle *mesh); -#ifdef HAVE_RECVMMSG +#if defined(HAVE_RECVMMSG) || defined(HAVE_SENDMMSG) extern void init_mmsg(struct meshlink_handle *mesh); extern void exit_mmsg(struct meshlink_handle *mesh); #endif +#ifdef HAVE_SENDMMSG +extern void flush_mmsg(struct meshlink_handle *mesh); +#endif extern void retry_outgoing(struct meshlink_handle *mesh, outgoing_t *); extern void handle_incoming_vpn_data(struct event_loop_t *loop, void *, int); diff --git a/src/net_packet.c b/src/net_packet.c index 4d3b2002..6a1d28a3 100644 --- a/src/net_packet.c +++ b/src/net_packet.c @@ -226,10 +226,13 @@ static void mtu_probe_h(meshlink_handle_t *mesh, node_t *n, vpn_packet_t *packet /* VPN packet I/O */ -#ifdef HAVE_RECVMMSG -#define MAX_MMSG 16 +#if defined(HAVE_RECVMMSG) || defined(HAVE_SENDMMSG) +#define MAX_MMSG 32 struct mmsgs { + int count; + int offset; + int fd; struct mmsghdr hdrs[MAX_MMSG]; struct iovec iovs[MAX_MMSG]; sockaddr_t addrs[MAX_MMSG]; @@ -237,6 +240,9 @@ struct mmsgs { }; static void init_mmsg_array(struct mmsgs *mmsgs) { + mmsgs->count = 0; + mmsgs->offset = 0; + for(int i = 0; i < MAX_MMSG; i++) { mmsgs->hdrs[i].msg_hdr.msg_name = &mmsgs->addrs[i]; mmsgs->hdrs[i].msg_hdr.msg_namelen = sizeof(mmsgs->addrs[i]); @@ -253,7 +259,6 @@ void init_mmsg(meshlink_handle_t *mesh) { init_mmsg_array(mesh->in_mmsgs); init_mmsg_array(mesh->out_mmsgs); - } void exit_mmsg(meshlink_handle_t *mesh) { @@ -265,6 +270,64 @@ void exit_mmsg(meshlink_handle_t *mesh) { } #endif +#ifdef HAVE_SENDMMSG +void flush_mmsg(meshlink_handle_t *mesh) { + struct mmsgs *mmsgs = mesh->out_mmsgs; + + int todo = mmsgs->count - mmsgs->offset; + int offset = mmsgs->offset; + + while(todo) { + int result = sendmmsg(mmsgs->fd, mmsgs->hdrs + offset, todo, 0); + + if(result <= 0) { + logger(mesh, MESHLINK_WARNING, "Error sending packet: %s", sockstrerror(errno)); + break; + } + + todo -= result; + offset += result; + } + + mmsgs->count = 0; + mmsgs->offset = 0; +} + +static vpn_packet_t *get_next_mmsg_pkt(meshlink_handle_t *mesh) { + struct mmsgs *mmsgs = mesh->out_mmsgs; + + if(mmsgs->count >= MAX_MMSG) { + flush_mmsg(mesh); + } + + return &mmsgs->pkts[mmsgs->count]; +} + +static void add_mmsg(meshlink_handle_t *mesh, int fd, const void *data, size_t len, const struct sockaddr *sa, socklen_t salen) { + struct mmsgs *mmsgs = mesh->out_mmsgs; + assert(mmsgs->count < MAX_MMSG); + assert(data == get_next_mmsg_pkt(mesh)->data); + + if(mmsgs->fd != fd) { + // Flush all packets from the previous fd + int oldcount = mmsgs->count; + flush_mmsg(mesh); + + // Adjust offset and count to start the next flush with this packet + mmsgs->fd = fd; + mmsgs->count = oldcount; + mmsgs->offset = oldcount; + } + + assert(mmsgs->iovs[mmsgs->count].iov_base == mmsgs->pkts[mmsgs->count].data); + assert(mmsgs->hdrs[mmsgs->count].msg_hdr.msg_iovlen == 1); + mmsgs->iovs[mmsgs->count].iov_len = len; + memcpy(mmsgs->hdrs[mmsgs->count].msg_hdr.msg_name, sa, salen); + mmsgs->hdrs[mmsgs->count].msg_hdr.msg_namelen = salen; + mmsgs->count++; +} +#endif + static void receive_packet(meshlink_handle_t *mesh, node_t *n, vpn_packet_t *packet) { logger(mesh, MESHLINK_DEBUG, "Received packet of %d bytes from %s", packet->len, n->name); @@ -318,6 +381,10 @@ static void send_sptps_packet(meshlink_handle_t *mesh, node_t *n, vpn_packet_t * uint8_t type = 0; +#ifdef HAVE_SENDMMSG + sptps_set_send_buffer(&n->sptps, get_next_mmsg_pkt(mesh)->data, MAXSIZE); +#endif + // If it's a probe, send it immediately without trying to compress it. if(origpkt->probe) { sptps_send_record(&n->sptps, PKT_PROBE, origpkt->data, origpkt->len); @@ -441,6 +508,15 @@ bool send_sptps_data(void *handle, uint8_t type, const void *data, size_t len) { choose_udp_address(mesh, to, &sa, &sock); } +#ifdef HAVE_SENDMMSG + + if(type != PKT_PROBE) { + add_mmsg(mesh, mesh->listen_socket[sock].udp.fd, data, len, &sa->sa, SALEN(sa->sa)); + return true; + } + +#endif + if(sendto(mesh->listen_socket[sock].udp.fd, data, len, 0, &sa->sa, SALEN(sa->sa)) < 0 && !sockwouldblock(sockerrno)) { if(sockmsgsize(sockerrno)) { if(to->maxmtu >= len) { diff --git a/src/sptps.c b/src/sptps.c index ed1f67ff..8d3dd8b1 100644 --- a/src/sptps.c +++ b/src/sptps.c @@ -95,7 +95,12 @@ static void warning(sptps_t *s, const char *format, ...) { // Send a record (datagram version, accepts all record types, handles encryption and authentication). static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) { - char buffer[len + 21UL]; + char *buffer = s->outbuf; + char local_buffer[len + 21UL]; + + if(!buffer || (len + 21UL) > s->outbuflen) { + buffer = local_buffer; + } // Create header with sequence number, length and record type uint32_t seqno = s->outseqno++; @@ -585,21 +590,21 @@ bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { while(len) { // First read the 2 length bytes. - if(s->buflen < 2) { - size_t toread = 2 - s->buflen; + if(s->inbuflen < 2) { + size_t toread = 2 - s->inbuflen; if(toread > len) { toread = len; } - memcpy(s->inbuf + s->buflen, ptr, toread); + memcpy(s->inbuf + s->inbuflen, ptr, toread); - s->buflen += toread; + s->inbuflen += toread; len -= toread; ptr += toread; // Exit early if we don't have the full length. - if(s->buflen < 2) { + if(s->inbuflen < 2) { return true; } @@ -622,19 +627,19 @@ bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { } // Read up to the end of the record. - size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen; + size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->inbuflen; if(toread > len) { toread = len; } - memcpy(s->inbuf + s->buflen, ptr, toread); - s->buflen += toread; + memcpy(s->inbuf + s->inbuflen, ptr, toread); + s->inbuflen += toread; len -= toread; ptr += toread; // If we don't have a whole record, exit. - if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) { + if(s->inbuflen < s->reclen + (s->instate ? 19UL : 3UL)) { return true; } @@ -670,7 +675,7 @@ bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { return error(s, EIO, "Invalid record type %d", type); } - s->buflen = 0; + s->inbuflen = 0; } return true; @@ -721,7 +726,7 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_ return error(s, errno, strerror(errno)); } - s->buflen = 0; + s->inbuflen = 0; } memcpy(s->label, label, labellen); @@ -752,3 +757,9 @@ bool sptps_stop(sptps_t *s) { memset(s, 0, sizeof(*s)); return true; } + +// Set the buffer to use for outgoing packets. +void sptps_set_send_buffer(sptps_t *s, void *data, size_t len) { + s->outbuf = data; + s->outbuflen = len; +} diff --git a/src/sptps.h b/src/sptps.h index c91d3882..5d567116 100644 --- a/src/sptps.h +++ b/src/sptps.h @@ -53,7 +53,7 @@ typedef struct sptps { // Main member variables char *inbuf; - size_t buflen; + size_t inbuflen; chacha_poly1305_ctx_t *incipher; uint32_t replaywin; @@ -61,6 +61,9 @@ typedef struct sptps { uint32_t received; uint16_t reclen; + char *outbuf; + size_t outbuflen; + chacha_poly1305_ctx_t *outcipher; uint32_t outseqno; @@ -96,5 +99,6 @@ extern bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16 extern bool sptps_receive_data(sptps_t *s, const void *data, size_t len) __attribute__((__warn_unused_result__)); extern bool sptps_force_kex(sptps_t *s) __attribute__((__warn_unused_result__)); extern bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len) __attribute__((__warn_unused_result__)); +extern void sptps_set_send_buffer(sptps_t *s, void *data, size_t len); #endif -- 2.39.5