]> ruderich.org/simon Gitweb - tlsproxy/tlsproxy.git/blobdiff - tests/client.c
Update copyright year.
[tlsproxy/tlsproxy.git] / tests / client.c
index d4edf7f8ba6f911fe674c1d1a697e21eb8bb908b..7ed4f05a1519de553a18ef46d381ec5c763d9951 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * Simple GnuTLS client used for testing.
  *
- * Copyright (C) 2011-2013  Simon Ruderich
+ * Copyright (C) 2011-2014  Simon Ruderich
  *
  * This program is free software: you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -20,7 +20,9 @@
 #include <config.h>
 
 #include <arpa/inet.h>
+#include <assert.h>
 #include <errno.h>
+#include <limits.h>
 #include <netdb.h>
 #include <stdio.h>
 #include <stdlib.h>
 
 #define MAX_REQUEST_LINE 4096
 
+static int fdopen_read_write(int socket, FILE **read_fd, FILE **write_fd);
 static int connect_to_host(const char *hostname, const char *port);
 static int read_http_request(FILE *client_fd, char *request, size_t length);
 
+#if 0
+static void log_function_gnutls(int level, const char *string) {
+    (void)level;
+    fprintf(stderr, "    => %s", string);
+}
+#endif
 
 int main (int argc, char *argv[]) {
     int result, response;
     unsigned int status;
     char buffer[MAX_REQUEST_LINE];
     int server;
-    FILE *fd;
+    FILE *fd_read, *fd_write;
 
     gnutls_session_t session;
     gnutls_certificate_credentials_t xcred;
@@ -53,9 +62,10 @@ int main (int argc, char *argv[]) {
     const gnutls_datum_t *cert_list;
     unsigned int cert_list_size;
 
-    if (argc != 5) {
+    if (argc != 5 && argc != 6) {
         fprintf(stderr,
-                "Usage: %s <ca-file> <hostname> <port> <hostname-verify>\n",
+                "Usage: %s <ca-file> <hostname> <port> <hostname-verify> "
+                          "[<digest-authentication>]\n",
                 argv[0]);
         return EXIT_FAILURE;
     }
@@ -63,6 +73,11 @@ int main (int argc, char *argv[]) {
     gnutls_global_init();
     gnutls_certificate_allocate_credentials(&xcred);
 
+#if 0
+    gnutls_global_set_log_level(10);
+    gnutls_global_set_log_function(log_function_gnutls);
+#endif
+
     gnutls_certificate_set_x509_trust_file(xcred,
                                            argv[1], GNUTLS_X509_FMT_PEM);
 
@@ -74,17 +89,18 @@ int main (int argc, char *argv[]) {
     if (server == -1) {
         return EXIT_FAILURE;
     }
-    fd = fdopen(server, "a+");
-    if (fd == NULL) {
-        perror("fdopen()");
+    if (fdopen_read_write(server, &fd_read, &fd_write) != 0) {
         return EXIT_FAILURE;
     }
 
     /* Talk to tlsproxy. */
-    fprintf(fd, "CONNECT %s:%s HTTP/1.0\r\n", argv[2], argv[3]);
-    fprintf(fd, "\r\n");
-    fflush(fd);
-    if (read_http_request(fd, buffer, sizeof(buffer)) == -1) {
+    fprintf(fd_write, "CONNECT %s:%s HTTP/1.0\r\n", argv[2], argv[3]);
+    if (argc == 6) {
+        fprintf(fd_write, "Proxy-Authorization: Basic %s\r\n", argv[5]);
+    }
+    fprintf(fd_write, "\r\n");
+    fflush(fd_write);
+    if (read_http_request(fd_read, buffer, sizeof(buffer)) != 0) {
         fprintf(stderr, "invalid proxy response\n");
         return EXIT_FAILURE;
     }
@@ -101,7 +117,12 @@ int main (int argc, char *argv[]) {
         return EXIT_FAILURE;
     }
 
+#ifdef HAVE_GNUTLS_TRANSPORT_SET_INT2
+    /* gnutls_transport_set_int() is a macro. */
+    gnutls_transport_set_int(session, server);
+#else
     gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)server);
+#endif
 
     result = gnutls_handshake(session);
     if (result != GNUTLS_E_SUCCESS) {
@@ -150,8 +171,14 @@ int main (int argc, char *argv[]) {
 
     gnutls_x509_crt_deinit(cert);
 
+    /* Send a bogus request to the server. Otherwise recent gnutls-serv won't
+     * terminate the connection when gnutls_bye() is used. */
+    gnutls_record_send(session, "GET / HTTP/1.0\r\n\r\n",
+                                strlen("GET / HTTP/1.0\r\n\r\n"));
+
     gnutls_bye(session, GNUTLS_SHUT_RDWR);
-    fclose(fd);
+    fclose(fd_read);
+    fclose(fd_write);
 
     gnutls_deinit(session);
     gnutls_certificate_free_credentials(xcred);
@@ -163,6 +190,24 @@ int main (int argc, char *argv[]) {
 
 /* Copied from src/connection.c (and removed LOG_* stuff)! Don't modify. */
 
+static int fdopen_read_write(int socket, FILE **read_fd, FILE **write_fd) {
+    *read_fd = fdopen(socket, "r");
+    if (*read_fd == NULL) {
+        perror("fdopen_read_write(): fdopen(\"r\") failed");
+        return -1;
+    }
+
+    *write_fd = fdopen(dup(socket), "w");
+    if (*write_fd == NULL) {
+        perror("fdopen_read_write(): fdopen(\"w\") failed");
+        fclose(*read_fd);
+        *read_fd = NULL; /* "tell" caller read_fd is already closed */
+        return -1;
+    }
+
+    return 0;
+}
+
 static int connect_to_host(const char *hostname, const char *port) {
     struct addrinfo gai_hints;
     struct addrinfo *gai_result;
@@ -181,8 +226,10 @@ static int connect_to_host(const char *hostname, const char *port) {
     gai_hints.ai_socktype = SOCK_STREAM;
     gai_hints.ai_protocol = 0;
     gai_hints.ai_flags    = AI_NUMERICSERV /* given port is numeric */
+#ifdef AI_ADDRCONFIG
                           | AI_ADDRCONFIG  /* supported by this computer */
-                          | AI_V4MAPPED;   /* support IPv4 through IPv6 */
+#endif
+                          ;
     gai_return = getaddrinfo(hostname, port, &gai_hints, &gai_result);
     if (gai_return != 0) {
         if (gai_return == EAI_SYSTEM) {
@@ -200,12 +247,12 @@ static int connect_to_host(const char *hostname, const char *port) {
         server_socket = socket(server->ai_family,
                                server->ai_socktype,
                                server->ai_protocol);
-        if (server_socket == -1) {
+        if (server_socket < 0) {
             perror("connect_to_host(): socket(), trying next");
             continue;
         }
 
-        if (connect(server_socket, server->ai_addr, server->ai_addrlen) != -1) {
+        if (connect(server_socket, server->ai_addr, server->ai_addrlen) == 0) {
             break;
         }
         perror("connect_to_host(): connect(), trying next");
@@ -226,16 +273,17 @@ static int connect_to_host(const char *hostname, const char *port) {
 static int read_http_request(FILE *client_fd, char *request, size_t length) {
     char buffer[MAX_REQUEST_LINE];
 
+    assert(length <= INT_MAX);
     if (fgets(request, (int)length, client_fd) == NULL) {
         if (ferror(client_fd)) {
             perror("read_http_request(): fgets()");
             return -1;
         }
-
+        /* EOF */
         return -2;
     }
 
-    while (fgets(buffer, MAX_REQUEST_LINE, client_fd) != NULL) {
+    while (fgets(buffer, sizeof(buffer), client_fd) != NULL) {
         /* End of header. */
         if (!strcmp(buffer, "\n") || !strcmp(buffer, "\r\n")) {
             break;
@@ -244,6 +292,8 @@ static int read_http_request(FILE *client_fd, char *request, size_t length) {
     if (ferror(client_fd)) {
         perror("read_http_request(): fgets()");
         return -1;
+    } else if (feof(client_fd)) {
+        return -2;
     }
 
     return 0;