]> ruderich.org/simon Gitweb - socket2unix/socket2unix.git/blob - src/socket2unix.c
Rename DEBUG() macro to DBG() to fix name clash.
[socket2unix/socket2unix.git] / src / socket2unix.c
1 /*
2  * Simple LD_PRELOAD wrapper to "convert" network sockets to UNIX sockets;
3  * works for clients and servers. See README for details.
4  *
5  * Copyright (C) 2013  Simon Ruderich
6  *
7  * This program is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation, either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21
22 /* Necessary for RTLD_NEXT. */
23 #define _GNU_SOURCE
24
25 #include <assert.h>
26 #include <dlfcn.h>
27 #include <errno.h>
28 #include <netinet/in.h>
29 #include <netinet/ip.h>
30 #include <stdarg.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <sys/socket.h>
35 #include <sys/stat.h>
36 #include <sys/types.h>
37 #include <sys/un.h>
38 #include <unistd.h>
39
40
41 /* CONSTANTS */
42
43 #define LOG_LEVEL_ERROR   1
44 #define LOG_LEVEL_WARNING 2
45 #define LOG_LEVEL_DEBUG   3
46 #define LOG_LEVEL_MASK    LOG_LEVEL_DEBUG
47
48 #define LOG_LEVEL_PERROR  42
49
50
51 /* GLOBAL VARIABLES */
52
53 struct list {
54     int unix_sockfd;
55
56     int orig_domain;
57     int orig_type;
58
59     struct list *next;
60 };
61
62 static struct list socket_list = {
63     .unix_sockfd = -1, /* must not match a valid sockfd */
64 };
65
66
67 /* LOG FUNCTIONS/MACROS */
68
69 static int get_log_level(void);
70
71 static void log_helper(int action, const char *file, int line, const char *format, va_list ap) {
72     int saved_errno = errno;
73
74     static int log_level;
75     if (!log_level) {
76         log_level = get_log_level();
77     }
78
79     int level = action & LOG_LEVEL_MASK;
80     if (level > log_level) {
81         return;
82     }
83
84     const char *prefix;
85     if (level == LOG_LEVEL_DEBUG) {
86         prefix = "DEBUG";
87     } else if (level == LOG_LEVEL_WARNING) {
88         prefix = "WARN ";
89     } else if (level == LOG_LEVEL_ERROR) {
90         prefix = "ERROR";
91     } else {
92         prefix = "UNKNOWN";
93     }
94
95     /* Prevent other threads from interrupting the printf()s. */
96     flockfile(stderr);
97
98     fprintf(stderr, "socket2unix [%s] ", prefix);
99     fprintf(stderr, "[%s:%3d] ", file, line);
100     vfprintf(stderr, format, ap);
101
102     if ((action & ~LOG_LEVEL_MASK) == LOG_LEVEL_PERROR) {
103         fprintf(stderr, ": ");
104
105         errno = saved_errno;
106         perror("");
107     }
108
109     funlockfile(stderr);
110
111     if (level == LOG_LEVEL_ERROR) {
112         fprintf(stderr, "Aborting.\n");
113         exit(EXIT_FAILURE);
114     }
115 }
116
117 static void log_(int level, const char *file, int line, const char *format, ...)
118     __attribute__((format(printf, 4, 5)));
119 static void log_(int level, const char *file, int line, const char *format, ...) {
120     va_list ap;
121
122     va_start(ap, format);
123     log_helper(level, file, line, format, ap);
124     va_end(ap);
125 }
126
127 #define ERROR(...) \
128     log_(LOG_LEVEL_ERROR,   __FILE__, __LINE__, __VA_ARGS__)
129 #define WARN(...) \
130     log_(LOG_LEVEL_WARNING, __FILE__, __LINE__, __VA_ARGS__)
131 #define DBG(...) \
132     log_(LOG_LEVEL_DEBUG,   __FILE__, __LINE__, __VA_ARGS__)
133
134 #define DIE(...) \
135     log_(LOG_LEVEL_ERROR | LOG_LEVEL_PERROR, __FILE__, __LINE__, __VA_ARGS__)
136
137
138 /* LD_PRELOAD */
139
140 /* Load the function name using dlsym() if necessary and store it in pointer.
141  * Terminate program on failure. */
142 #define LOAD_FUNCTION(pointer, name) \
143     if ((pointer) == NULL) { \
144         char *error; \
145         dlerror(); /* Clear possibly existing error. */ \
146         \
147         *(void **) (&(pointer)) = dlsym(RTLD_NEXT, (name)); \
148         \
149         if ((error = dlerror()) != NULL) { \
150             ERROR("%s\n", error); \
151         } \
152     }
153
154
155 /* OTHER FUNCTIONS */
156
157 static struct list *find_sockfd(int sockfd) {
158     struct list *e;
159
160     if (sockfd == socket_list.unix_sockfd) {
161         return NULL;
162     }
163
164     for (e = &socket_list; e != NULL; e = e->next) {
165         if (e->unix_sockfd == sockfd) {
166             return e;
167         }
168     }
169     return NULL;
170 }
171 static struct list *remove_sockfd(int sockfd) {
172     struct list *e, *p;
173
174     if (sockfd == socket_list.unix_sockfd) {
175         return NULL;
176     }
177
178     for (e = &socket_list, p = NULL; e != NULL; p = e, e = e->next) {
179         if (e->unix_sockfd == sockfd) {
180             p->next = e->next;
181             return e;
182         }
183     }
184     return NULL;
185 }
186
187 static const char *get_socket_path(void) {
188     const char *path = getenv("SOCKET2UNIX_PATH");
189     if (!path) {
190         ERROR("SOCKET2UNIX_PATH environment variable not defined\n");
191     }
192     if (path[0] != '/') {
193         ERROR("SOCKET2UNIX_PATH '%s' must be an absolute path\n", path);
194     }
195     return path;
196 }
197 static int get_log_level(void) {
198     const char *level = getenv("SOCKET2UNIX_DEBUG");
199     if (!level) {
200         return LOG_LEVEL_WARNING;
201     }
202     int number = atoi(level);
203     if (number <= 0 || number > LOG_LEVEL_DEBUG) {
204         number = LOG_LEVEL_DEBUG;
205     }
206     return number;
207 }
208
209 static const char *af_to_name(int af) {
210     if (af == AF_UNIX) {
211         return "AF_UNIX";
212     } else if (af == AF_LOCAL) {
213         return "AF_LOCAL";
214     } else if (af == AF_INET) {
215         return "AF_INET";
216     } else if (af == AF_INET6) {
217         return "AF_INET6";
218     } else if (af == AF_IPX) {
219         return "AF_IPX";
220     } else if (af == AF_NETLINK) {
221         return "AF_NETLINK";
222     } else if (af == AF_X25) {
223         return "AF_X25";
224     } else if (af == AF_AX25) {
225         return "AF_AX25";
226     } else if (af == AF_ATMPVC) {
227         return "AF_ATMPVC";
228     } else if (af == AF_APPLETALK) {
229         return "AF_APPLETALK";
230     } else if (af == AF_PACKET) {
231         return "AF_PACKET";
232     } else {
233         return "AF_UNKNOWN";
234     }
235 }
236 static const char *sock_to_name(int sock) {
237     if (sock & SOCK_STREAM) {
238         return "SOCK_STREAM";
239     } else if (sock & SOCK_DGRAM) {
240         return "SOCK_DGRAM";
241     } else if (sock & SOCK_SEQPACKET) {
242         return "SOCK_SEQPACKET";
243     } else if (sock & SOCK_RAW) {
244         return "SOCK_RAW";
245     } else if (sock & SOCK_RDM) {
246         return "SOCK_RDM";
247     } else if (sock & SOCK_PACKET) {
248         return "SOCK_PACKET";
249     } else {
250         return "SOCK_UNKNOWN";
251     }
252 }
253 /* for getsockopt()/setsockopt(). */
254 static const char *level_to_name(int level) {
255     if (level == SOL_SOCKET) {
256         return "SOL_SOCKET";
257     } else if (level == SOL_IP) {
258         return "SOL_IP";
259     } else if (level == SOL_IPV6) {
260         return "SOL_IPV6";
261     } else if (level == IPPROTO_TCP) {
262         return "IPPROTO_TCP";
263     } else if (level == IPPROTO_UDP) {
264         return "IPPROTO_UDP";
265     } else {
266         return "SOL_UNKNOWN";
267     }
268 }
269
270
271 static int set_sockaddr_un(struct sockaddr_un *sockaddr,
272                            const struct sockaddr *addr, socklen_t addrlen) {
273     /* Just in case ... */
274     if ((addr->sa_family == AF_INET
275                 && addrlen < sizeof(struct sockaddr_in))
276             || (addr->sa_family == AF_INET6
277                 && addrlen < sizeof(struct sockaddr_in6))) {
278         WARN("invalid addrlen from program\n");
279         return -1;
280     }
281
282     const char *socket_path = get_socket_path();
283
284     /* The program may open multiple sockets, e.g. IPv4 and IPv6 and on
285      * multiple ports. Create unique paths. */
286     const char *af;
287     int port;
288     if (addr->sa_family == AF_INET) {
289         af   = "v4";
290         port = ntohs(((struct sockaddr_in *)addr)->sin_port);
291     } else if (addr->sa_family == AF_INET6) {
292         af   = "v6";
293         port = ntohs(((struct sockaddr_in6 *)addr)->sin6_port);
294     } else {
295         af   = "unknown";
296         port = 0;
297         WARN("unknown sa_family '%s' (%d)\n",
298              af_to_name(addr->sa_family), addr->sa_family);
299     }
300
301     /* Initialize sockaddr_un. */
302     sockaddr->sun_family = AF_UNIX;
303     int written = snprintf(sockaddr->sun_path, sizeof(sockaddr->sun_path),
304                            "%s-%s-%d", socket_path, af, port);
305     /* The maximum length is quite short, check it. */
306     if (written >= (int)sizeof(sockaddr->sun_path)) {
307         ERROR("path '%s-%s-%d' too long for UNIX socket",
308               socket_path, af, port);
309     }
310
311     return 0;
312 }
313
314
315 /* FUNCTIONS OVERWRITTEN BY LD_PRELOAD */
316
317 int socket(int domain, int type, int protocol) {
318     static int (*real_socket)(int, int, int);
319     LOAD_FUNCTION(real_socket, "socket");
320
321     if (domain == AF_UNIX || domain == AF_LOCAL) {
322         return real_socket(domain, type, protocol);
323     }
324
325     DBG("socket(%s, %s, %d)\n",
326         af_to_name(domain), sock_to_name(type), protocol);
327
328     /* We must return the replacement socket in case the program uses select()
329      * or similar on it. */
330
331     int unix_sockfd = real_socket(AF_UNIX, type, 0);
332     if (unix_sockfd < 0) {
333         DIE("bind(): failed to create UNIX socket");
334     }
335
336     struct list *entry = malloc(sizeof(*entry));
337     if (!entry) {
338         DIE("socket(): malloc");
339     }
340     memset(entry, 0, sizeof(*entry));
341
342     entry->unix_sockfd = unix_sockfd;
343     entry->orig_domain = domain;
344     entry->orig_type   = type;
345
346     entry->next = socket_list.next;
347     socket_list.next = entry;
348
349     return unix_sockfd;
350 }
351
352 int close(int fd) {
353     static int (*real_close)(int);
354     LOAD_FUNCTION(real_close, "close");
355
356     DBG("close(%d)\n", fd);
357
358     struct list *entry = remove_sockfd(fd);
359     if (entry == NULL) {
360         DBG("close(%d): sockfd not found\n", fd);
361         return real_close(fd);
362     }
363     assert(fd == entry->unix_sockfd);
364     free(entry);
365
366     return real_close(fd);
367 }
368
369 int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
370     static int (*real_bind)(int, const struct sockaddr *, socklen_t);
371     LOAD_FUNCTION(real_bind, "bind");
372
373     DBG("bind(%d, ..)\n", sockfd);
374
375     if (addr == NULL || addrlen < sizeof(addr->sa_family)
376             || addr->sa_family == AF_UNIX
377             || addr->sa_family == AF_LOCAL) {
378         return real_bind(sockfd, addr, addrlen);
379     }
380
381     struct list *entry = find_sockfd(sockfd);
382     if (!entry) {
383         DBG("bind(%d, ..): sockfd not found\n", sockfd);
384         return real_bind(sockfd, addr, addrlen);
385     }
386     assert(sockfd == entry->unix_sockfd);
387     DBG("bind(%d, ..): %s %s\n",
388         sockfd,
389         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
390
391     struct sockaddr_un sockaddr;
392     if (set_sockaddr_un(&sockaddr, addr, addrlen) != 0) {
393         ERROR("connect(%d, ..) failed\n", sockfd);
394     }
395
396     DBG("bind(%d, ..): using path '%s'\n", sockfd, sockaddr.sun_path);
397
398     int attempts = 0;
399     while (attempts < 10) {
400         if (real_bind(entry->unix_sockfd, (struct sockaddr *)&sockaddr,
401                                           sizeof(sockaddr)) == 0) {
402             /* Success. */
403             return 0;
404         }
405         if (errno != EADDRINUSE) {
406             DIE("bind(%d, ..): failed to bind to '%s'",
407                 sockfd, sockaddr.sun_path);
408         }
409
410         /* File already exists, unlink it if it's a socket. This has a race
411          * condition, but the worst case is that we delete a file created by
412          * the user at the path he told us to use. Tough luck .. */
413
414         struct stat buf;
415         if (lstat(sockaddr.sun_path, &buf) != 0) {
416             /* Looks like a race, better abort. */
417             DIE("bind(%d, ..): lstat on UNIX socket '%s' failed",
418                 sockfd, sockaddr.sun_path);
419         }
420
421         if (!S_ISSOCK(buf.st_mode)) {
422             ERROR("bind(%d, ..): path '%s' exits and is no socket\n",
423                   sockfd, sockaddr.sun_path);
424         }
425
426         WARN("bind(%d, ..): unlinking '%s'\n", sockfd, sockaddr.sun_path);
427         if (unlink(sockaddr.sun_path) != 0) {
428             DIE("bind(%d, ..): unlink '%s' failed",
429                 sockfd, sockaddr.sun_path);
430         }
431
432         attempts++;
433     }
434
435     ERROR("bind(%d, ..): failed to create UNIX socket file\n", sockfd);
436     return -1; /* never reached */
437 }
438
439 int listen(int sockfd, int backlog) {
440     static int (*real_listen)(int, int);
441     LOAD_FUNCTION(real_listen, "listen");
442
443     DBG("listen(%d, %d)\n", sockfd, backlog);
444
445     struct list *entry = find_sockfd(sockfd);
446     if (!entry) {
447         DBG("listen(%d, %d): sockfd not found\n", sockfd, backlog);
448         return real_listen(sockfd, backlog);
449     }
450     assert(sockfd == entry->unix_sockfd);
451     DBG("listen(%d, %d): %s %s\n",
452         sockfd, backlog,
453         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
454
455     if (real_listen(entry->unix_sockfd, backlog) != 0) {
456         DIE("listen(): failed to listen");
457     }
458
459     return 0;
460 }
461
462 int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
463     static int (*real_accept)(int, struct sockaddr *, socklen_t *);
464     LOAD_FUNCTION(real_accept, "accept");
465
466     DBG("accept(%d, ..)\n", sockfd);
467
468     struct list *entry = find_sockfd(sockfd);
469     if (!entry) {
470         DBG("accept(%d, ..): sockfd not found\n", sockfd);
471         return real_accept(sockfd, addr, addrlen);
472     }
473     assert(sockfd == entry->unix_sockfd);
474     DBG("accept(%d, ..): %s %s\n",
475         sockfd,
476         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
477
478     struct sockaddr_un sockaddr;
479     socklen_t size = sizeof(sockaddr);
480     int sock = real_accept(entry->unix_sockfd, (struct sockaddr *)&sockaddr,
481                                                &size);
482     if (sock < 0) {
483         DIE("accept(%d, ..): failed to accept", sockfd);
484     }
485
486     if (addr == NULL || addrlen == NULL) {
487         return sock;
488     }
489     DBG("accept(%d, ..): caller requested sockaddr\n", sockfd);
490
491     if (*addrlen < size) {
492         WARN("accept(%d, ..): invalid addrlen from program", sockfd);
493         errno = EINVAL;
494         return -1;
495     }
496
497     /* This is not the protocol the program asked for (AF_* vs. AF_UNIX), but
498      * it should work most of the time. */
499     memcpy(addr, &sockaddr, size);
500     *addrlen = size;
501
502     /* TODO: is this enough? */
503
504     return sock;
505 }
506
507 int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
508     static int (*real_connect)(int, const struct sockaddr *, socklen_t);
509     LOAD_FUNCTION(real_connect, "connect");
510
511     DBG("connect(%d, ..)\n", sockfd);
512
513     if (addr == NULL || addrlen < sizeof(addr->sa_family)
514             || addr->sa_family == AF_UNIX
515             || addr->sa_family == AF_LOCAL) {
516         return real_connect(sockfd, addr, addrlen);
517     }
518
519     struct list *entry = find_sockfd(sockfd);
520     if (!entry) {
521         DBG("connect(%d, ..): sockfd not found\n", sockfd);
522         return real_connect(sockfd, addr, addrlen);
523     }
524     assert(sockfd == entry->unix_sockfd);
525     DBG("connect(%d, ..): %s %s\n",
526         sockfd,
527         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
528
529     struct sockaddr_un sockaddr;
530     if (set_sockaddr_un(&sockaddr, addr, addrlen) != 0) {
531         ERROR("connect(%d, ..) failed\n", sockfd);
532     }
533
534     DBG("connect(%d, ..): using path '%s'\n", sockfd, sockaddr.sun_path);
535
536     if (real_connect(entry->unix_sockfd, (struct sockaddr *)&sockaddr,
537                                          sizeof(sockaddr)) != 0) {
538         DIE("connect(%d, ..): failed to connect", sockfd);
539     }
540
541     return 0;
542 }
543
544
545 int getsockname(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
546     static int (*real_getsockname)(int, struct sockaddr *, socklen_t *);
547     LOAD_FUNCTION(real_getsockname, "getsockname");
548
549     DBG("getsockname(%d, ..)\n", sockfd);
550
551     return real_getsockname(sockfd, addr, addrlen);
552 }
553
554 int getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
555     static int (*real_getpeername)(int, struct sockaddr *, socklen_t *);
556     LOAD_FUNCTION(real_getpeername, "getpeername");
557
558     DBG("getpeername(%d, ..)\n", sockfd);
559
560     return real_getpeername(sockfd, addr, addrlen);
561 }
562
563 int getsockopt(int sockfd, int level, int optname, void *optval, socklen_t *optlen) {
564     static int (*real_getsockopt)(int, int, int, void *, socklen_t *);
565     LOAD_FUNCTION(real_getsockopt, "getsockopt");
566
567     DBG("getsockopt(%d, %d %s, %d, ..)\n",
568         sockfd, level, level_to_name(level), optname);
569
570     return real_getsockopt(sockfd, level, optname, optval, optlen);
571 }
572 int setsockopt(int sockfd, int level, int optname, const void *optval, socklen_t optlen) {
573     static int (*real_setsockopt)(int, int, int, const void *, socklen_t);
574     LOAD_FUNCTION(real_setsockopt, "setsockopt");
575
576     DBG("setsockopt(%d, %d %s, %d, ..)\n",
577         sockfd, level, level_to_name(level), optname);
578
579     return real_setsockopt(sockfd, level, optname, optval, optlen);
580 }