]> ruderich.org/simon Gitweb - tlsproxy/tlsproxy.git/blobdiff - tlsproxy.c
tlsproxy.c: Free resources when receiving SIGINT.
[tlsproxy/tlsproxy.git] / tlsproxy.c
index b417a9871c6f84829b7ada913f1461d58a89fac0..157633bf20e348ea1613879b0c3584479aefa652 100644 (file)
 #include <netdb.h>
 /* strncmp() */
 #include <string.h>
+/* sigaction() */
+#include <signal.h>
 /* poll() */
 #include <poll.h>
 
 
+/* Maximum line of the request line. Longer request lines are aborted with an
+ * error. The standard doesn't specify a maximum line length but this should
+ * be a good limit to make processing simpler. */
+#define MAX_REQUEST_LINE 4096
+
+
+/* Server should shut down. Set by SIGINT handler. */
+static volatile int done;
+
+/* Proxy hostname and port if specified on the command line. */
+static char *use_proxy_host;
+static char *use_proxy_port;
+
+
+static void sigint_handler(int signal);
+
+static void parse_arguments(int argc, char **argv);
+static void print_usage(const char *argv);
+
 static void handle_connection(int socket);
+static int read_http_request(FILE *client_fd, char *request, size_t length);
 static void send_close_bad_request(FILE *client_fd);
 static void send_close_forwarding_failure(FILE *client_fd);
 
@@ -52,19 +74,24 @@ int main(int argc, char **argv) {
     int client_socket, server_socket;
     struct sockaddr_in6 server_in;
 
-    if (2 != argc) {
-        printf("Usage: %s port\n", argv[0]);
-        return EXIT_FAILURE;
-    }
+    struct sigaction action;
 
-    port = atoi(argv[1]);
+    parse_arguments(argc, argv);
+
+    port = atoi(argv[argc - 1]);
     if (0 >= port || 0xffff < port) {
-        printf("Usage: %s port\n", argv[0]);
-        printf("\n");
-        printf("Invalid port: %s!\n", argv[1]);
+        print_usage(argv[0]);
+        fprintf(stderr, "\ninvalid port");
         return EXIT_FAILURE;
     }
 
+    /* Setup our SIGINT signal handler which allows a "normal" termination of
+     * the server. */
+    sigemptyset(&action.sa_mask);
+    action.sa_handler = sigint_handler;
+    action.sa_flags   = 0;
+    sigaction(SIGINT, &action, NULL);
+
     server_socket = socket(PF_INET6, SOCK_STREAM, 0);
     if (-1 == server_socket) {
         perror("socket()");
@@ -96,29 +123,102 @@ int main(int argc, char **argv) {
         return EXIT_FAILURE;
     }
 
-    for (;;) {
+#ifdef DEBUG
+    printf("Listening for connections on port %d.\n", port);
+
+    if (NULL != use_proxy_host && NULL != use_proxy_port) {
+        printf("Using proxy: %s:%s.\n", use_proxy_host, use_proxy_port);
+    }
+#endif
+
+    while (!done) {
         /* Accept new connection. */
         client_socket = accept(server_socket, NULL, NULL);
         if (-1 == client_socket) {
             perror("accept()");
-            return EXIT_FAILURE;
+            break;
         }
 
         handle_connection(client_socket);
     }
 
-    return EXIT_SUCCESS;
+    close(server_socket);
+
+    free(use_proxy_host);
+    free(use_proxy_port);
+
+    return EXIT_FAILURE;
+}
+
+static void sigint_handler(int signal_number) {
+    (void)signal_number;
+
+    done = 1;
+}
+
+static void parse_arguments(int argc, char **argv) {
+    int option;
+
+    while (-1 != (option = getopt(argc, argv, "p:h?"))) {
+        switch (option) {
+            case 'p': {
+                char *position;
+
+                /* -p must have the format host:port. */
+                if (NULL == (position = strchr(optarg, ':'))
+                        || position == optarg
+                        || 0 == strlen(position + 1)
+                        || 0 >= atoi(position + 1)
+                        || 0xffff < atoi(position + 1)) {
+                    fprintf(stderr, "-p host:port\n");
+                    exit(EXIT_FAILURE);
+                }
+
+                use_proxy_host = malloc((size_t)(position - optarg) + 1);
+                if (NULL == use_proxy_host) {
+                    perror("malloc()");
+                    exit(EXIT_FAILURE);
+                }
+                memcpy(use_proxy_host, optarg, (size_t)(position - optarg));
+                use_proxy_host[position - optarg] = '\0';
+
+                use_proxy_port = malloc(strlen(position + 1) + 1);
+                if (NULL == use_proxy_port) {
+                    perror("malloc()");
+                    exit(EXIT_FAILURE);
+                }
+                strcpy(use_proxy_port, position + 1);
+
+                break;
+            }
+            case 'h':
+            default: /* '?' */
+                print_usage(argv[0]);
+                exit(EXIT_FAILURE);
+        }
+    }
+
+    if (optind >= argc) {
+        print_usage(argv[0]);
+        exit(EXIT_FAILURE);
+    }
+}
+static void print_usage(const char *argv) {
+    fprintf(stderr, "Usage: %s [-p host:port] port\n", argv);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "-p proxy hostname and port\n");
 }
 
 static void handle_connection(int client_socket) {
     int server_socket;
     FILE *client_fd, *server_fd;
 
-    char buffer[4096];
-    char host[4096];
+    char buffer[MAX_REQUEST_LINE];
+    char host[MAX_REQUEST_LINE];
     char port[5 + 1];
 
     int version_minor;
+    int result;
 
     client_fd = fdopen(client_socket, "a+");
     if (NULL == client_fd) {
@@ -131,13 +231,13 @@ static void handle_connection(int client_socket) {
     printf("New connection:\n");
 #endif
 
-    if (NULL == fgets(buffer, sizeof(buffer), client_fd)) {
-        if (ferror(client_fd)) {
-            perror("fgets(), request");
-            fclose(client_fd);
-            return;
-        }
-
+    /* Read request line (CONNECT ..) and headers (they are discarded). */
+    result = read_http_request(client_fd, buffer, sizeof(buffer));
+    if (result == -1) {
+        /* Read error. */
+        return;
+    } else if (result == -2) {
+        /* EOF */
         send_close_bad_request(client_fd);
         return;
     }
@@ -154,23 +254,17 @@ static void handle_connection(int client_socket) {
         return;
     }
 
-    while (NULL != fgets(buffer, sizeof(buffer), client_fd)) {
-        /* End of header. */
-        if (0 == strcmp(buffer, "\n") || 0 == strcmp(buffer, "\r\n")) {
-            break;
-        }
-    }
-    if (ferror(client_fd)) {
-        perror("fgets(), header");
-        fclose(client_fd);
-        return;
-    }
-
 #ifdef DEBUG
     printf("  %s:%s (HTTP 1.%d)\n", host, port, version_minor);
 #endif
 
-    server_socket = connect_to_host(host, port);
+    /* Connect to proxy server or directly to server. */
+    if (NULL != use_proxy_host && NULL != use_proxy_port) {
+        server_socket = connect_to_host(use_proxy_host, use_proxy_port);
+    } else {
+        server_socket = connect_to_host(host, port);
+    }
+
     if (-1 == server_socket) {
         send_close_forwarding_failure(client_fd);
         return;
@@ -181,6 +275,39 @@ static void handle_connection(int client_socket) {
         return;
     }
 
+    /* Connect to proxy if requested (command line option). */
+    if (NULL != use_proxy_host && NULL != use_proxy_port) {
+        fprintf(server_fd, "CONNECT %s:%s HTTP/1.0\r\n", host, port);
+        fprintf(server_fd, "\r\n");
+
+        /* Read response line from proxy server. */
+        result = read_http_request(server_fd, buffer, sizeof(buffer));
+        if (result == -1) {
+            /* Read error. */
+            send_close_forwarding_failure(client_fd);
+            return;
+        } else if (result == -2) {
+            /* EOF */
+            fclose(server_fd);
+            send_close_forwarding_failure(client_fd);
+            return;
+        }
+
+        /* Check response of proxy server. */
+        if (0 != strncmp(buffer, "HTTP/1.0 200", 12)) {
+#ifdef DEBUG
+            printf("  bad proxy response\n");
+#endif
+            fclose(server_fd);
+            send_close_forwarding_failure(client_fd);
+            return;
+        }
+    }
+
+#ifdef DEBUG
+    printf("  connection to server established\n");
+#endif
+
     /* We've established a connection, tell the client. */
     fprintf(client_fd, "HTTP/1.0 200 Connection established\r\n");
     fprintf(client_fd, "\r\n");
@@ -193,6 +320,39 @@ static void handle_connection(int client_socket) {
     fclose(server_fd);
 }
 
+/* Read HTTP request line and headers (ignored).
+ *
+ * On success 0 is returned, -1 on client error (we close client descriptor in
+ * this case), -2 on unexpected EOF.
+ */
+static int read_http_request(FILE *client_fd, char *request, size_t length) {
+    char buffer[MAX_REQUEST_LINE];
+
+    if (NULL == fgets(request, (int)length, client_fd)) {
+        if (ferror(client_fd)) {
+            perror("fgets(), request");
+            fclose(client_fd);
+            return -1;
+        }
+
+        return -2;
+    }
+
+    while (NULL != fgets(buffer, MAX_REQUEST_LINE, client_fd)) {
+        /* End of header. */
+        if (0 == strcmp(buffer, "\n") || 0 == strcmp(buffer, "\r\n")) {
+            break;
+        }
+    }
+    if (ferror(client_fd)) {
+        perror("fgets(), header");
+        fclose(client_fd);
+        return -1;
+    }
+
+    return 0;
+}
+
 static void send_close_bad_request(FILE *client_fd) {
     fprintf(client_fd, "HTTP/1.0 400 Bad Request\r\n");
     fprintf(client_fd, "\r\n");