]> ruderich.org/simon Gitweb - socket2unix/socket2unix.git/commitdiff
Allow disabling of hooks for client or server functionality.
authorSimon Ruderich <simon@ruderich.org>
Mon, 2 Dec 2013 22:25:49 +0000 (23:25 +0100)
committerSimon Ruderich <simon@ruderich.org>
Mon, 2 Dec 2013 22:25:49 +0000 (23:25 +0100)
Add new SOCKET2UNIX_OPTIONS environment variable with options
'client_only' and 'server_only'.

README
src/socket2unix.c

diff --git a/README b/README
index ba408bd3a9b8220fde53126493821437b5a54343..da69397ca75d01bbc8cf7d34c5609d5c0fc83f40 100644 (file)
--- a/README
+++ b/README
@@ -71,6 +71,13 @@ The following additional environment variables are available:
 - 'SOCKET2UNIX_DEBUG':
   Control debug level. 1 = errors only, 2 = warnings only, 3 = debug messages.
   Default: 2
+- 'SOCKET2UNIX_OPTIONS':
+  Comma separated list of options for socket2unix. Valid options are (without
+  quotes):
+  - 'client_only': Don't intercept calls to `listen()` and `accept()`.
+  - 'server_only': Don't intercept calls to `connect()`.
+    These options are useful if a program has both client and server
+    functionality but only one part should be redirected.
 
 
 BUGS
index c90c864ebecf79b805e96a29add3412efcb7596e..1f741b5f5680e9b97d0f09be17540f718523d2cd 100644 (file)
 
 #define LOG_LEVEL_PERROR  42
 
+#define OPTION_PARSED                                   (1 << 1)
+/* Don't intercept listen(), accept(). */
+#define OPTION_CLIENT_ONLY                              (1 << 2)
+/* Don't intercept connect(). */
+#define OPTION_SERVER_ONLY                              (1 << 3)
+
 
 /* GLOBAL VARIABLES */
 
 struct list {
-    int unix_sockfd;
-
+    int orig_sockfd;
     int orig_domain;
     int orig_type;
 
+    /* Used by listen(). */
+    struct sockaddr *orig_addr;
+    socklen_t orig_addrlen;
+
     struct list *next;
 };
 
 static struct list socket_list = {
-    .unix_sockfd = -1, /* must not match a valid sockfd */
+    .orig_sockfd = -1, /* must not match a valid sockfd */
 };
 
+static int global_options;
+
 
 /* LOG FUNCTIONS/MACROS */
 
@@ -155,15 +166,23 @@ static void log_(int level, const char *file, int line, const char *format, ...)
 
 /* OTHER FUNCTIONS */
 
+static void *xmalloc(size_t size) {
+    void *x = malloc(size);
+    if (!x) {
+        DIE("malloc(%zu)", size);
+    }
+    return x;
+}
+
 static struct list *find_sockfd(int sockfd) {
     struct list *e;
 
-    if (sockfd == socket_list.unix_sockfd) {
+    if (sockfd == socket_list.orig_sockfd) {
         return NULL;
     }
 
     for (e = &socket_list; e != NULL; e = e->next) {
-        if (e->unix_sockfd == sockfd) {
+        if (e->orig_sockfd == sockfd) {
             return e;
         }
     }
@@ -172,12 +191,12 @@ static struct list *find_sockfd(int sockfd) {
 static struct list *remove_sockfd(int sockfd) {
     struct list *e, *p;
 
-    if (sockfd == socket_list.unix_sockfd) {
+    if (sockfd == socket_list.orig_sockfd) {
         return NULL;
     }
 
     for (e = &socket_list, p = NULL; e != NULL; p = e, e = e->next) {
-        if (e->unix_sockfd == sockfd) {
+        if (e->orig_sockfd == sockfd) {
             p->next = e->next;
             return e;
         }
@@ -210,6 +229,49 @@ static int get_log_level(void) {
     }
     return number;
 }
+static int get_options(void) {
+    const char *pos = getenv("SOCKET2UNIX_OPTIONS");
+    if (!pos) {
+        return OPTION_PARSED;
+    }
+
+    int options = OPTION_PARSED;
+
+    while (*pos != '\0') {
+        size_t length;
+
+        const char *end = strchr(pos, ',');
+        if (end == NULL) {
+            length = strlen(pos);
+        } else {
+            length = (size_t)(end - pos);
+        }
+
+        if (!strncmp("client_only", pos, length)) {
+            options |= OPTION_CLIENT_ONLY;
+        } else if (!strncmp("server_only", pos, length)) {
+            options |= OPTION_SERVER_ONLY;
+        } else {
+            char option[length + 1];
+            strncpy(option, pos, length);
+            option[length] = '\0';
+            ERROR("unknown option '%s' in SOCKET2UNIX_OPTIONS\n",
+                  option);
+        }
+
+        if (end == NULL) {
+            break;
+        }
+        pos = end + 1;
+    }
+
+    if ((options & OPTION_CLIENT_ONLY) && (options & OPTION_SERVER_ONLY)) {
+        ERROR("conflicting options 'client_only', 'server_only' "
+              "in SOCKET2UNIX_OPTIONS\n");
+    }
+
+    return options;
+}
 
 static const char *af_to_name(int af) {
     if (af == AF_UNIX) {
@@ -332,6 +394,22 @@ static int set_sockaddr_un(struct sockaddr_un *sockaddr,
     return 0;
 }
 
+static int replace_socket(int replaceefd, int replacerfd) {
+    static int (*real_close)(int);
+    LOAD_FUNCTION(real_close, "close");
+
+    /* Replace socket replaceefd with replacerfd. After dup2() both socket fds
+     * point to the same socket (replacerfd). */
+    if (dup2(replacerfd, replaceefd) < 0) {
+        return -1;
+    }
+    /* We don't need replacerfd anymore. The program will use our replacement
+     * and we don't need it for anything else. Use real_close() to prevent
+     * unnecessary debug messages. */
+    real_close(replacerfd);
+    return 0;
+}
+
 
 /* FUNCTIONS OVERWRITTEN BY LD_PRELOAD */
 
@@ -339,35 +417,31 @@ int socket(int domain, int type, int protocol) {
     static int (*real_socket)(int, int, int);
     LOAD_FUNCTION(real_socket, "socket");
 
-    if (domain == AF_UNIX || domain == AF_LOCAL) {
-        return real_socket(domain, type, protocol);
+    /* We return the normal socket because we don't know yet if it's a client
+     * or a listen socket and therefore if we should replace it or not. This
+     * happens in listen() and connect(), see below. */
+
+    int sockfd = real_socket(domain, type, protocol);
+    if (sockfd < 0
+            || domain == AF_UNIX
+            || domain == AF_LOCAL) {
+        return sockfd;
     }
 
     DBG("socket(%s, %s, %d)\n",
         af_to_name(domain), sock_to_name(type), protocol);
 
-    /* We must return the replacement socket in case the program uses select()
-     * or similar on it. */
-
-    int unix_sockfd = real_socket(AF_UNIX, type, 0);
-    if (unix_sockfd < 0) {
-        DIE("bind(): failed to create UNIX socket");
-    }
-
-    struct list *entry = malloc(sizeof(*entry));
-    if (!entry) {
-        DIE("socket(): malloc");
-    }
+    struct list *entry = xmalloc(sizeof(*entry));
     memset(entry, 0, sizeof(*entry));
 
-    entry->unix_sockfd = unix_sockfd;
+    entry->orig_sockfd = sockfd;
     entry->orig_domain = domain;
     entry->orig_type   = type;
 
     entry->next = socket_list.next;
     socket_list.next = entry;
 
-    return unix_sockfd;
+    return sockfd;
 }
 
 int close(int fd) {
@@ -381,7 +455,8 @@ int close(int fd) {
         DBG("close(%d): sockfd not found\n", fd);
         return real_close(fd);
     }
-    assert(fd == entry->unix_sockfd);
+    assert(fd == entry->orig_sockfd);
+    free(entry->orig_addr);
     free(entry);
 
     return real_close(fd);
@@ -404,27 +479,63 @@ int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
         DBG("bind(%d, ..): sockfd not found\n", sockfd);
         return real_bind(sockfd, addr, addrlen);
     }
-    assert(sockfd == entry->unix_sockfd);
+    assert(sockfd == entry->orig_sockfd);
     DBG("bind(%d, ..): %s %s\n",
         sockfd,
         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
 
+    /* Copy struct sockaddr, we need it later in listen(). */
+    entry->orig_addr = xmalloc(addrlen);
+    memcpy(entry->orig_addr, addr, addrlen);
+    entry->orig_addrlen = addrlen;
+
+    return real_bind(sockfd, addr, addrlen);
+}
+
+int listen(int sockfd, int backlog) {
+    static int (*real_listen)(int, int);
+    LOAD_FUNCTION(real_listen, "listen");
+
+    if (!global_options) {
+        global_options = get_options();
+    }
+
+    if (global_options & OPTION_CLIENT_ONLY) {
+        DBG("listen(%d, %d): server hooking disabled\n", sockfd, backlog);
+        return real_listen(sockfd, backlog);
+    }
+
+    struct list *entry = find_sockfd(sockfd);
+    if (!entry) {
+        DBG("listen(%d, %d): sockfd not found\n", sockfd, backlog);
+        return real_listen(sockfd, backlog);
+    }
+    assert(sockfd == entry->orig_sockfd);
+    DBG("listen(%d, %d): %s %s\n",
+        sockfd, backlog,
+        af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
+
+    int unix_sockfd = socket(AF_UNIX, entry->orig_type, 0);
+    if (unix_sockfd < 0) {
+        DIE("listen(): failed to create UNIX socket");
+    }
+
     struct sockaddr_un sockaddr;
-    if (set_sockaddr_un(&sockaddr, addr, addrlen) != 0) {
-        ERROR("connect(%d, ..) failed\n", sockfd);
+    if (set_sockaddr_un(&sockaddr, entry->orig_addr,
+                                   entry->orig_addrlen) != 0) {
+        ERROR("listen(%d, ..) failed\n", sockfd);
     }
 
-    DBG("bind(%d, ..): using path '%s'\n", sockfd, sockaddr.sun_path);
+    DBG("listen(%d, ..): using path '%s'\n", sockfd, sockaddr.sun_path);
 
     int attempts = 0;
     while (attempts < 10) {
-        if (real_bind(entry->unix_sockfd, (struct sockaddr *)&sockaddr,
-                                          sizeof(sockaddr)) == 0) {
-            /* Success. */
-            return 0;
+        if (bind(unix_sockfd, (struct sockaddr *)&sockaddr,
+                              sizeof(sockaddr)) == 0) {
+            break;
         }
         if (errno != EADDRINUSE) {
-            DIE("bind(%d, ..): failed to bind to '%s'",
+            DIE("listen(%d, ..): failed to bind to '%s'",
                 sockfd, sockaddr.sun_path);
         }
 
@@ -435,45 +546,34 @@ int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
         struct stat buf;
         if (lstat(sockaddr.sun_path, &buf) != 0) {
             /* Looks like a race, better abort. */
-            DIE("bind(%d, ..): lstat on UNIX socket '%s' failed",
+            DIE("listen(%d, ..): lstat on UNIX socket '%s' failed",
                 sockfd, sockaddr.sun_path);
         }
 
         if (!S_ISSOCK(buf.st_mode)) {
-            ERROR("bind(%d, ..): path '%s' exits and is no socket\n",
+            ERROR("listen(%d, ..): path '%s' exits and is no socket\n",
                   sockfd, sockaddr.sun_path);
         }
 
-        WARN("bind(%d, ..): unlinking '%s'\n", sockfd, sockaddr.sun_path);
+        WARN("listen(%d, ..): unlinking '%s'\n", sockfd, sockaddr.sun_path);
         if (unlink(sockaddr.sun_path) != 0) {
-            DIE("bind(%d, ..): unlink '%s' failed",
+            DIE("listen(%d, ..): unlink '%s' failed",
                 sockfd, sockaddr.sun_path);
         }
 
         attempts++;
     }
 
-    ERROR("bind(%d, ..): failed to create UNIX socket file\n", sockfd);
-    return -1; /* never reached */
-}
-
-int listen(int sockfd, int backlog) {
-    static int (*real_listen)(int, int);
-    LOAD_FUNCTION(real_listen, "listen");
-
-    DBG("listen(%d, %d)\n", sockfd, backlog);
+    if (attempts == 10) {
+        ERROR("listen(%d, ..): failed to create UNIX socket file\n", sockfd);
+    }
 
-    struct list *entry = find_sockfd(sockfd);
-    if (!entry) {
-        DBG("listen(%d, %d): sockfd not found\n", sockfd, backlog);
-        return real_listen(sockfd, backlog);
+    /* Replace the original socket of the program with our socket. */
+    if (replace_socket(entry->orig_sockfd, unix_sockfd)) {
+        DIE("listen(): failed to replace socket");
     }
-    assert(sockfd == entry->unix_sockfd);
-    DBG("listen(%d, %d): %s %s\n",
-        sockfd, backlog,
-        af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
 
-    if (real_listen(entry->unix_sockfd, backlog) != 0) {
+    if (real_listen(entry->orig_sockfd, backlog) != 0) {
         DIE("listen(): failed to listen");
     }
 
@@ -484,6 +584,15 @@ int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
     static int (*real_accept)(int, struct sockaddr *, socklen_t *);
     LOAD_FUNCTION(real_accept, "accept");
 
+    if (!global_options) {
+        global_options = get_options();
+    }
+
+    if (global_options & OPTION_CLIENT_ONLY) {
+        DBG("accept(%d, ..): server hooking disabled\n", sockfd);
+        return real_accept(sockfd, addr, addrlen);
+    }
+
     DBG("accept(%d, ..)\n", sockfd);
 
     struct list *entry = find_sockfd(sockfd);
@@ -491,14 +600,14 @@ int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
         DBG("accept(%d, ..): sockfd not found\n", sockfd);
         return real_accept(sockfd, addr, addrlen);
     }
-    assert(sockfd == entry->unix_sockfd);
+    assert(sockfd == entry->orig_sockfd);
     DBG("accept(%d, ..): %s %s\n",
         sockfd,
         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
 
     struct sockaddr_un sockaddr;
     socklen_t size = sizeof(sockaddr);
-    int sock = real_accept(entry->unix_sockfd, (struct sockaddr *)&sockaddr,
+    int sock = real_accept(entry->orig_sockfd, (struct sockaddr *)&sockaddr,
                                                &size);
     if (sock < 0) {
         DIE("accept(%d, ..): failed to accept", sockfd);
@@ -529,6 +638,15 @@ int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
     static int (*real_connect)(int, const struct sockaddr *, socklen_t);
     LOAD_FUNCTION(real_connect, "connect");
 
+    if (!global_options) {
+        global_options = get_options();
+    }
+
+    if (global_options & OPTION_SERVER_ONLY) {
+        DBG("connect(%d, ..): client hooking disabled\n", sockfd);
+        return real_connect(sockfd, addr, addrlen);
+    }
+
     DBG("connect(%d, ..)\n", sockfd);
 
     if (addr == NULL || addrlen < sizeof(addr->sa_family)
@@ -542,11 +660,21 @@ int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
         DBG("connect(%d, ..): sockfd not found\n", sockfd);
         return real_connect(sockfd, addr, addrlen);
     }
-    assert(sockfd == entry->unix_sockfd);
+    assert(sockfd == entry->orig_sockfd);
     DBG("connect(%d, ..): %s %s\n",
         sockfd,
         af_to_name(entry->orig_domain), sock_to_name(entry->orig_type));
 
+    int unix_sockfd = socket(AF_UNIX, entry->orig_type, 0);
+    if (unix_sockfd < 0) {
+        DIE("bind(): failed to create UNIX socket");
+    }
+
+    /* Replace the original socket of the program with our socket. */
+    if (replace_socket(entry->orig_sockfd, unix_sockfd)) {
+        DIE("connect(): failed to replace socket");
+    }
+
     struct sockaddr_un sockaddr;
     if (set_sockaddr_un(&sockaddr, addr, addrlen) != 0) {
         ERROR("connect(%d, ..) failed\n", sockfd);
@@ -554,7 +682,7 @@ int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
 
     DBG("connect(%d, ..): using path '%s'\n", sockfd, sockaddr.sun_path);
 
-    if (real_connect(entry->unix_sockfd, (struct sockaddr *)&sockaddr,
+    if (real_connect(entry->orig_sockfd, (struct sockaddr *)&sockaddr,
                                          sizeof(sockaddr)) != 0) {
         DIE("connect(%d, ..): failed to connect", sockfd);
     }