]> git.meshlink.io Git - utcp/blobdiff - utcp.c
Fix potential segmentation fault when the receive callback is not set.
[utcp] / utcp.c
diff --git a/utcp.c b/utcp.c
index 10d026efabf2679db6ca419c71131e5fecd7e582..f7c1859de2abe5ef05873bb1ef6565d449810e3d 100644 (file)
--- a/utcp.c
+++ b/utcp.c
@@ -74,6 +74,7 @@ static bool timespec_lt(const struct timespec *a, const struct timespec *b) {
 
 static void timespec_clear(struct timespec *a) {
        a->tv_sec = 0;
+       a->tv_nsec = 0;
 }
 
 static bool timespec_isset(const struct timespec *a) {
@@ -400,7 +401,7 @@ static void buffer_exit(struct buffer *buf) {
 }
 
 static uint32_t buffer_free(const struct buffer *buf) {
-       return buf->maxsize - buf->used;
+       return buf->maxsize > buf->used ? buf->maxsize - buf->used : 0;
 }
 
 // Connections are stored in a sorted list.
@@ -524,6 +525,9 @@ static struct utcp_connection *allocate_connection(struct utcp *utcp, uint16_t s
        c->snd.cwnd = (utcp->mss > 2190 ? 2 : utcp->mss > 1095 ? 3 : 4) * utcp->mss;
        c->snd.ssthresh = ~0;
        debug_cwnd(c);
+       c->srtt = 0;
+       c->rttvar = 0;
+       c->rto = START_RTO;
        c->utcp = utcp;
 
        // Add it to the sorted list of connections
@@ -549,36 +553,34 @@ static void update_rtt(struct utcp_connection *c, uint32_t rtt) {
                return;
        }
 
-       struct utcp *utcp = c->utcp;
-
-       if(!utcp->srtt) {
-               utcp->srtt = rtt;
-               utcp->rttvar = rtt / 2;
+       if(!c->srtt) {
+               c->srtt = rtt;
+               c->rttvar = rtt / 2;
        } else {
-               utcp->rttvar = (utcp->rttvar * 3 + absdiff(utcp->srtt, rtt)) / 4;
-               utcp->srtt = (utcp->srtt * 7 + rtt) / 8;
+               c->rttvar = (c->rttvar * 3 + absdiff(c->srtt, rtt)) / 4;
+               c->srtt = (c->srtt * 7 + rtt) / 8;
        }
 
-       utcp->rto = utcp->srtt + max(4 * utcp->rttvar, CLOCK_GRANULARITY);
+       c->rto = c->srtt + max(4 * c->rttvar, CLOCK_GRANULARITY);
 
-       if(utcp->rto > MAX_RTO) {
-               utcp->rto = MAX_RTO;
+       if(c->rto > MAX_RTO) {
+               c->rto = MAX_RTO;
        }
 
-       debug(c, "rtt %u srtt %u rttvar %u rto %u\n", rtt, utcp->srtt, utcp->rttvar, utcp->rto);
+       debug(c, "rtt %u srtt %u rttvar %u rto %u\n", rtt, c->srtt, c->rttvar, c->rto);
 }
 
 static void start_retransmit_timer(struct utcp_connection *c) {
        clock_gettime(UTCP_CLOCK, &c->rtrx_timeout);
 
-       uint32_t rto = c->utcp->rto;
+       uint32_t rto = c->rto;
 
        while(rto > USEC_PER_SEC) {
                c->rtrx_timeout.tv_sec++;
                rto -= USEC_PER_SEC;
        }
 
-       c->rtrx_timeout.tv_nsec += c->utcp->rto * 1000;
+       c->rtrx_timeout.tv_nsec += rto * 1000;
 
        if(c->rtrx_timeout.tv_nsec >= NSEC_PER_SEC) {
                c->rtrx_timeout.tv_nsec -= NSEC_PER_SEC;
@@ -814,7 +816,6 @@ ssize_t utcp_send(struct utcp_connection *c, const void *data, size_t len) {
        if(!is_reliable(c)) {
                c->snd.una = c->snd.nxt = c->snd.last;
                buffer_discard(&c->sndbuf, c->sndbuf.used);
-               c->do_poll = true;
        }
 
        if(is_reliable(c) && !timespec_isset(&c->rtrx_timeout)) {
@@ -973,10 +974,10 @@ static void retransmit(struct utcp_connection *c) {
        }
 
        start_retransmit_timer(c);
-       utcp->rto *= 2;
+       c->rto *= 2;
 
-       if(utcp->rto > MAX_RTO) {
-               utcp->rto = MAX_RTO;
+       if(c->rto > MAX_RTO) {
+               c->rto = MAX_RTO;
        }
 
        c->rtt_start.tv_sec = 0; // invalidate RTT timer
@@ -1044,8 +1045,14 @@ static void handle_out_of_order(struct utcp_connection *c, uint32_t offset, cons
        // Packet loss or reordering occured. Store the data in the buffer.
        ssize_t rxd = buffer_put_at(&c->rcvbuf, offset, data, len);
 
-       if(rxd < 0 || (size_t)rxd < len) {
-               abort();
+       if(rxd <= 0) {
+               debug(c, "packet outside receive buffer, dropping\n");
+               return;
+       }
+
+       if((size_t)rxd < len) {
+               debug(c, "packet partially outside receive buffer\n");
+               len = rxd;
        }
 
        // Make note of where we put it.
@@ -1106,11 +1113,14 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
                        size_t offset = len;
                        len = c->sacks[0].offset + c->sacks[0].len;
                        size_t remainder = len - offset;
-                       ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
 
-                       if(rxd != (ssize_t)remainder) {
-                               // TODO: handle the application not accepting all data.
-                               abort();
+                       if(c->recv) {
+                               ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
+
+                               if(rxd != (ssize_t)remainder) {
+                                       // TODO: handle the application not accepting all data.
+                                       abort();
+                               }
                        }
                }
        }
@@ -1125,7 +1135,10 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
 static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr, const void *data, size_t len) {
        // Fast path for unfragmented packets
        if(!hdr->wnd && !(hdr->ctl & MF)) {
-               c->recv(c, data, len);
+               if(c->recv) {
+                       c->recv(c, data, len);
+               }
+
                c->rcv.nxt = hdr->seq + len;
                return;
        }
@@ -1152,7 +1165,7 @@ static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr,
        }
 
        // Send the packet if it's the final fragment
-       if(!(hdr->ctl & MF)) {
+       if(!(hdr->ctl & MF) && c->recv) {
                buffer_call(&c->rcvbuf, c->recv, c, 0, hdr->wnd + len);
        }
 
@@ -1167,10 +1180,6 @@ static void handle_incoming_data(struct utcp_connection *c, const struct hdr *hd
 
        uint32_t offset = seqdiff(hdr->seq, c->rcv.nxt);
 
-       if(offset + len > c->rcvbuf.maxsize) {
-               abort();
-       }
-
        if(offset) {
                handle_out_of_order(c, offset, data, len);
        } else {
@@ -1590,7 +1599,10 @@ synack:
 
                if(data_acked) {
                        buffer_discard(&c->sndbuf, data_acked);
-                       c->do_poll = true;
+
+                       if(is_reliable(c)) {
+                               c->do_poll = true;
+                       }
                }
 
                // Also advance snd.nxt if possible
@@ -2173,7 +2185,6 @@ struct utcp *utcp_init(utcp_accept_t accept, utcp_pre_accept_t pre_accept, utcp_
        utcp->priv = priv;
        utcp_set_mtu(utcp, DEFAULT_MTU);
        utcp->timeout = DEFAULT_USER_TIMEOUT; // sec
-       utcp->rto = START_RTO; // usec
 
        return utcp;
 }
@@ -2266,10 +2277,10 @@ void utcp_reset_timers(struct utcp *utcp) {
                }
 
                c->rtt_start.tv_sec = 0;
-       }
 
-       if(utcp->rto > START_RTO) {
-               utcp->rto = START_RTO;
+               if(c->rto > START_RTO) {
+                       c->rto = START_RTO;
+               }
        }
 }
 
@@ -2315,7 +2326,7 @@ void utcp_set_sndbuf(struct utcp_connection *c, size_t size) {
                c->sndbuf.maxsize = -1;
        }
 
-       c->do_poll = buffer_free(&c->sndbuf);
+       c->do_poll = is_reliable(c) && buffer_free(&c->sndbuf);
 }
 
 size_t utcp_get_rcvbuf(struct utcp_connection *c) {
@@ -2383,7 +2394,7 @@ void utcp_set_recv_cb(struct utcp_connection *c, utcp_recv_t recv) {
 void utcp_set_poll_cb(struct utcp_connection *c, utcp_poll_t poll) {
        if(c) {
                c->poll = poll;
-               c->do_poll = buffer_free(&c->sndbuf);
+               c->do_poll = is_reliable(c) && buffer_free(&c->sndbuf);
        }
 }
 
@@ -2436,11 +2447,11 @@ void utcp_offline(struct utcp *utcp, bool offline) {
                        }
 
                        utcp->connections[i]->rtt_start.tv_sec = 0;
-               }
-       }
 
-       if(!offline && utcp->rto > START_RTO) {
-               utcp->rto = START_RTO;
+                       if(c->rto > START_RTO) {
+                               c->rto = START_RTO;
+                       }
+               }
        }
 }