2 * Copyright (c) 2019 Ori Bernstein <ori@openbsd.org>
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #include <sys/types.h>
18 #include <sys/queue.h>
22 #include <sys/syslimits.h>
38 #include "got_error.h"
39 #include "got_object.h"
41 #include "got_version.h"
43 #include "got_lib_sha1.h"
44 #include "got_lib_delta.h"
45 #include "got_lib_object.h"
46 #include "got_lib_object_parse.h"
47 #include "got_lib_privsep.h"
48 #include "got_lib_pack.h"
51 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
54 #define GOT_PKTMAX 65536
56 struct got_object *indexed;
58 static char *fetchbranch;
59 static struct got_object_id zhash = {.sha1={0}};
61 static const struct got_error *
62 readn(ssize_t *off, int fd, void *buf, size_t n)
68 r = read(fd, buf + *off, n - *off);
70 return got_error_from_errno("read");
78 static const struct got_error *
84 fprintf(stderr, "writepkt: 0000\n");
86 w = write(fd, "0000", 4);
88 return got_error_from_errno("write");
90 return got_error(GOT_ERR_IO);
95 * Packet header contains a 4-byte hexstring which specifies the length
96 * of data which follows.
98 static const struct got_error *
99 read_pkthdr(int *datalen, int fd)
101 static const struct got_error *err = NULL;
110 err = readn(&r, fd, lenstr, 4);
113 if (r == 0) /* implicit "0000" */
116 return got_error_msg(GOT_ERR_BAD_PACKET,
117 "wrong packet header length");
120 for (i = 0; i < 4; i++) {
121 if (!isxdigit(lenstr[i]))
122 return got_error_msg(GOT_ERR_BAD_PACKET,
123 "packet length not specified in hex");
126 len = strtol(lenstr, &e, 16);
127 if (lenstr[0] == '\0' || *e != '\0')
128 return got_error(GOT_ERR_BAD_PACKET);
129 if (errno == ERANGE && (len == LONG_MAX || len == LONG_MIN))
130 return got_error_msg(GOT_ERR_BAD_PACKET, "bad packet length");
131 if (len > INT_MAX || len < INT_MIN)
132 return got_error_msg(GOT_ERR_BAD_PACKET, "bad packet length");
137 return got_error_msg(GOT_ERR_BAD_PACKET, "packet too short");
144 static const struct got_error *
145 readpkt(int *outlen, int fd, char *buf, int buflen)
147 const struct got_error *err = NULL;
151 err = read_pkthdr(&datalen, fd);
155 if (datalen > buflen)
156 return got_error(GOT_ERR_NO_SPACE);
158 err = readn(&n, fd, buf, datalen);
162 return got_error_msg(GOT_ERR_BAD_PACKET, "short packet");
168 static const struct got_error *
169 writepkt(int fd, char *buf, int nbuf)
175 if (snprintf(len, sizeof(len), "%04x", nbuf + 4) >= sizeof(len))
176 return got_error(GOT_ERR_NO_SPACE);
177 w = write(fd, len, 4);
179 return got_error_from_errno("write");
181 return got_error(GOT_ERR_IO);
182 w = write(fd, buf, nbuf);
184 return got_error_from_errno("write");
186 return got_error(GOT_ERR_IO);
188 fprintf(stderr, "writepkt: %s:\t", len);
189 fwrite(buf, 1, nbuf, stderr);
190 for (i = 0; i < nbuf; i++) {
192 fputc(buf[i], stderr);
199 static const struct got_error *
200 match_remote_ref(struct got_pathlist_head *have_refs, struct got_object_id *id,
201 char *refname, char *id_str)
203 struct got_pathlist_entry *pe;
205 memset(id, 0, sizeof(*id));
207 TAILQ_FOREACH(pe, have_refs, entry) {
208 if (strcmp(pe->path, refname) == 0) {
209 if (!got_parse_sha1_digest(id->sha1, id_str))
210 return got_error(GOT_ERR_BAD_OBJ_ID_STR);
217 static const struct got_error *
218 check_pack_hash(int fd, size_t sz, uint8_t *hcomp)
220 const struct got_error *err = NULL;
222 uint8_t hexpect[SHA1_DIGEST_LENGTH];
223 uint8_t buf[32 * 1024];
226 if (sz < sizeof(struct got_packfile_hdr) + SHA1_DIGEST_LENGTH)
227 return got_error(GOT_ERR_BAD_PACKFILE);
231 while (n < sz - 20) {
233 if (sz - n - 20 < sizeof(buf))
235 err = readn(&r, fd, buf, nr);
239 return got_error(GOT_ERR_BAD_PACKFILE);
240 SHA1Update(&ctx, buf, nr);
243 SHA1Final(hcomp, &ctx);
245 err = readn(&r, fd, hexpect, sizeof(hexpect));
248 if (r != sizeof(hexpect))
249 return got_error(GOT_ERR_BAD_PACKFILE);
250 if (memcmp(hcomp, hexpect, SHA1_DIGEST_LENGTH) != 0)
251 return got_error(GOT_ERR_BAD_PACKFILE);
256 match_branch(char *br, char *pat)
260 if (strstr(pat, "refs/heads") == pat) {
261 if (snprintf(name, sizeof(name), "%s", pat) >= sizeof(name))
263 } else if (strstr(pat, "heads")) {
264 if (snprintf(name, sizeof(name), "refs/%s", pat)
268 if (snprintf(name, sizeof(name), "refs/heads/%s", pat)
272 return strcmp(br, name) == 0;
275 static const struct got_error *
276 tokenize_refline(char **tokens, char *line, int len, int maxtokens)
278 const struct got_error *err = NULL;
282 for (i = 0; i < maxtokens; i++)
285 for (i = 0; n < len && i < maxtokens; i++) {
286 while (isspace(*line)) {
291 while (*line != '\0' &&
292 (!isspace(*line) || i == maxtokens - 1)) {
296 tokens[i] = strndup(p, line - p);
297 if (tokens[i] == NULL) {
298 err = got_error_from_errno("strndup");
301 /* Skip \0 field-delimiter at end of token. */
302 while (line[0] == '\0' && n < len) {
308 err = got_error(GOT_ERR_NOT_REF);
312 for (j = 0; j < i; j++)
319 static const struct got_error *
320 parse_refline(char **id_str, char **refname, char **server_capabilities,
323 const struct got_error *err = NULL;
326 err = tokenize_refline(tokens, line, len, nitems(tokens));
333 *refname = tokens[1];
335 *server_capabilities = tokens[2];
340 #define GOT_CAPA_AGENT "agent"
341 #define GOT_CAPA_OFS_DELTA "ofs-delta"
342 #define GOT_CAPA_SIDE_BAND_64K "side-band-64k"
344 #define GOT_SIDEBAND_PACKFILE_DATA 1
345 #define GOT_SIDEBAND_PROGRESS_INFO 2
346 #define GOT_SIDEBAND_ERROR_INFO 3
349 struct got_capability {
353 static const struct got_capability got_capabilities[] = {
354 { GOT_CAPA_AGENT, "got/" GOT_VERSION_STR },
355 { GOT_CAPA_OFS_DELTA, NULL },
356 { GOT_CAPA_SIDE_BAND_64K, NULL },
359 static const struct got_error *
360 match_capability(char **my_capabilities, const char *capa,
361 const struct got_capability *mycapa)
366 equalsign = strchr(capa, '=');
368 if (strncmp(capa, mycapa->key, equalsign - capa) != 0)
371 if (strcmp(capa, mycapa->key) != 0)
375 if (asprintf(&s, "%s%s%s%s%s",
376 *my_capabilities != NULL ? *my_capabilities : "",
377 *my_capabilities != NULL ? " " : "",
379 mycapa->value != NULL ? "=" : "",
380 mycapa->value != NULL? mycapa->value : "") == -1)
381 return got_error_from_errno("asprintf");
383 free(*my_capabilities);
384 *my_capabilities = s;
388 static const struct got_error *
389 add_symref(struct got_pathlist_head *symrefs, char *capa)
391 const struct got_error *err = NULL;
392 char *colon, *name = NULL, *target = NULL;
394 /* Need at least "A:B" */
395 if (strlen(capa) < 3)
398 colon = strchr(capa, ':');
405 return got_error_from_errno("strdup");
407 target = strdup(colon + 1);
408 if (target == NULL) {
409 err = got_error_from_errno("strdup");
413 /* We can't validate the ref itself here. The main process will. */
414 err = got_pathlist_append(symrefs, name, target);
423 static const struct got_error *
424 match_capabilities(char **my_capabilities, struct got_pathlist_head *symrefs,
425 char *server_capabilities)
427 const struct got_error *err = NULL;
428 char *capa, *equalsign;
431 *my_capabilities = NULL;
433 capa = strsep(&server_capabilities, " ");
437 equalsign = strchr(capa, '=');
438 if (equalsign != NULL &&
439 strncmp(capa, "symref", equalsign - capa) == 0) {
440 err = add_symref(symrefs, equalsign + 1);
446 for (i = 0; i < nitems(got_capabilities); i++) {
447 err = match_capability(my_capabilities,
448 capa, &got_capabilities[i]);
457 static const struct got_error *
458 fetch_progress(struct imsgbuf *ibuf, const char *buf, size_t len)
467 * Truncate messages which exceed the maximum imsg payload size.
468 * Server may send up to 64k.
470 if (len > MAX_IMSGSIZE - IMSG_HEADER_SIZE)
471 len = MAX_IMSGSIZE - IMSG_HEADER_SIZE;
473 /* Only allow printable ASCII. */
474 for (i = 0; i < len; i++) {
475 if (isprint((unsigned char)buf[i]) ||
476 isspace((unsigned char)buf[i]))
478 return got_error_msg(GOT_ERR_BAD_PACKET,
479 "non-printable progress message received from server");
482 return got_privsep_send_fetch_server_progress(ibuf, buf, len);
485 static const struct got_error *
486 fetch_error(const char *buf, size_t len)
488 static char msg[1024];
491 for (i = 0; i < len && i < sizeof(msg) - 1; i++) {
492 if (!isprint(buf[i]))
493 return got_error_msg(GOT_ERR_BAD_PACKET,
494 "non-printable error message received from server");
498 return got_error_msg(GOT_ERR_FETCH_FAILED, msg);
501 static const struct got_error *
502 fetch_pack(int fd, int packfd, struct got_object_id *packid,
503 struct got_pathlist_head *have_refs, struct imsgbuf *ibuf)
505 const struct got_error *err = NULL;
506 char buf[GOT_PKTMAX];
507 char hashstr[SHA1_DIGEST_STRING_LENGTH];
508 struct got_object_id *have, *want;
509 int is_firstpkt = 1, nref = 0, refsz = 16;
512 char *id_str = NULL, *refname = NULL;
513 char *server_capabilities = NULL, *my_capabilities = NULL;
514 struct got_pathlist_head symrefs;
515 struct got_pathlist_entry *pe;
516 int have_sidebands = 0;
517 uint32_t nobjects = 0;
519 TAILQ_INIT(&symrefs);
521 have = malloc(refsz * sizeof(have[0]));
523 return got_error_from_errno("malloc");
524 want = malloc(refsz * sizeof(want[0]));
526 err = got_error_from_errno("malloc");
530 fprintf(stderr, "starting fetch\n");
532 err = readpkt(&n, fd, buf, sizeof(buf));
537 if (n >= 4 && strncmp(buf, "ERR ", 4) == 0) {
538 err = fetch_error(&buf[4], n - 4);
541 err = parse_refline(&id_str, &refname, &server_capabilities,
545 if (chattygit && server_capabilities[0] != '\0')
546 fprintf(stderr, "server capabilities: %s\n",
547 server_capabilities);
549 err = match_capabilities(&my_capabilities, &symrefs,
550 server_capabilities);
553 if (chattygit && my_capabilities)
554 fprintf(stderr, "my matched capabilities: %s\n",
556 err = got_privsep_send_fetch_symrefs(ibuf, &symrefs);
561 if (strstr(refname, "^{}"))
563 if (fetchbranch && !match_branch(refname, fetchbranch))
565 if (refsz == nref + 1) {
567 have = reallocarray(have, refsz, sizeof(have[0]));
569 err = got_error_from_errno("reallocarray");
572 want = reallocarray(want, refsz, sizeof(want[0]));
574 err = got_error_from_errno("reallocarray");
578 if (!got_parse_sha1_digest(want[nref].sha1, id_str)) {
579 err = got_error(GOT_ERR_BAD_OBJ_ID_STR);
583 err = match_remote_ref(have_refs, &have[nref], id_str, refname);
587 err = got_privsep_send_fetch_ref(ibuf, &want[nref],
592 fprintf(stderr, "remote %s\n", refname);
597 for (i = 0; i < nref; i++) {
598 if (got_object_id_cmp(&have[i], &want[i]) == 0)
600 got_sha1_digest_to_str(want[i].sha1, hashstr, sizeof(hashstr));
601 n = snprintf(buf, sizeof(buf), "want %s%s%s\n", hashstr,
602 i == 0 && my_capabilities ? " " : "",
603 i == 0 && my_capabilities ? my_capabilities : "");
604 if (n >= sizeof(buf)) {
605 err = got_error(GOT_ERR_NO_SPACE);
608 err = writepkt(fd, buf, n);
616 for (i = 0; i < nref; i++) {
617 if (got_object_id_cmp(&have[i], &zhash) == 0)
619 got_sha1_digest_to_str(want[i].sha1, hashstr, sizeof(hashstr));
620 n = snprintf(buf, sizeof(buf), "have %s\n", hashstr);
621 if (n >= sizeof(buf)) {
622 err = got_error(GOT_ERR_NO_SPACE);
625 err = writepkt(fd, buf, n + 1);
631 fprintf(stderr, "up to date\n");
636 n = snprintf(buf, sizeof(buf), "done\n");
637 err = writepkt(fd, buf, n);
643 err = readpkt(&n, fd, buf, sizeof(buf));
647 * For now, we only support a full clone, in which case the server
648 * will now send a "NAK" (meaning no common objects were found).
650 if (n != 4 || strncmp(buf, "NAK\n", n) != 0) {
651 err = got_error_msg(GOT_ERR_BAD_PACKET,
652 "unexpected message from server");
657 fprintf(stderr, "fetching...\n");
659 if (my_capabilities != NULL &&
660 strstr(my_capabilities, GOT_CAPA_SIDE_BAND_64K) != NULL)
668 if (have_sidebands) {
669 err = read_pkthdr(&datalen, fd);
675 /* Read sideband channel ID (one byte). */
676 r = read(fd, buf, 1);
678 err = got_error_from_errno("read");
682 err = got_error_msg(GOT_ERR_BAD_PACKET,
686 if (datalen > sizeof(buf) - 5) {
687 err = got_error_msg(GOT_ERR_BAD_PACKET,
688 "bad packet length");
691 datalen--; /* sideband ID has been read */
692 if (buf[0] == GOT_SIDEBAND_PACKFILE_DATA) {
693 /* Read packfile data. */
694 err = readn(&r, fd, buf, datalen);
698 err = got_error_msg(GOT_ERR_BAD_PACKET,
702 } else if (buf[0] == GOT_SIDEBAND_PROGRESS_INFO) {
703 err = readn(&r, fd, buf, datalen);
707 err = got_error_msg(GOT_ERR_BAD_PACKET,
711 err = fetch_progress(ibuf, buf, r);
715 } else if (buf[0] == GOT_SIDEBAND_ERROR_INFO) {
716 err = readn(&r, fd, buf, datalen);
720 err = got_error_msg(GOT_ERR_BAD_PACKET,
724 err = fetch_error(buf, r);
727 err = got_error_msg(GOT_ERR_BAD_PACKET,
728 "unknown side-band received from server");
732 /* No sideband channel. Every byte is packfile data. */
733 err = readn(&r, fd, buf, sizeof buf);
740 /* Check pack file header. */
742 struct got_packfile_hdr *hdr = (void *)buf;
743 if (r < sizeof(*hdr)) {
744 err = got_error_msg(GOT_ERR_BAD_PACKFILE,
745 "short packfile header");
748 if (hdr->signature != htobe32(GOT_PACKFILE_SIGNATURE)) {
749 err = got_error_msg(GOT_ERR_BAD_PACKFILE,
750 "bad packfile signature");
753 if (hdr->version != htobe32(GOT_PACKFILE_VERSION)) {
754 err = got_error_msg(GOT_ERR_BAD_PACKFILE,
755 "bad packfile version");
758 nobjects = betoh32(hdr->nobjects);
760 err = got_error_msg(GOT_ERR_BAD_PACKFILE,
761 "bad packfile with zero objects");
766 /* Write packfile data to temporary pack file. */
767 w = write(packfd, buf, r);
769 err = got_error_from_errno("write");
773 err = got_error(GOT_ERR_IO);
778 if (lseek(packfd, 0, SEEK_SET) == -1) {
779 err = got_error_from_errno("lseek");
782 err = check_pack_hash(packfd, packsz, packid->sha1);
784 TAILQ_FOREACH(pe, &symrefs, entry) {
785 free((void *)pe->path);
788 got_pathlist_free(&symrefs);
793 free(server_capabilities);
799 main(int argc, char **argv)
801 const struct got_error *err = NULL;
802 int fetchfd, packfd = -1;
803 struct got_object_id packid;
806 struct got_pathlist_head have_refs;
807 struct got_imsg_fetch_have_refs *fetch_have_refs = NULL;
810 TAILQ_INIT(&have_refs);
812 if (getenv("GOT_DEBUG") != NULL) {
813 fprintf(stderr, "fetch-pack being chatty!\n");
817 imsg_init(&ibuf, GOT_IMSG_FD_CHILD);
819 /* revoke access to most system calls */
820 if (pledge("stdio recvfd", NULL) == -1) {
821 err = got_error_from_errno("pledge");
822 got_privsep_send_error(&ibuf, err);
826 if ((err = got_privsep_recv_imsg(&imsg, &ibuf, 0)) != 0) {
827 if (err->code == GOT_ERR_PRIVSEP_PIPE)
831 if (imsg.hdr.type == GOT_IMSG_STOP)
833 if (imsg.hdr.type != GOT_IMSG_FETCH_REQUEST) {
834 err = got_error(GOT_ERR_PRIVSEP_MSG);
837 datalen = imsg.hdr.len - IMSG_HEADER_SIZE;
838 if (datalen < sizeof(struct got_imsg_fetch_have_refs)) {
839 err = got_error(GOT_ERR_PRIVSEP_LEN);
842 fetch_have_refs = (struct got_imsg_fetch_have_refs *)imsg.data;
843 if (datalen != sizeof(struct got_imsg_fetch_have_refs) +
844 sizeof(struct got_imsg_fetch_have_ref) *
845 fetch_have_refs->n_have_refs) {
846 err = got_error(GOT_ERR_PRIVSEP_LEN);
849 if (fetch_have_refs->n_have_refs != 0) {
850 /* TODO: Incremental fetch support */
851 err = got_error(GOT_ERR_NOT_IMPL);
856 if ((err = got_privsep_recv_imsg(&imsg, &ibuf, 0)) != 0) {
857 if (err->code == GOT_ERR_PRIVSEP_PIPE)
861 if (imsg.hdr.type == GOT_IMSG_STOP)
863 if (imsg.hdr.type != GOT_IMSG_TMPFD) {
864 err = got_error(GOT_ERR_PRIVSEP_MSG);
867 if (imsg.hdr.len - IMSG_HEADER_SIZE != 0) {
868 err = got_error(GOT_ERR_PRIVSEP_LEN);
873 err = fetch_pack(fetchfd, packfd, &packid, &have_refs, &ibuf);
875 if (packfd != -1 && close(packfd) == -1 && err == NULL)
876 err = got_error_from_errno("close");
878 got_privsep_send_error(&ibuf, err);
880 err = got_privsep_send_fetch_done(&ibuf, packid);
882 fprintf(stderr, "%s: %s\n", getprogname(), err->msg);
883 got_privsep_send_error(&ibuf, err);