]> git.meshlink.io Git - utcp/blob - utcp.c
Fix free_connection() moving the wrong memory.
[utcp] / utcp.c
1 /*
2     utcp.c -- Userspace TCP
3     Copyright (C) 2014 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #define _GNU_SOURCE
21
22 #include <assert.h>
23 #include <errno.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <stdint.h>
27 #include <stdbool.h>
28 #include <string.h>
29 #include <unistd.h>
30 #include <sys/time.h>
31 #include <sys/socket.h>
32
33 #include "utcp_priv.h"
34
35 #ifndef EBADMSG
36 #define EBADMSG         104
37 #endif
38
39 #ifndef SHUT_RDWR
40 #define SHUT_RDWR 2
41 #endif
42
43 #ifdef poll
44 #undef poll
45 #endif
46
47 #ifndef timersub
48 #define timersub(a, b, r) do {\
49         (r)->tv_sec = (a)->tv_sec - (b)->tv_sec;\
50         (r)->tv_usec = (a)->tv_usec - (b)->tv_usec;\
51         if((r)->tv_usec < 0)\
52                 (r)->tv_sec--, (r)->tv_usec += 1000000;\
53 } while (0)
54 #endif
55
56 #ifdef UTCP_DEBUG
57 #include <stdarg.h>
58
59 static void debug(const char *format, ...) {
60         va_list ap;
61         va_start(ap, format);
62         vfprintf(stderr, format, ap);
63         va_end(ap);
64 }
65
66 static void print_packet(struct utcp *utcp, const char *dir, const void *pkt, size_t len) {
67         struct hdr hdr;
68         if(len < sizeof hdr) {
69                 debug("%p %s: short packet (%zu bytes)\n", utcp, dir, len);
70                 return;
71         }
72
73         memcpy(&hdr, pkt, sizeof hdr);
74         fprintf (stderr, "%p %s: len=%zu, src=%u dst=%u seq=%u ack=%u wnd=%u ctl=", utcp, dir, len, hdr.src, hdr.dst, hdr.seq, hdr.ack, hdr.wnd);
75         if(hdr.ctl & SYN)
76                 debug("SYN");
77         if(hdr.ctl & RST)
78                 debug("RST");
79         if(hdr.ctl & FIN)
80                 debug("FIN");
81         if(hdr.ctl & ACK)
82                 debug("ACK");
83
84         if(len > sizeof hdr) {
85                 debug(" data=");
86                 for(int i = sizeof hdr; i < len; i++) {
87                         const char *data = pkt;
88                         debug("%c", data[i] >= 32 ? data[i] : '.');
89                 }
90         }
91
92         debug("\n");
93 }
94 #else
95 #define debug(...)
96 #define print_packet(...)
97 #endif
98
99 static void set_state(struct utcp_connection *c, enum state state) {
100         c->state = state;
101         if(state == ESTABLISHED)
102                 timerclear(&c->conn_timeout);
103         debug("%p new state: %s\n", c->utcp, strstate[state]);
104 }
105
106 static inline void list_connections(struct utcp *utcp) {
107         debug("%p has %d connections:\n", utcp, utcp->nconnections);
108         for(int i = 0; i < utcp->nconnections; i++)
109                 debug("  %u -> %u state %s\n", utcp->connections[i]->src, utcp->connections[i]->dst, strstate[utcp->connections[i]->state]);
110 }
111
112 static int32_t seqdiff(uint32_t a, uint32_t b) {
113         return a - b;
114 }
115
116 // Connections are stored in a sorted list.
117 // This gives O(log(N)) lookup time, O(N log(N)) insertion time and O(N) deletion time.
118
119 static int compare(const void *va, const void *vb) {
120         assert(va && vb);
121
122         const struct utcp_connection *a = *(struct utcp_connection **)va;
123         const struct utcp_connection *b = *(struct utcp_connection **)vb;
124
125         assert(a && b);
126         assert(a->src && b->src);
127
128         int c = (int)a->src - (int)b->src;
129         if(c)
130                 return c;
131         c = (int)a->dst - (int)b->dst;
132         return c;
133 }
134
135 static struct utcp_connection *find_connection(const struct utcp *utcp, uint16_t src, uint16_t dst) {
136         if(!utcp->nconnections)
137                 return NULL;
138         struct utcp_connection key = {
139                 .src = src,
140                 .dst = dst,
141         }, *keyp = &key;
142         struct utcp_connection **match = bsearch(&keyp, utcp->connections, utcp->nconnections, sizeof *utcp->connections, compare);
143         return match ? *match : NULL;
144 }
145
146 static void free_connection(struct utcp_connection *c) {
147         struct utcp *utcp = c->utcp;
148         struct utcp_connection **cp = bsearch(&c, utcp->connections, utcp->nconnections, sizeof *utcp->connections, compare);
149
150         assert(cp);
151
152         int i = cp - utcp->connections;
153         memmove(cp, cp + 1, (utcp->nconnections - i - 1) * sizeof *cp);
154         utcp->nconnections--;
155
156         free(c->sndbuf);
157         free(c);
158 }
159
160 static struct utcp_connection *allocate_connection(struct utcp *utcp, uint16_t src, uint16_t dst) {
161         // Check whether this combination of src and dst is free
162
163         if(src) {
164                 if(find_connection(utcp, src, dst)) {
165                         errno = EADDRINUSE;
166                         return NULL;
167                 }
168         } else { // If src == 0, generate a random port number with the high bit set
169                 if(utcp->nconnections >= 32767) {
170                         errno = ENOMEM;
171                         return NULL;
172                 }
173                 src = rand() | 0x8000;
174                 while(find_connection(utcp, src, dst))
175                         src++;
176         }
177
178         // Allocate memory for the new connection
179
180         if(utcp->nconnections >= utcp->nallocated) {
181                 if(!utcp->nallocated)
182                         utcp->nallocated = 4;
183                 else
184                         utcp->nallocated *= 2;
185                 struct utcp_connection **new_array = realloc(utcp->connections, utcp->nallocated * sizeof *utcp->connections);
186                 if(!new_array)
187                         return NULL;
188                 utcp->connections = new_array;
189         }
190
191         struct utcp_connection *c = calloc(1, sizeof *c);
192         if(!c)
193                 return NULL;
194
195         c->sndbufsize = DEFAULT_SNDBUFSIZE;
196         c->maxsndbufsize = DEFAULT_MAXSNDBUFSIZE;
197         c->sndbuf = malloc(c->sndbufsize);
198         if(!c->sndbuf) {
199                 free(c);
200                 return NULL;
201         }
202
203         // Fill in the details
204
205         c->src = src;
206         c->dst = dst;
207         c->snd.iss = rand();
208         c->snd.una = c->snd.iss;
209         c->snd.nxt = c->snd.iss + 1;
210         c->rcv.wnd = utcp->mtu;
211         c->snd.last = c->snd.nxt;
212         c->snd.cwnd = utcp->mtu;
213         c->utcp = utcp;
214
215         // Add it to the sorted list of connections
216
217         utcp->connections[utcp->nconnections++] = c;
218         qsort(utcp->connections, utcp->nconnections, sizeof *utcp->connections, compare);
219
220         return c;
221 }
222
223 struct utcp_connection *utcp_connect(struct utcp *utcp, uint16_t dst, utcp_recv_t recv, void *priv) {
224         struct utcp_connection *c = allocate_connection(utcp, 0, dst);
225         if(!c)
226                 return NULL;
227
228         c->recv = recv;
229         c->priv = priv;
230
231         struct hdr hdr;
232
233         hdr.src = c->src;
234         hdr.dst = c->dst;
235         hdr.seq = c->snd.iss;
236         hdr.ack = 0;
237         hdr.wnd = c->rcv.wnd;
238         hdr.ctl = SYN;
239         hdr.aux = 0;
240
241         set_state(c, SYN_SENT);
242
243         print_packet(utcp, "send", &hdr, sizeof hdr);
244         utcp->send(utcp, &hdr, sizeof hdr);
245
246         gettimeofday(&c->conn_timeout, NULL);
247         c->conn_timeout.tv_sec += utcp->timeout;
248
249         return c;
250 }
251
252 void utcp_accept(struct utcp_connection *c, utcp_recv_t recv, void *priv) {
253         if(c->reapable || c->state != SYN_RECEIVED) {
254                 debug("Error: accept() called on invalid connection %p in state %s\n", c, strstate[c->state]);
255                 return;
256         }
257
258         debug("%p accepted, %p %p\n", c, recv, priv);
259         c->recv = recv;
260         c->priv = priv;
261         set_state(c, ESTABLISHED);
262 }
263
264 static void ack(struct utcp_connection *c, bool sendatleastone) {
265         int32_t left = seqdiff(c->snd.last, c->snd.nxt);
266         int32_t cwndleft = c->snd.cwnd - seqdiff(c->snd.nxt, c->snd.una);
267         char *data = c->sndbuf + seqdiff(c->snd.nxt, c->snd.una);
268
269         assert(left >= 0);
270
271         if(cwndleft <= 0)
272                 cwndleft = 0;
273
274         if(cwndleft < left)
275                 left = cwndleft;
276
277         if(!left && !sendatleastone)
278                 return;
279
280         struct {
281                 struct hdr hdr;
282                 char data[];
283         } *pkt;
284
285         pkt = malloc(sizeof pkt->hdr + c->utcp->mtu);
286         if(!pkt->data)
287                 return;
288
289         pkt->hdr.src = c->src;
290         pkt->hdr.dst = c->dst;
291         pkt->hdr.ack = c->rcv.nxt;
292         pkt->hdr.wnd = c->snd.wnd;
293         pkt->hdr.ctl = ACK;
294         pkt->hdr.aux = 0;
295
296         do {
297                 uint32_t seglen = left > c->utcp->mtu ? c->utcp->mtu : left;
298                 pkt->hdr.seq = c->snd.nxt;
299
300                 memcpy(pkt->data, data, seglen);
301
302                 c->snd.nxt += seglen;
303                 data += seglen;
304                 left -= seglen;
305
306                 if(c->state != ESTABLISHED && !left && seglen) {
307                         switch(c->state) {
308                         case FIN_WAIT_1:
309                         case CLOSING:
310                                 seglen--;
311                                 pkt->hdr.ctl |= FIN;
312                                 break;
313                         default:
314                                 break;
315                         }
316                 }
317
318                 print_packet(c->utcp, "send", pkt, sizeof pkt->hdr + seglen);
319                 c->utcp->send(c->utcp, pkt, sizeof pkt->hdr + seglen);
320         } while(left);
321
322         free(pkt);
323 }
324
325 ssize_t utcp_send(struct utcp_connection *c, const void *data, size_t len) {
326         if(c->reapable) {
327                 debug("Error: send() called on closed connection %p\n", c);
328                 errno = EBADF;
329                 return -1;
330         }
331
332         switch(c->state) {
333         case CLOSED:
334         case LISTEN:
335         case SYN_SENT:
336         case SYN_RECEIVED:
337                 debug("Error: send() called on unconnected connection %p\n", c);
338                 errno = ENOTCONN;
339                 return -1;
340         case ESTABLISHED:
341         case CLOSE_WAIT:
342                 break;
343         case FIN_WAIT_1:
344         case FIN_WAIT_2:
345         case CLOSING:
346         case LAST_ACK:
347         case TIME_WAIT:
348                 debug("Error: send() called on closing connection %p\n", c);
349                 errno = EPIPE;
350                 return -1;
351         }
352
353         // Add data to send buffer
354
355         if(!len)
356                 return 0;
357
358         if(!data) {
359                 errno = EFAULT;
360                 return -1;
361         }
362
363         uint32_t bufused = seqdiff(c->snd.nxt, c->snd.una);
364
365         /* Check our send buffer.
366          * - If it's big enough, just put the data in there.
367          * - If not, decide whether to enlarge if possible.
368          * - Cap len so it doesn't overflow our buffer.
369          */
370
371         if(len > c->sndbufsize - bufused && c->sndbufsize < c->maxsndbufsize) {
372                 uint32_t newbufsize;
373                 if(c->sndbufsize > c->maxsndbufsize / 2)
374                         newbufsize = c->maxsndbufsize;
375                 else
376                         newbufsize = c->sndbufsize * 2;
377                 if(bufused + len > newbufsize) {
378                         if(bufused + len > c->maxsndbufsize)
379                                 newbufsize = c->maxsndbufsize;
380                         else
381                                 newbufsize = bufused + len;
382                 }
383                 char *newbuf = realloc(c->sndbuf, newbufsize);
384                 if(newbuf) {
385                         c->sndbuf = newbuf;
386                         c->sndbufsize = newbufsize;
387                 }
388         }
389
390         if(len > c->sndbufsize - bufused)
391                 len = c->sndbufsize - bufused;
392
393         if(!len) {
394                 errno == EWOULDBLOCK;
395                 return 0;
396         }
397
398         memcpy(c->sndbuf + bufused, data, len);
399         c->snd.last += len;
400
401         ack(c, false);
402         return len;
403 }
404
405 static void swap_ports(struct hdr *hdr) {
406         uint16_t tmp = hdr->src;
407         hdr->src = hdr->dst;
408         hdr->dst = tmp;
409 }
410
411 ssize_t utcp_recv(struct utcp *utcp, const void *data, size_t len) {
412         if(!utcp) {
413                 errno = EFAULT;
414                 return -1;
415         }
416
417         if(!len)
418                 return 0;
419
420         if(!data) {
421                 errno = EFAULT;
422                 return -1;
423         }
424
425         print_packet(utcp, "recv", data, len);
426
427         // Drop packets smaller than the header
428
429         struct hdr hdr;
430         if(len < sizeof hdr) {
431                 errno = EBADMSG;
432                 return -1;
433         }
434
435         // Make a copy from the potentially unaligned data to a struct hdr
436
437         memcpy(&hdr, data, sizeof hdr);
438         data += sizeof hdr;
439         len -= sizeof hdr;
440
441         // Drop packets with an unknown CTL flag
442
443         if(hdr.ctl & ~(SYN | ACK | RST | FIN)) {
444                 errno = EBADMSG;
445                 return -1;
446         }
447
448         // Try to match the packet to an existing connection
449
450         struct utcp_connection *c = find_connection(utcp, hdr.dst, hdr.src);
451
452         // Is it for a new connection?
453
454         if(!c) {
455                 // Ignore RST packets
456
457                 if(hdr.ctl & RST)
458                         return 0;
459
460                 // Is it a SYN packet and are we LISTENing?
461
462                 if(hdr.ctl & SYN && !(hdr.ctl & ACK) && utcp->accept) {
463                         // If we don't want to accept it, send a RST back
464                         if((utcp->pre_accept && !utcp->pre_accept(utcp, hdr.dst))) {
465                                 len = 1;
466                                 goto reset;
467                         }
468
469                         // Try to allocate memory, otherwise send a RST back
470                         c = allocate_connection(utcp, hdr.dst, hdr.src);
471                         if(!c) {
472                                 len = 1;
473                                 goto reset;
474                         }
475
476                         // Return SYN+ACK, go to SYN_RECEIVED state
477                         c->snd.wnd = hdr.wnd;
478                         c->rcv.irs = hdr.seq;
479                         c->rcv.nxt = c->rcv.irs + 1;
480                         set_state(c, SYN_RECEIVED);
481
482                         hdr.dst = c->dst;
483                         hdr.src = c->src;
484                         hdr.ack = c->rcv.irs + 1;
485                         hdr.seq = c->snd.iss;
486                         hdr.ctl = SYN | ACK;
487                         print_packet(c->utcp, "send", &hdr, sizeof hdr);
488                         utcp->send(utcp, &hdr, sizeof hdr);
489                 } else {
490                         // No, we don't want your packets, send a RST back
491                         len = 1;
492                         goto reset;
493                 }
494
495                 return 0;
496         }
497
498         debug("%p state %s\n", c->utcp, strstate[c->state]);
499
500         // In case this is for a CLOSED connection, ignore the packet.
501         // TODO: make it so incoming packets can never match a CLOSED connection.
502
503         if(c->state == CLOSED)
504                 return 0;
505
506         // It is for an existing connection.
507
508         // 1. Drop invalid packets.
509
510         // 1a. Drop packets that should not happen in our current state.
511
512         switch(c->state) {
513         case SYN_SENT:
514         case SYN_RECEIVED:
515         case ESTABLISHED:
516         case FIN_WAIT_1:
517         case FIN_WAIT_2:
518         case CLOSE_WAIT:
519         case CLOSING:
520         case LAST_ACK:
521         case TIME_WAIT:
522                 break;
523         default:
524                 abort();
525         }
526
527         // 1b. Drop packets with a sequence number not in our receive window.
528
529         bool acceptable;
530
531         if(c->state == SYN_SENT)
532                 acceptable = true;
533
534         // TODO: handle packets overlapping c->rcv.nxt.
535 #if 0
536         // Only use this when accepting out-of-order packets.
537         else if(len == 0)
538                 if(c->rcv.wnd == 0)
539                         acceptable = hdr.seq == c->rcv.nxt;
540                 else
541                         acceptable = (seqdiff(hdr.seq, c->rcv.nxt) >= 0 && seqdiff(hdr.seq, c->rcv.nxt + c->rcv.wnd) < 0);
542         else
543                 if(c->rcv.wnd == 0)
544                         // We don't accept data when the receive window is zero.
545                         acceptable = false;
546                 else
547                         // Both start and end of packet must be within the receive window
548                         acceptable = (seqdiff(hdr.seq, c->rcv.nxt) >= 0 && seqdiff(hdr.seq, c->rcv.nxt + c->rcv.wnd) < 0)
549                                 || (seqdiff(hdr.seq + len + 1, c->rcv.nxt) >= 0 && seqdiff(hdr.seq + len - 1, c->rcv.nxt + c->rcv.wnd) < 0);
550 #else
551         if(c->state != SYN_SENT)
552                 acceptable = hdr.seq == c->rcv.nxt;
553 #endif
554
555         if(!acceptable) {
556                 debug("Packet not acceptable, %u  <= %u + %zu < %u\n", c->rcv.nxt, hdr.seq, len, c->rcv.nxt + c->rcv.wnd);
557                 // Ignore unacceptable RST packets.
558                 if(hdr.ctl & RST)
559                         return 0;
560                 // Otherwise, send an ACK back in the hope things improve.
561                 goto ack;
562         }
563
564         c->snd.wnd = hdr.wnd; // TODO: move below
565
566         // 1c. Drop packets with an invalid ACK.
567         // ackno should not roll back, and it should also not be bigger than snd.nxt.
568
569         if(hdr.ctl & ACK && (seqdiff(hdr.ack, c->snd.nxt) > 0 || seqdiff(hdr.ack, c->snd.una) < 0)) {
570                 debug("Packet ack seqno out of range, %u %u %u\n", hdr.ack, c->snd.una, c->snd.nxt);
571                 // Ignore unacceptable RST packets.
572                 if(hdr.ctl & RST)
573                         return 0;
574                 goto reset;
575         }
576
577         // 2. Handle RST packets
578
579         if(hdr.ctl & RST) {
580                 switch(c->state) {
581                 case SYN_SENT:
582                         if(!(hdr.ctl & ACK))
583                                 return 0;
584                         // The peer has refused our connection.
585                         set_state(c, CLOSED);
586                         errno = ECONNREFUSED;
587                         if(c->recv)
588                                 c->recv(c, NULL, 0);
589                         return 0;
590                 case SYN_RECEIVED:
591                         if(hdr.ctl & ACK)
592                                 return 0;
593                         // We haven't told the application about this connection yet. Silently delete.
594                         free_connection(c);
595                         return 0;
596                 case ESTABLISHED:
597                 case FIN_WAIT_1:
598                 case FIN_WAIT_2:
599                 case CLOSE_WAIT:
600                         if(hdr.ctl & ACK)
601                                 return 0;
602                         // The peer has aborted our connection.
603                         set_state(c, CLOSED);
604                         errno = ECONNRESET;
605                         if(c->recv)
606                                 c->recv(c, NULL, 0);
607                         return 0;
608                 case CLOSING:
609                 case LAST_ACK:
610                 case TIME_WAIT:
611                         if(hdr.ctl & ACK)
612                                 return 0;
613                         // As far as the application is concerned, the connection has already been closed.
614                         // If it has called utcp_close() already, we can immediately free this connection.
615                         if(c->reapable) {
616                                 free_connection(c);
617                                 return 0;
618                         }
619                         // Otherwise, immediately move to the CLOSED state.
620                         set_state(c, CLOSED);
621                         return 0;
622                 default:
623                         abort();
624                 }
625         }
626
627         // 3. Advance snd.una
628
629         uint32_t advanced = seqdiff(hdr.ack, c->snd.una);
630         uint32_t prevrcvnxt = c->rcv.nxt;
631
632         if(advanced) {
633                 int32_t data_acked = advanced;
634
635                 switch(c->state) {
636                         case SYN_SENT:
637                         case SYN_RECEIVED:
638                                 data_acked--;
639                                 break;
640                         // TODO: handle FIN as well.
641                         default:
642                                 break;
643                 }
644
645                 assert(data_acked >= 0);
646
647                 int32_t bufused = seqdiff(c->snd.last, c->snd.una);
648                 assert(data_acked <= bufused);
649
650                 // Make room in the send buffer.
651                 // TODO: try to avoid memmoving too much. Circular buffer?
652                 uint32_t left = bufused - data_acked;
653                 if(data_acked && left)
654                         memmove(c->sndbuf, c->sndbuf + data_acked, left);
655
656                 c->snd.una = hdr.ack;
657
658                 c->dupack = 0;
659                 c->snd.cwnd += utcp->mtu;
660                 if(c->snd.cwnd > c->maxsndbufsize)
661                         c->snd.cwnd = c->maxsndbufsize;
662
663                 // Check if we have sent a FIN that is now ACKed.
664                 switch(c->state) {
665                 case FIN_WAIT_1:
666                         if(c->snd.una == c->snd.last)
667                                 set_state(c, FIN_WAIT_2);
668                         break;
669                 case CLOSING:
670                         if(c->snd.una == c->snd.last) {
671                                 gettimeofday(&c->conn_timeout, NULL);
672                                 c->conn_timeout.tv_sec += 60;
673                                 set_state(c, TIME_WAIT);
674                         }
675                         break;
676                 default:
677                         break;
678                 }
679         } else {
680                 if(!len) {
681                         c->dupack++;
682                         if(c->dupack >= 3) {
683                                 debug("Triplicate ACK\n");
684                                 //TODO: Resend one packet and go to fast recovery mode. See RFC 6582.
685                                 //abort();
686                         }
687                 }
688         }
689
690         // 4. Update timers
691
692         if(advanced) {
693                 timerclear(&c->conn_timeout); // It will be set anew in utcp_timeout() if c->snd.una != c->snd.nxt.
694                 if(c->snd.una == c->snd.nxt)
695                         timerclear(&c->rtrx_timeout);
696         }
697
698         // 5. Process SYN stuff
699
700         if(hdr.ctl & SYN) {
701                 switch(c->state) {
702                 case SYN_SENT:
703                         // This is a SYNACK. It should always have ACKed the SYN.
704                         if(!advanced)
705                                 goto reset;
706                         c->rcv.irs = hdr.seq;
707                         c->rcv.nxt = hdr.seq;
708                         set_state(c, ESTABLISHED);
709                         // TODO: notify application of this somehow.
710                         break;
711                 case SYN_RECEIVED:
712                 case ESTABLISHED:
713                 case FIN_WAIT_1:
714                 case FIN_WAIT_2:
715                 case CLOSE_WAIT:
716                 case CLOSING:
717                 case LAST_ACK:
718                 case TIME_WAIT:
719                         // Ehm, no. We should never receive a second SYN.
720                         goto reset;
721                 default:
722                         abort();
723                 }
724
725                 // SYN counts as one sequence number
726                 c->rcv.nxt++;
727         }
728
729         // 6. Process new data
730
731         if(c->state == SYN_RECEIVED) {
732                 // This is the ACK after the SYNACK. It should always have ACKed the SYNACK.
733                 if(!advanced)
734                         goto reset;
735
736                 // Are we still LISTENing?
737                 if(utcp->accept)
738                         utcp->accept(c, c->src);
739
740                 if(c->state != ESTABLISHED) {
741                         set_state(c, CLOSED);
742                         c->reapable = true;
743                         goto reset;
744                 }
745         }
746
747         if(len) {
748                 switch(c->state) {
749                 case SYN_SENT:
750                 case SYN_RECEIVED:
751                         // This should never happen.
752                         abort();
753                 case ESTABLISHED:
754                 case FIN_WAIT_1:
755                 case FIN_WAIT_2:
756                         break;
757                 case CLOSE_WAIT:
758                 case CLOSING:
759                 case LAST_ACK:
760                 case TIME_WAIT:
761                         // Ehm no, We should never receive more data after a FIN.
762                         goto reset;
763                 default:
764                         abort();
765                 }
766
767                 ssize_t rxd;
768
769                 if(c->recv) {
770                         rxd = c->recv(c, data, len);
771                         if(rxd != len) {
772                                 // TODO: once we have a receive buffer, handle the application not accepting all data.
773                                 fprintf(stderr, "c->recv(%p, %p, %zu) returned %zd\n", c, data, len, rxd);
774                                 abort();
775                         }
776                         if(rxd < 0)
777                                 rxd = 0;
778                         else if(rxd > len)
779                                 rxd = len; // Bad application, bad!
780                 } else {
781                         rxd = len;
782                 }
783
784                 c->rcv.nxt += len;
785         }
786
787         // 7. Process FIN stuff
788
789         if(hdr.ctl & FIN) {
790                 switch(c->state) {
791                 case SYN_SENT:
792                 case SYN_RECEIVED:
793                         // This should never happen.
794                         abort();
795                 case ESTABLISHED:
796                         set_state(c, CLOSE_WAIT);
797                         break;
798                 case FIN_WAIT_1:
799                         set_state(c, CLOSING);
800                         break;
801                 case FIN_WAIT_2:
802                         gettimeofday(&c->conn_timeout, NULL);
803                         c->conn_timeout.tv_sec += 60;
804                         set_state(c, TIME_WAIT);
805                         break;
806                 case CLOSE_WAIT:
807                 case CLOSING:
808                 case LAST_ACK:
809                 case TIME_WAIT:
810                         // Ehm, no. We should never receive a second FIN.
811                         goto reset;
812                 default:
813                         abort();
814                 }
815
816                 // FIN counts as one sequence number
817                 c->rcv.nxt++;
818                 len++;
819
820                 // Inform the application that the peer closed the connection.
821                 if(c->recv) {
822                         errno = 0;
823                         c->recv(c, NULL, 0);
824                 }
825         }
826
827         // Now we send something back if:
828         // - we advanced rcv.nxt (ie, we got some data that needs to be ACKed)
829         //   -> sendatleastone = true
830         // - or we got an ack, so we should maybe send a bit more data
831         //   -> sendatleastone = false
832
833 ack:
834         ack(c, prevrcvnxt != c->rcv.nxt);
835         return 0;
836
837 reset:
838         swap_ports(&hdr);
839         hdr.wnd = 0;
840         if(hdr.ctl & ACK) {
841                 hdr.seq = hdr.ack;
842                 hdr.ctl = RST;
843         } else {
844                 hdr.ack = hdr.seq + len;
845                 hdr.seq = 0;
846                 hdr.ctl = RST | ACK;
847         }
848         print_packet(utcp, "send", &hdr, sizeof hdr);
849         utcp->send(utcp, &hdr, sizeof hdr);
850         return 0;
851
852 }
853
854 int utcp_shutdown(struct utcp_connection *c, int dir) {
855         debug("%p shutdown %d\n", c ? c->utcp : NULL, dir);
856         if(!c) {
857                 errno = EFAULT;
858                 return -1;
859         }
860
861         if(c->reapable) {
862                 debug("Error: shutdown() called on closed connection %p\n", c);
863                 errno = EBADF;
864                 return -1;
865         }
866
867         // TODO: handle dir
868
869         switch(c->state) {
870         case CLOSED:
871                 return 0;
872         case LISTEN:
873         case SYN_SENT:
874                 set_state(c, CLOSED);
875                 return 0;
876
877         case SYN_RECEIVED:
878         case ESTABLISHED:
879                 set_state(c, FIN_WAIT_1);
880                 break;
881         case FIN_WAIT_1:
882         case FIN_WAIT_2:
883                 return 0;
884         case CLOSE_WAIT:
885                 set_state(c, CLOSING);
886                 break;
887
888         case CLOSING:
889         case LAST_ACK:
890         case TIME_WAIT:
891                 return 0;
892         }
893
894         c->snd.last++;
895
896         ack(c, false);
897         return 0;
898 }
899
900 int utcp_close(struct utcp_connection *c) {
901         if(utcp_shutdown(c, SHUT_RDWR))
902                 return -1;
903         c->reapable = true;
904         return 0;
905 }
906
907 int utcp_abort(struct utcp_connection *c) {
908         if(!c) {
909                 errno = EFAULT;
910                 return -1;
911         }
912
913         if(c->reapable) {
914                 debug("Error: abort() called on closed connection %p\n", c);
915                 errno = EBADF;
916                 return -1;
917         }
918
919         c->reapable = true;
920
921         switch(c->state) {
922         case CLOSED:
923                 return 0;
924         case LISTEN:
925         case SYN_SENT:
926         case CLOSING:
927         case LAST_ACK:
928         case TIME_WAIT:
929                 set_state(c, CLOSED);
930                 return 0;
931
932         case SYN_RECEIVED:
933         case ESTABLISHED:
934         case FIN_WAIT_1:
935         case FIN_WAIT_2:
936         case CLOSE_WAIT:
937                 set_state(c, CLOSED);
938                 break;
939         }
940
941         // Send RST
942
943         struct hdr hdr;
944
945         hdr.src = c->src;
946         hdr.dst = c->dst;
947         hdr.seq = c->snd.nxt;
948         hdr.ack = 0;
949         hdr.wnd = 0;
950         hdr.ctl = RST;
951
952         print_packet(c->utcp, "send", &hdr, sizeof hdr);
953         c->utcp->send(c->utcp, &hdr, sizeof hdr);
954         return 0;
955 }
956
957 static void retransmit(struct utcp_connection *c) {
958         if(c->state == CLOSED || c->snd.nxt == c->snd.una)
959                 return;
960
961         struct utcp *utcp = c->utcp;
962
963         struct {
964                 struct hdr hdr;
965                 char data[];
966         } *pkt;
967
968         pkt = malloc(sizeof pkt->hdr + c->utcp->mtu);
969         if(!pkt)
970                 return;
971
972         pkt->hdr.src = c->src;
973         pkt->hdr.dst = c->dst;
974
975         switch(c->state) {
976                 case LISTEN:
977                         // TODO: this should not happen
978                         break;
979
980                 case SYN_SENT:
981                         pkt->hdr.seq = c->snd.iss;
982                         pkt->hdr.ack = 0;
983                         pkt->hdr.wnd = c->rcv.wnd;
984                         pkt->hdr.ctl = SYN;
985                         print_packet(c->utcp, "rtrx", pkt, sizeof pkt->hdr);
986                         utcp->send(utcp, pkt, sizeof pkt->hdr);
987                         break;
988
989                 case SYN_RECEIVED:
990                         pkt->hdr.seq = c->snd.nxt;
991                         pkt->hdr.ack = c->rcv.nxt;
992                         pkt->hdr.ctl = SYN | ACK;
993                         print_packet(c->utcp, "rtrx", pkt, sizeof pkt->hdr);
994                         utcp->send(utcp, pkt, sizeof pkt->hdr);
995                         break;
996
997                 case ESTABLISHED:
998                 case FIN_WAIT_1:
999                         pkt->hdr.seq = c->snd.una;
1000                         pkt->hdr.ack = c->rcv.nxt;
1001                         pkt->hdr.ctl = ACK;
1002                         uint32_t len = seqdiff(c->snd.nxt, c->snd.una);
1003                         if(c->state == FIN_WAIT_1)
1004                                 len--;
1005                         if(len > utcp->mtu)
1006                                 len = utcp->mtu;
1007                         else {
1008                                 if(c->state == FIN_WAIT_1)
1009                                         pkt->hdr.ctl |= FIN;
1010                         }
1011                         memcpy(pkt->data, c->sndbuf, len);
1012                         print_packet(c->utcp, "rtrx", pkt, sizeof pkt->hdr + len);
1013                         utcp->send(utcp, pkt, sizeof pkt->hdr + len);
1014                         break;
1015
1016                 default:
1017                         // TODO: implement
1018                         abort();
1019         }
1020
1021         free(pkt);
1022 }
1023
1024 /* Handle timeouts.
1025  * One call to this function will loop through all connections,
1026  * checking if something needs to be resent or not.
1027  * The return value is the time to the next timeout in milliseconds,
1028  * or maybe a negative value if the timeout is infinite.
1029  */
1030 int utcp_timeout(struct utcp *utcp) {
1031         struct timeval now;
1032         gettimeofday(&now, NULL);
1033         struct timeval next = {now.tv_sec + 3600, now.tv_usec};
1034
1035         for(int i = 0; i < utcp->nconnections; i++) {
1036                 struct utcp_connection *c = utcp->connections[i];
1037                 if(!c)
1038                         continue;
1039
1040                 if(c->state == CLOSED) {
1041                         if(c->reapable) {
1042                                 debug("Reaping %p\n", c);
1043                                 free_connection(c);
1044                                 i--;
1045                         }
1046                         continue;
1047                 }
1048
1049                 if(timerisset(&c->conn_timeout) && timercmp(&c->conn_timeout, &now, <)) {
1050                         errno = ETIMEDOUT;
1051                         c->state = CLOSED;
1052                         if(c->recv)
1053                                 c->recv(c, NULL, 0);
1054                         continue;
1055                 }
1056
1057                 if(timerisset(&c->rtrx_timeout) && timercmp(&c->rtrx_timeout, &now, <)) {
1058                         retransmit(c);
1059                 }
1060
1061                 if(c->poll && c->sndbufsize < c->maxsndbufsize / 2 && (c->state == ESTABLISHED || c->state == CLOSE_WAIT))
1062                         c->poll(c, c->maxsndbufsize - c->sndbufsize);
1063
1064                 if(timerisset(&c->conn_timeout) && timercmp(&c->conn_timeout, &next, <))
1065                         next = c->conn_timeout;
1066
1067                 if(c->snd.nxt != c->snd.una) {
1068                         c->rtrx_timeout = now;
1069                         c->rtrx_timeout.tv_sec++;
1070                 } else {
1071                         timerclear(&c->rtrx_timeout);
1072                 }
1073
1074                 if(timerisset(&c->rtrx_timeout) && timercmp(&c->rtrx_timeout, &next, <))
1075                         next = c->rtrx_timeout;
1076         }
1077
1078         struct timeval diff;
1079         timersub(&next, &now, &diff);
1080         if(diff.tv_sec < 0)
1081                 return 0;
1082         return diff.tv_sec * 1000 + diff.tv_usec / 1000;
1083 }
1084
1085 struct utcp *utcp_init(utcp_accept_t accept, utcp_pre_accept_t pre_accept, utcp_send_t send, void *priv) {
1086         struct utcp *utcp = calloc(1, sizeof *utcp);
1087         if(!utcp)
1088                 return NULL;
1089
1090         if(!send) {
1091                 errno = EFAULT;
1092                 return NULL;
1093         }
1094
1095         utcp->accept = accept;
1096         utcp->pre_accept = pre_accept;
1097         utcp->send = send;
1098         utcp->priv = priv;
1099         utcp->mtu = 1000;
1100         utcp->timeout = 60;
1101
1102         return utcp;
1103 }
1104
1105 void utcp_exit(struct utcp *utcp) {
1106         if(!utcp)
1107                 return;
1108         for(int i = 0; i < utcp->nconnections; i++) {
1109                 if(!utcp->connections[i]->reapable)
1110                         debug("Warning, freeing unclosed connection %p\n", utcp->connections[i]);
1111                 free(utcp->connections[i]->sndbuf);
1112                 free(utcp->connections[i]);
1113         }
1114         free(utcp->connections);
1115         free(utcp);
1116 }
1117
1118 uint16_t utcp_get_mtu(struct utcp *utcp) {
1119         return utcp->mtu;
1120 }
1121
1122 void utcp_set_mtu(struct utcp *utcp, uint16_t mtu) {
1123         // TODO: handle overhead of the header
1124         utcp->mtu = mtu;
1125 }
1126
1127 int utcp_get_user_timeout(struct utcp *u) {
1128         return u->timeout;
1129 }
1130
1131 void utcp_set_user_timeout(struct utcp *u, int timeout) {
1132         u->timeout = timeout;
1133 }
1134
1135 size_t utcp_get_sndbuf(struct utcp_connection *c) {
1136         return c->maxsndbufsize;
1137 }
1138
1139 size_t utcp_get_sndbuf_free(struct utcp_connection *c) {
1140         return c->maxsndbufsize - c->sndbufsize;
1141 }
1142
1143 void utcp_set_sndbuf(struct utcp_connection *c, size_t size) {
1144         c->maxsndbufsize = size;
1145         if(c->maxsndbufsize != size)
1146                 c->maxsndbufsize = -1;
1147 }
1148
1149 bool utcp_get_nodelay(struct utcp_connection *c) {
1150         return c->nodelay;
1151 }
1152
1153 void utcp_set_nodelay(struct utcp_connection *c, bool nodelay) {
1154         c->nodelay = nodelay;
1155 }
1156
1157 bool utcp_get_keepalive(struct utcp_connection *c) {
1158         return c->keepalive;
1159 }
1160
1161 void utcp_set_keepalive(struct utcp_connection *c, bool keepalive) {
1162         c->keepalive = keepalive;
1163 }
1164
1165 size_t utcp_get_outq(struct utcp_connection *c) {
1166         return seqdiff(c->snd.nxt, c->snd.una);
1167 }
1168
1169 void utcp_set_recv_cb(struct utcp_connection *c, utcp_recv_t recv) {
1170         c->recv = recv;
1171 }
1172
1173 void utcp_set_poll_cb(struct utcp_connection *c, utcp_poll_t poll) {
1174         c->poll = poll;
1175 }