]> ruderich.org/simon Gitweb - tlsproxy/tlsproxy.git/blobdiff - src/connection.c
src/connection.c: Minor cleanup.
[tlsproxy/tlsproxy.git] / src / connection.c
index ed8a59411892bb7f89b12642c9482b0f5bb30850..700a3548345398a436aaf9a26b5144714e85f230 100644 (file)
 #include <errno.h>
 
 
-/* Maximum line of a HTTP request line. Longer request lines are aborted with
- * an error. The standard doesn't specify a maximum line length but this
- * should be a good limit to make processing simpler. */
+/* Maximum length of a HTTP request line. Longer request lines are aborted
+ * with an error. The standard doesn't specify a maximum line length but this
+ * should be a good limit to make processing simpler. As HTTPS is used this
+ * doesn't limit long GET requests. */
 #define MAX_REQUEST_LINE 4096
 
 /* Format string used to send HTTP/1.0 error responses to the client.
@@ -71,7 +72,8 @@ static int read_from_write_to(int from, int to);
 static void transfer_data_tls(int client, int server,
                               gnutls_session_t client_session,
                               gnutls_session_t server_session);
-static int read_from_write_to_tls(gnutls_session_t from, gnutls_session_t to);
+static int read_from_write_to_tls(gnutls_session_t from, gnutls_session_t to,
+                                  size_t buffer_size);
 
 static int connect_to_host(const char *hostname, const char *port);
 
@@ -200,7 +202,7 @@ void handle_connection(int client_socket) {
         char path[1024];
         FILE *file = NULL;
 
-        if (-2 == server_certificate_path(&file, host, path, sizeof(path))) {
+        if (-2 == server_certificate_file(&file, host, path, sizeof(path))) {
             /* We've established a connection, tell the client. */
             fprintf(client_fd, "HTTP/1.0 200 Connection established\r\n");
             fprintf(client_fd, "\r\n");
@@ -216,7 +218,7 @@ void handle_connection(int client_socket) {
 
             goto out;
         }
-        /* server_certificate_path() may open the file, close it. */
+        /* server_certificate_file() may have opened the file, close it. */
         if (NULL != file) {
             fclose(file);
         }
@@ -257,7 +259,7 @@ void handle_connection(int client_socket) {
 
     /* Initialize TLS server credentials to talk to the client. */
     result = initialize_tls_session_client(client_socket,
-                                           /* use special host if the server
+                                           /* use special host if the server
                                             * certificate was invalid */
                                            (validation_failed) ? "invalid"
                                                                : host,
@@ -355,30 +357,12 @@ static int initialize_tls_session_client(int peer_socket,
 
     /* The "invalid" hostname is special. If it's used we send an invalid
      * certificate to let the client know something is wrong. */
-    use_invalid_cert = 0 == strcmp(hostname, "invalid");
+    use_invalid_cert = (0 == strcmp(hostname, "invalid"));
 
-    /* Hostname too long. */
-    if (sizeof(path) - strlen(PROXY_SERVER_CERT_FORMAT) <= strlen(hostname)) {
-        LOG(LOG_WARNING,
-            "initialize_tls_session_client(): hostname too long: '%s'",
-            hostname);
-        return -1;
-    }
-    /* Try to prevent path traversals in hostnames. */
-    if (NULL != strstr(hostname, "..")) {
-        LOG(LOG_WARNING,
-            "initialize_tls_session_client(): possible path traversal: '%s'",
-            hostname);
-        return -1;
-    }
-    result = snprintf(path, sizeof(path), PROXY_SERVER_CERT_FORMAT, hostname);
-    if (result < 0) {
-        LOG_PERROR(LOG_ERROR,
-                   "initialize_tls_session_client(): snprintf failed");
-        return -1;
-    } else if ((size_t)result >= sizeof(path)) {
+    if (0 != proxy_certificate_path(hostname, path, sizeof(path))) {
         LOG(LOG_ERROR,
-            "initialize_tls_session_client(): snprintf buffer too short");
+            "initialize_tls_session_client(): \
+failed to get proxy certificate path");
         return -1;
     }
 
@@ -528,7 +512,7 @@ static int read_http_request(FILE *client_fd, char *request, size_t length) {
         return -2;
     }
 
-    while (NULL != fgets(buffer, MAX_REQUEST_LINE, client_fd)) {
+    while (NULL != fgets(buffer, sizeof(buffer), client_fd)) {
         /* End of header. */
         if (0 == strcmp(buffer, "\n") || 0 == strcmp(buffer, "\r\n")) {
             break;
@@ -674,6 +658,8 @@ static int read_from_write_to(int from, int to) {
 static void transfer_data_tls(int client, int server,
                               gnutls_session_t client_session,
                               gnutls_session_t server_session) {
+    size_t buffer_size;
+
     struct pollfd fds[2];
     fds[0].fd      = client;
     fds[0].events  = POLLIN | POLLPRI | POLLHUP | POLLERR;
@@ -682,6 +668,14 @@ static void transfer_data_tls(int client, int server,
     fds[1].events  = POLLIN | POLLPRI | POLLHUP | POLLERR;
     fds[1].revents = 0;
 
+    /* Get maximum possible buffer size. */
+    buffer_size = gnutls_record_get_max_size(client_session);
+    if (buffer_size > gnutls_record_get_max_size(server_session)) {
+        buffer_size = gnutls_record_get_max_size(server_session);
+    }
+    LOG(LOG_DEBUG, "transfer_data_tls(): suggested buffer size: %ld",
+                   (long int)buffer_size);
+
     for (;;) {
         int result = poll(fds, 2, -1 /* no timeout */);
         if (result < 0) {
@@ -691,14 +685,16 @@ static void transfer_data_tls(int client, int server,
 
         /* Data available from client. */
         if (fds[0].revents & POLLIN || fds[0].revents & POLLPRI) {
-            if (0 != read_from_write_to_tls(client_session, server_session)) {
+            if (0 != read_from_write_to_tls(client_session, server_session,
+                                            buffer_size)) {
                 /* EOF (or other error) */
                 break;
             }
         }
         /* Data available from server. */
         if (fds[1].revents & POLLIN || fds[1].revents & POLLPRI) {
-            if (0 != read_from_write_to_tls(server_session, client_session)) {
+            if (0 != read_from_write_to_tls(server_session, client_session,
+                                            buffer_size)) {
                 /* EOF (or other error) */
                 break;
             }
@@ -717,26 +713,19 @@ static void transfer_data_tls(int client, int server,
 
 /* Read available data from session from and write to session to. */
 static int read_from_write_to_tls(gnutls_session_t from,
-                                  gnutls_session_t to) {
-    size_t size;
+                                  gnutls_session_t to,
+                                  size_t buffer_size) {
     ssize_t size_read;
     ssize_t size_written;
     char buffer[16384];
 
-    /* Get maximum possible buffer size. */
-    size = gnutls_record_get_max_size(from);
-    LOG(LOG_DEBUG, "read_from_write_to_tls(): suggested buffer size: %ld",
-                   (long int)size);
-    if (size > gnutls_record_get_max_size(to)) {
-        size = gnutls_record_get_max_size(to);
-    }
-    if (size > sizeof(buffer)) {
-        size = sizeof(buffer);
+    if (buffer_size > sizeof(buffer)) {
+        buffer_size = sizeof(buffer);
     }
     LOG(LOG_DEBUG, "read_from_write_to_tls(): used buffer size: %ld",
-                   (long int)size);
+                   (long int)buffer_size);
 
-    size_read = gnutls_record_recv(from, buffer, size);
+    size_read = gnutls_record_recv(from, buffer, buffer_size);
     if (0 > size_read) {
         LOG(LOG_WARNING, "read_from_write_to_tls(): gnutls_record_recv(): %s",
                          gnutls_strerror((int)size_read));