Blob


1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include "got_compat.h"
19 #include <sys/types.h>
20 #include <sys/queue.h>
21 #include <sys/socket.h>
22 #include <sys/uio.h>
24 #include <errno.h>
25 #include <event.h>
26 #include <stdint.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30 #include <imsg.h>
31 #include <limits.h>
32 #include <signal.h>
33 #include <unistd.h>
35 #include "got_error.h"
36 #include "got_path.h"
38 #include "got_compat.h"
40 #include "gotd.h"
41 #include "log.h"
42 #include "listen.h"
44 #ifndef nitems
45 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
46 #endif
48 struct gotd_listen_client {
49 STAILQ_ENTRY(gotd_listen_client) entry;
50 uint32_t id;
51 int fd;
52 uid_t euid;
53 };
54 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
56 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
57 static SIPHASH_KEY clients_hash_key;
58 static volatile int listen_client_cnt;
59 static int inflight;
61 struct gotd_uid_connection_counter {
62 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
63 uid_t euid;
64 int nconnections;
65 };
66 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
67 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
68 static SIPHASH_KEY uid_hash_key;
70 static struct {
71 pid_t pid;
72 const char *title;
73 int fd;
74 struct gotd_imsgev iev;
75 struct gotd_imsgev pause;
76 struct gotd_uid_connection_limit *connection_limits;
77 size_t nconnection_limits;
78 } gotd_listen;
80 static int inflight;
82 static void listen_shutdown(void);
84 static void
85 listen_sighdlr(int sig, short event, void *arg)
86 {
87 /*
88 * Normal signal handler rules don't apply because libevent
89 * decouples for us.
90 */
92 switch (sig) {
93 case SIGHUP:
94 break;
95 case SIGUSR1:
96 break;
97 case SIGTERM:
98 case SIGINT:
99 listen_shutdown();
100 /* NOTREACHED */
101 break;
102 default:
103 fatalx("unexpected signal");
107 static uint64_t
108 client_hash(uint32_t client_id)
110 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
113 static void
114 add_client(struct gotd_listen_client *client)
116 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
117 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
118 listen_client_cnt++;
121 static struct gotd_listen_client *
122 find_client(uint32_t client_id)
124 uint64_t slot;
125 struct gotd_listen_client *c;
127 slot = client_hash(client_id) % nitems(gotd_listen_clients);
128 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
129 if (c->id == client_id)
130 return c;
133 return NULL;
136 static uint32_t
137 get_client_id(void)
139 int duplicate = 0;
140 uint32_t id;
142 do {
143 id = arc4random();
144 duplicate = (find_client(id) != NULL);
145 } while (duplicate || id == 0);
147 return id;
150 static uint64_t
151 uid_hash(uid_t euid)
153 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
156 static void
157 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
159 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
160 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
163 static void
164 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
166 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
167 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
168 gotd_uid_connection_counter, entry);
171 static struct gotd_uid_connection_counter *
172 find_uid_connection_counter(uid_t euid)
174 uint64_t slot;
175 struct gotd_uid_connection_counter *c;
177 slot = uid_hash(euid) % nitems(gotd_client_uids);
178 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
179 if (c->euid == euid)
180 return c;
183 return NULL;
186 static const struct got_error *
187 disconnect(struct gotd_listen_client *client)
189 struct gotd_uid_connection_counter *counter;
190 uint64_t slot;
191 int client_fd;
193 log_debug("client on fd %d disconnecting", client->fd);
195 slot = client_hash(client->id) % nitems(gotd_listen_clients);
196 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
197 gotd_listen_client, entry);
199 counter = find_uid_connection_counter(client->euid);
200 if (counter) {
201 if (counter->nconnections > 0)
202 counter->nconnections--;
203 if (counter->nconnections == 0) {
204 remove_uid_connection_counter(counter);
205 free(counter);
209 client_fd = client->fd;
210 free(client);
211 inflight--;
212 listen_client_cnt--;
213 if (close(client_fd) == -1)
214 return got_error_from_errno("close");
216 return NULL;
219 static int
220 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
221 int reserve, volatile int *counter)
223 int ret;
224 int sock_flags = SOCK_NONBLOCK;
226 #ifdef SOCK_CLOEXEC
227 sock_flags |= SOCK_CLOEXEC;
228 #endif
230 if (getdtablecount() + reserve +
231 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
232 log_debug("inflight fds exceeded");
233 errno = EMFILE;
234 return -1;
236 #ifdef __APPLE__
237 /* TA: silence warning from GCC. */
238 (void)sock_flags;
239 ret = accept(fd, addr, addrlen);
240 #else
241 ret = accept4(fd, addr, addrlen, sock_flags);
242 #endif
244 if (ret > -1) {
245 (*counter)++;
248 return ret;
251 static void
252 gotd_accept_paused(int fd, short event, void *arg)
254 event_add(&gotd_listen.iev.ev, NULL);
257 static void
258 gotd_accept(int fd, short event, void *arg)
260 struct gotd_imsgev *iev = arg;
261 struct sockaddr_storage ss;
262 struct timeval backoff;
263 socklen_t len;
264 int s = -1;
265 struct gotd_listen_client *client = NULL;
266 struct gotd_uid_connection_counter *counter = NULL;
267 struct gotd_imsg_connect iconn;
268 uid_t euid;
269 gid_t egid;
271 backoff.tv_sec = 1;
272 backoff.tv_usec = 0;
274 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
275 log_warn("event_add");
276 return;
278 if (event & EV_TIMEOUT)
279 return;
281 len = sizeof(ss);
283 /* Other backoff conditions apart from EMFILE/ENFILE? */
284 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
285 &inflight);
286 if (s == -1) {
287 switch (errno) {
288 case EINTR:
289 case EWOULDBLOCK:
290 case ECONNABORTED:
291 return;
292 case EMFILE:
293 case ENFILE:
294 event_del(&gotd_listen.iev.ev);
295 evtimer_add(&gotd_listen.pause.ev, &backoff);
296 return;
297 default:
298 log_warn("accept");
299 return;
303 if (listen_client_cnt >= GOTD_MAXCLIENTS)
304 goto err;
306 if (getpeereid(s, &euid, &egid) == -1) {
307 log_warn("getpeerid");
308 goto err;
311 counter = find_uid_connection_counter(euid);
312 if (counter == NULL) {
313 counter = calloc(1, sizeof(*counter));
314 if (counter == NULL) {
315 log_warn("%s: calloc", __func__);
316 goto err;
318 counter->euid = euid;
319 counter->nconnections = 1;
320 add_uid_connection_counter(counter);
321 } else {
322 int max_connections = GOTD_MAX_CONN_PER_UID;
323 struct gotd_uid_connection_limit *limit;
325 limit = gotd_find_uid_connection_limit(
326 gotd_listen.connection_limits,
327 gotd_listen.nconnection_limits, euid);
328 if (limit)
329 max_connections = limit->max_connections;
331 if (counter->nconnections >= max_connections) {
332 log_warnx("maximum connections exceeded for uid %d",
333 euid);
334 goto err;
336 counter->nconnections++;
339 client = calloc(1, sizeof(*client));
340 if (client == NULL) {
341 log_warn("%s: calloc", __func__);
342 goto err;
344 client->id = get_client_id();
345 client->fd = s;
346 client->euid = euid;
347 s = -1;
348 add_client(client);
349 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
350 client->fd, euid, egid);
352 memset(&iconn, 0, sizeof(iconn));
353 iconn.client_id = client->id;
354 iconn.euid = euid;
355 iconn.egid = egid;
356 s = dup(client->fd);
357 if (s == -1) {
358 log_warn("%s: dup", __func__);
359 goto err;
361 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
362 &iconn, sizeof(iconn)) == -1) {
363 log_warn("imsg compose CONNECT");
364 goto err;
367 return;
368 err:
369 inflight--;
370 if (client)
371 disconnect(client);
372 if (s != -1)
373 close(s);
376 static const struct got_error *
377 recv_disconnect(struct imsg *imsg)
379 struct gotd_imsg_disconnect idisconnect;
380 size_t datalen;
381 struct gotd_listen_client *client = NULL;
383 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
384 if (datalen != sizeof(idisconnect))
385 return got_error(GOT_ERR_PRIVSEP_LEN);
386 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
388 log_debug("client disconnecting");
390 client = find_client(idisconnect.client_id);
391 if (client == NULL)
392 return got_error(GOT_ERR_CLIENT_ID);
394 return disconnect(client);
397 static void
398 listen_dispatch(int fd, short event, void *arg)
400 const struct got_error *err = NULL;
401 struct gotd_imsgev *iev = arg;
402 struct imsgbuf *ibuf = &iev->ibuf;
403 struct imsg imsg;
404 ssize_t n;
405 int shut = 0;
407 if (event & EV_READ) {
408 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
409 fatal("imsg_read error");
410 if (n == 0) /* Connection closed. */
411 shut = 1;
414 if (event & EV_WRITE) {
415 n = msgbuf_write(&ibuf->w);
416 if (n == -1 && errno != EAGAIN)
417 fatal("msgbuf_write");
418 if (n == 0) /* Connection closed. */
419 shut = 1;
422 for (;;) {
423 if ((n = imsg_get(ibuf, &imsg)) == -1)
424 fatal("%s: imsg_get", __func__);
425 if (n == 0) /* No more messages. */
426 break;
428 switch (imsg.hdr.type) {
429 case GOTD_IMSG_DISCONNECT:
430 err = recv_disconnect(&imsg);
431 if (err)
432 log_warnx("disconnect: %s", err->msg);
433 break;
434 default:
435 log_debug("unexpected imsg %d", imsg.hdr.type);
436 break;
439 imsg_free(&imsg);
442 if (!shut) {
443 gotd_imsg_event_add(iev);
444 } else {
445 /* This pipe is dead. Remove its event handler */
446 event_del(&iev->ev);
447 event_loopexit(NULL);
451 void
452 listen_main(const char *title, int gotd_socket,
453 struct gotd_uid_connection_limit *connection_limits,
454 size_t nconnection_limits)
456 struct gotd_imsgev iev;
457 struct event evsigint, evsigterm, evsighup, evsigusr1;
459 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
460 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
462 gotd_listen.title = title;
463 gotd_listen.pid = getpid();
464 gotd_listen.fd = gotd_socket;
465 gotd_listen.connection_limits = connection_limits;
466 gotd_listen.nconnection_limits = nconnection_limits;
468 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
469 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
470 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
471 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
472 signal(SIGPIPE, SIG_IGN);
474 signal_add(&evsigint, NULL);
475 signal_add(&evsigterm, NULL);
476 signal_add(&evsighup, NULL);
477 signal_add(&evsigusr1, NULL);
479 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
480 iev.handler = listen_dispatch;
481 iev.events = EV_READ;
482 iev.handler_arg = NULL;
483 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
484 if (event_add(&iev.ev, NULL) == -1)
485 fatalx("event add");
487 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
488 gotd_accept, &iev);
489 if (event_add(&gotd_listen.iev.ev, NULL))
490 fatalx("event add");
491 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
493 event_dispatch();
495 listen_shutdown();
498 static void
499 listen_shutdown(void)
501 log_debug("%s: shutting down", gotd_listen.title);
503 free(gotd_listen.connection_limits);
504 if (gotd_listen.fd != -1)
505 close(gotd_listen.fd);
507 exit(0);