Blob


1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/queue.h>
19 #include <sys/socket.h>
20 #include <sys/uio.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <siphash.h>
25 #include <stdint.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <imsg.h>
30 #include <limits.h>
31 #include <signal.h>
32 #include <unistd.h>
34 #include "got_error.h"
36 #include "gotd.h"
37 #include "log.h"
38 #include "listen.h"
40 #ifndef nitems
41 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
42 #endif
44 struct gotd_listen_client {
45 STAILQ_ENTRY(gotd_listen_client) entry;
46 uint32_t id;
47 int fd;
48 uid_t euid;
49 };
50 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
52 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
53 static SIPHASH_KEY clients_hash_key;
54 static volatile int listen_client_cnt;
55 static int inflight;
57 struct gotd_uid_connection_counter {
58 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
59 uid_t euid;
60 int nconnections;
61 };
62 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
63 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
64 static SIPHASH_KEY uid_hash_key;
66 static struct {
67 pid_t pid;
68 const char *title;
69 int fd;
70 struct gotd_imsgev iev;
71 struct gotd_imsgev pause;
72 struct gotd_uid_connection_limit *connection_limits;
73 size_t nconnection_limits;
74 } gotd_listen;
76 static int inflight;
78 static void listen_shutdown(void);
80 static void
81 listen_sighdlr(int sig, short event, void *arg)
82 {
83 /*
84 * Normal signal handler rules don't apply because libevent
85 * decouples for us.
86 */
88 switch (sig) {
89 case SIGHUP:
90 break;
91 case SIGUSR1:
92 break;
93 case SIGTERM:
94 case SIGINT:
95 listen_shutdown();
96 /* NOTREACHED */
97 break;
98 default:
99 fatalx("unexpected signal");
103 static uint64_t
104 client_hash(uint32_t client_id)
106 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
109 static void
110 add_client(struct gotd_listen_client *client)
112 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
113 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
114 listen_client_cnt++;
117 static struct gotd_listen_client *
118 find_client(uint32_t client_id)
120 uint64_t slot;
121 struct gotd_listen_client *c;
123 slot = client_hash(client_id) % nitems(gotd_listen_clients);
124 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
125 if (c->id == client_id)
126 return c;
129 return NULL;
132 static uint32_t
133 get_client_id(void)
135 int duplicate = 0;
136 uint32_t id;
138 do {
139 id = arc4random();
140 duplicate = (find_client(id) != NULL);
141 } while (duplicate || id == 0);
143 return id;
146 static uint64_t
147 uid_hash(uid_t euid)
149 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
152 static void
153 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
155 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
156 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
159 static void
160 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
162 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
163 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
164 gotd_uid_connection_counter, entry);
167 static struct gotd_uid_connection_counter *
168 find_uid_connection_counter(uid_t euid)
170 uint64_t slot;
171 struct gotd_uid_connection_counter *c;
173 slot = uid_hash(euid) % nitems(gotd_client_uids);
174 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
175 if (c->euid == euid)
176 return c;
179 return NULL;
182 struct gotd_uid_connection_limit *
183 gotd_find_uid_connection_limit(struct gotd_uid_connection_limit *limits,
184 size_t nlimits, uid_t uid)
186 /* This array is always sorted to allow for binary search. */
187 int i, left = 0, right = nlimits - 1;
189 while (left <= right) {
190 i = ((left + right) / 2);
191 if (limits[i].uid == uid)
192 return &limits[i];
193 if (limits[i].uid > uid)
194 left = i + 1;
195 else
196 right = i - 1;
199 return NULL;
202 static const struct got_error *
203 disconnect(struct gotd_listen_client *client)
205 struct gotd_uid_connection_counter *counter;
206 uint64_t slot;
207 int client_fd;
209 log_debug("client on fd %d disconnecting", client->fd);
211 slot = client_hash(client->id) % nitems(gotd_listen_clients);
212 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
213 gotd_listen_client, entry);
215 counter = find_uid_connection_counter(client->euid);
216 if (counter) {
217 if (counter->nconnections > 0)
218 counter->nconnections--;
219 if (counter->nconnections == 0) {
220 remove_uid_connection_counter(counter);
221 free(counter);
225 client_fd = client->fd;
226 free(client);
227 inflight--;
228 listen_client_cnt--;
229 if (close(client_fd) == -1)
230 return got_error_from_errno("close");
232 return NULL;
235 static int
236 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
237 int reserve, volatile int *counter)
239 int ret;
241 if (getdtablecount() + reserve +
242 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
243 log_debug("inflight fds exceeded");
244 errno = EMFILE;
245 return -1;
248 if ((ret = accept4(fd, addr, addrlen,
249 SOCK_NONBLOCK | SOCK_CLOEXEC)) > -1) {
250 (*counter)++;
253 return ret;
256 static void
257 gotd_accept_paused(int fd, short event, void *arg)
259 event_add(&gotd_listen.iev.ev, NULL);
262 static void
263 gotd_accept(int fd, short event, void *arg)
265 struct gotd_imsgev *iev = arg;
266 struct sockaddr_storage ss;
267 struct timeval backoff;
268 socklen_t len;
269 int s = -1;
270 struct gotd_listen_client *client = NULL;
271 struct gotd_uid_connection_counter *counter = NULL;
272 struct gotd_imsg_connect iconn;
273 uid_t euid;
274 gid_t egid;
276 backoff.tv_sec = 1;
277 backoff.tv_usec = 0;
279 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
280 log_warn("event_add");
281 return;
283 if (event & EV_TIMEOUT)
284 return;
286 len = sizeof(ss);
288 /* Other backoff conditions apart from EMFILE/ENFILE? */
289 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
290 &inflight);
291 if (s == -1) {
292 switch (errno) {
293 case EINTR:
294 case EWOULDBLOCK:
295 case ECONNABORTED:
296 return;
297 case EMFILE:
298 case ENFILE:
299 event_del(&gotd_listen.iev.ev);
300 evtimer_add(&gotd_listen.pause.ev, &backoff);
301 return;
302 default:
303 log_warn("accept");
304 return;
308 if (listen_client_cnt >= GOTD_MAXCLIENTS)
309 goto err;
311 if (getpeereid(s, &euid, &egid) == -1) {
312 log_warn("getpeerid");
313 goto err;
316 counter = find_uid_connection_counter(euid);
317 if (counter == NULL) {
318 counter = calloc(1, sizeof(*counter));
319 if (counter == NULL) {
320 log_warn("%s: calloc", __func__);
321 goto err;
323 counter->euid = euid;
324 counter->nconnections = 1;
325 add_uid_connection_counter(counter);
326 } else {
327 int max_connections = GOTD_MAX_CONN_PER_UID;
328 struct gotd_uid_connection_limit *limit;
330 limit = gotd_find_uid_connection_limit(
331 gotd_listen.connection_limits,
332 gotd_listen.nconnection_limits, euid);
333 if (limit)
334 max_connections = limit->max_connections;
336 if (counter->nconnections >= max_connections) {
337 log_warnx("maximum connections exceeded for uid %d",
338 euid);
339 goto err;
341 counter->nconnections++;
344 client = calloc(1, sizeof(*client));
345 if (client == NULL) {
346 log_warn("%s: calloc", __func__);
347 goto err;
349 client->id = get_client_id();
350 client->fd = s;
351 client->euid = euid;
352 s = -1;
353 add_client(client);
354 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
355 client->fd, euid, egid);
357 memset(&iconn, 0, sizeof(iconn));
358 iconn.client_id = client->id;
359 iconn.euid = euid;
360 iconn.egid = egid;
361 s = dup(client->fd);
362 if (s == -1) {
363 log_warn("%s: dup", __func__);
364 goto err;
366 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
367 &iconn, sizeof(iconn)) == -1) {
368 log_warn("imsg compose CONNECT");
369 goto err;
372 return;
373 err:
374 inflight--;
375 if (client)
376 disconnect(client);
377 if (s != -1)
378 close(s);
381 static const struct got_error *
382 recv_disconnect(struct imsg *imsg)
384 struct gotd_imsg_disconnect idisconnect;
385 size_t datalen;
386 struct gotd_listen_client *client = NULL;
388 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
389 if (datalen != sizeof(idisconnect))
390 return got_error(GOT_ERR_PRIVSEP_LEN);
391 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
393 log_debug("client disconnecting");
395 client = find_client(idisconnect.client_id);
396 if (client == NULL)
397 return got_error(GOT_ERR_CLIENT_ID);
399 return disconnect(client);
402 static void
403 listen_dispatch(int fd, short event, void *arg)
405 const struct got_error *err = NULL;
406 struct gotd_imsgev *iev = arg;
407 struct imsgbuf *ibuf = &iev->ibuf;
408 struct imsg imsg;
409 ssize_t n;
410 int shut = 0;
412 if (event & EV_READ) {
413 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
414 fatal("imsg_read error");
415 if (n == 0) /* Connection closed. */
416 shut = 1;
419 if (event & EV_WRITE) {
420 n = msgbuf_write(&ibuf->w);
421 if (n == -1 && errno != EAGAIN)
422 fatal("msgbuf_write");
423 if (n == 0) /* Connection closed. */
424 shut = 1;
427 for (;;) {
428 if ((n = imsg_get(ibuf, &imsg)) == -1)
429 fatal("%s: imsg_get", __func__);
430 if (n == 0) /* No more messages. */
431 break;
433 switch (imsg.hdr.type) {
434 case GOTD_IMSG_DISCONNECT:
435 err = recv_disconnect(&imsg);
436 if (err)
437 log_warnx("disconnect: %s", err->msg);
438 break;
439 default:
440 log_debug("unexpected imsg %d", imsg.hdr.type);
441 break;
444 imsg_free(&imsg);
447 if (!shut) {
448 gotd_imsg_event_add(iev);
449 } else {
450 /* This pipe is dead. Remove its event handler */
451 event_del(&iev->ev);
452 event_loopexit(NULL);
456 void
457 listen_main(const char *title, int gotd_socket,
458 struct gotd_uid_connection_limit *connection_limits,
459 size_t nconnection_limits)
461 struct gotd_imsgev iev;
462 struct event evsigint, evsigterm, evsighup, evsigusr1;
464 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
465 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
467 gotd_listen.title = title;
468 gotd_listen.pid = getpid();
469 gotd_listen.fd = gotd_socket;
470 gotd_listen.connection_limits = connection_limits;
471 gotd_listen.nconnection_limits = nconnection_limits;
473 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
474 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
475 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
476 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
477 signal(SIGPIPE, SIG_IGN);
479 signal_add(&evsigint, NULL);
480 signal_add(&evsigterm, NULL);
481 signal_add(&evsighup, NULL);
482 signal_add(&evsigusr1, NULL);
484 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
485 iev.handler = listen_dispatch;
486 iev.events = EV_READ;
487 iev.handler_arg = NULL;
488 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
489 if (event_add(&iev.ev, NULL) == -1)
490 fatalx("event add");
492 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
493 gotd_accept, &iev);
494 if (event_add(&gotd_listen.iev.ev, NULL))
495 fatalx("event add");
496 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
498 event_dispatch();
500 listen_shutdown();
503 static void
504 listen_shutdown(void)
506 log_debug("shutting down");
508 free(gotd_listen.connection_limits);
509 if (gotd_listen.fd != -1)
510 close(gotd_listen.fd);
512 exit(0);