]> ruderich.org/simon Gitweb - tlsproxy/tlsproxy.git/blobdiff - src/connection.c
tlsproxy.h: Sort includes.
[tlsproxy/tlsproxy.git] / src / connection.c
index 710963c8f0bb96b4ab37d626b2d1bf0f0f04db1f..82b13bf0d04f5a316968bb18577378a2d47f8c47 100644 (file)
 #include "connection.h"
 #include "verify.h"
 
-/* close() */
-#include <unistd.h>
-/* getaddrinfo() */
+#include <errno.h>
 #include <netdb.h>
-/* poll() */
 #include <poll.h>
-/* errno */
-#include <errno.h>
+#include <unistd.h>
 
 
 /* Maximum length of a HTTP request line. Longer request lines are aborted
@@ -89,7 +85,7 @@ void handle_connection(int client_socket) {
     char host[MAX_REQUEST_LINE];
     char port[5 + 1];
 
-    int version_minor; /* HTTP/1.x */
+    int version_minor; /* x in HTTP/1.x */
     int result;
 
     /* client_x509_cred is used when talking to the client (acting as a TSL
@@ -126,11 +122,9 @@ void handle_connection(int client_socket) {
     /* Read request line (CONNECT ..) and headers (they are discarded). */
     result = read_http_request(client_fd, buffer, sizeof(buffer));
     if (result == -1) {
-        /* Read error. */
         LOG(LOG_WARNING, "read_http_request(): client read error");
         goto out;
     } else if (result == -2) {
-        /* EOF */
         LOG(LOG_WARNING, "read_http_request(): client EOF");
         send_bad_request(client_fd);
         goto out;
@@ -170,23 +164,22 @@ void handle_connection(int client_socket) {
     if (global_proxy_host != NULL && global_proxy_port != NULL) {
         fprintf(server_fd, "CONNECT %s:%s HTTP/1.0\r\n", host, port);
         fprintf(server_fd, "\r\n");
+        fflush(server_fd);
 
         /* Read response line from proxy server. */
         result = read_http_request(server_fd, buffer, sizeof(buffer));
         if (result == -1) {
-            /* Read error. */
             LOG(LOG_WARNING, "read_http_request(): proxy read error");
             send_forwarding_failure(client_fd);
             goto out;
         } else if (result == -2) {
-            /* EOF */
             LOG(LOG_WARNING, "read_http_request(): proxy EOF");
             send_forwarding_failure(client_fd);
             goto out;
         }
 
         /* Check response of proxy server. */
-        if (strncmp(buffer, "HTTP/1.0 200", 12) != 0) {
+        if (strncmp(buffer, "HTTP/1.0 200", 12)) {
             LOG(LOG_WARNING, "bad proxy response: %s", buffer);
             send_forwarding_failure(client_fd);
             goto out;
@@ -251,7 +244,7 @@ void handle_connection(int client_socket) {
     /* Make sure the server certificate is valid and known. */
     if (verify_tls_connection(server_session, host) != 0) {
         LOG(LOG_ERROR, "server certificate validation failed!");
-        /* We send the error message over our TLS connection to the client,
+        /* We'll send the error message over our TLS connection to the client,
          * but with an invalid certificate. No data is transfered from/to the
          * target server. */
         validation_failed = 1;
@@ -311,17 +304,17 @@ void handle_connection(int client_socket) {
 out:
     /* Close TLS sessions if necessary. Use GNUTLS_SHUT_RDWR so the data is
      * reliable transmitted. */
-    if (server_session_started != 0) {
+    if (server_session_started) {
         gnutls_bye(server_session, GNUTLS_SHUT_RDWR);
     }
-    if (client_session_started != 0) {
+    if (client_session_started) {
         gnutls_bye(client_session, GNUTLS_SHUT_RDWR);
     }
-    if (server_session_init != 0) {
+    if (server_session_init) {
         gnutls_deinit(server_session);
         gnutls_certificate_free_credentials(server_x509_cred);
     }
-    if (client_session_init != 0) {
+    if (client_session_init) {
         gnutls_deinit(client_session);
         gnutls_certificate_free_cas(client_x509_cred);
         gnutls_certificate_free_keys(client_x509_cred);
@@ -357,7 +350,7 @@ static int initialize_tls_session_client(int peer_socket,
 
     /* The "invalid" hostname is special. If it's used we send an invalid
      * certificate to let the client know something is wrong. */
-    use_invalid_cert = (strcmp(hostname, "invalid") == 0);
+    use_invalid_cert = (!strcmp(hostname, "invalid"));
 
     if (proxy_certificate_path(hostname, path, sizeof(path)) != 0) {
         LOG(LOG_ERROR,
@@ -508,13 +501,13 @@ static int read_http_request(FILE *client_fd, char *request, size_t length) {
             LOG_PERROR(LOG_WARNING, "read_http_request(): fgets()");
             return -1;
         }
-
+        /* EOF */
         return -2;
     }
 
     while (fgets(buffer, sizeof(buffer), client_fd) != NULL) {
         /* End of header. */
-        if (strcmp(buffer, "\n") == 0 || strcmp(buffer, "\r\n") == 0) {
+        if (!strcmp(buffer, "\n") || !strcmp(buffer, "\r\n")) {
             break;
         }
     }
@@ -532,6 +525,7 @@ static void send_bad_request(FILE *client_fd) {
     fprintf(client_fd, HTTP_RESPONSE_FORMAT,
                        RESPONSE_ERROR, RESPONSE_ERROR, RESPONSE_ERROR,
                        RESPONSE_MSG);
+    fflush(client_fd);
 #undef RESPONSE_ERROR
 #undef RESPONSE_MSG
 }
@@ -541,6 +535,7 @@ static void send_forwarding_failure(FILE *client_fd) {
     fprintf(client_fd, HTTP_RESPONSE_FORMAT,
                        RESPONSE_ERROR, RESPONSE_ERROR, RESPONSE_ERROR,
                        RESPONSE_MSG);
+    fflush(client_fd);
 #undef RESPONSE_ERROR
 #undef RESPONSE_MSG
 }
@@ -588,7 +583,7 @@ static void transfer_data(int client, int server) {
     fds[1].revents = 0;
 
     for (;;) {
-        int result = poll(fds, 2, -1 /* no timeout */);
+        int result = poll(fds, 2 /* fd count */, -1 /* no timeout */);
         if (result < 0) {
             LOG_PERROR(LOG_ERROR, "transfer_data(): poll()");
             return;
@@ -633,9 +628,8 @@ static int read_from_write_to(int from, int to) {
     if (size_read < 0) {
         LOG_PERROR(LOG_WARNING, "read_from_write_to(): read()");
         return -1;
-    }
     /* EOF */
-    if (size_read == 0) {
+    } else if (size_read == 0) {
         return -1;
     }
 
@@ -677,7 +671,7 @@ static void transfer_data_tls(int client, int server,
                    (long int)buffer_size);
 
     for (;;) {
-        int result = poll(fds, 2, -1 /* no timeout */);
+        int result = poll(fds, 2 /* fd count */, -1 /* no timeout */);
         if (result < 0) {
             LOG_PERROR(LOG_ERROR, "transfer_data(): poll()");
             return;
@@ -730,9 +724,8 @@ static int read_from_write_to_tls(gnutls_session_t from,
         LOG(LOG_WARNING, "read_from_write_to_tls(): gnutls_record_recv(): %s",
                          gnutls_strerror((int)size_read));
         return -1;
-    }
     /* EOF */
-    if (size_read == 0) {
+    } else if (size_read == 0) {
         return -1;
     }
 
@@ -774,7 +767,12 @@ static int connect_to_host(const char *hostname, const char *port) {
                           | AI_V4MAPPED;   /* support IPv4 through IPv6 */
     gai_return = getaddrinfo(hostname, port, &gai_hints, &gai_result);
     if (gai_return != 0) {
-        LOG_PERROR(LOG_WARNING, "connect_to_host(): getaddrinfo()");
+        if (gai_return == EAI_SYSTEM) {
+            LOG_PERROR(LOG_WARNING, "connect_to_host(): getaddrinfo()");
+        } else {
+            LOG(LOG_WARNING, "connect_to_host(): getaddrinfo(): %s",
+                             gai_strerror(gai_return));
+        }
         return -1;
     }
 
@@ -821,7 +819,7 @@ static int parse_request(const char *request, char *host, char *port,
     char *position;
 
     /* scanf() doesn't check spaces. */
-    if (strncmp(request, "CONNECT ", 8) != 0) {
+    if (strncmp(request, "CONNECT ", 8)) {
         return -1;
     }
     /* Check request and extract data, "host:port" is not yet separated. */