X-Git-Url: https://ruderich.org/simon/gitweb/?p=tlsproxy%2Ftlsproxy.git;a=blobdiff_plain;f=tests%2Fclient.c;h=7ed4f05a1519de553a18ef46d381ec5c763d9951;hp=3f2dc4b66f7ccaddaaff185b23d59410e96ee5a7;hb=7eba49d24d56288d83746f3f0ce383d7c0c36552;hpb=c4343157f93bfeb4e6de858fdd61b8fb4eddafc2 diff --git a/tests/client.c b/tests/client.c index 3f2dc4b..7ed4f05 100644 --- a/tests/client.c +++ b/tests/client.c @@ -1,7 +1,7 @@ /* * Simple GnuTLS client used for testing. * - * Copyright (C) 2011 Simon Ruderich + * Copyright (C) 2011-2014 Simon Ruderich * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -19,38 +19,41 @@ #include -#include +#include +#include +#include +#include +#include #include +#include #include -/* socket(), connect() */ -#include #include -/* close() */ +#include #include -/* getaddrinfo() */ -#include -/* htons() */ -#include -/* errno */ -#include -/* GnuTLS */ #include #include #define MAX_REQUEST_LINE 4096 +static int fdopen_read_write(int socket, FILE **read_fd, FILE **write_fd); static int connect_to_host(const char *hostname, const char *port); static int read_http_request(FILE *client_fd, char *request, size_t length); +#if 0 +static void log_function_gnutls(int level, const char *string) { + (void)level; + fprintf(stderr, " => %s", string); +} +#endif int main (int argc, char *argv[]) { int result, response; unsigned int status; char buffer[MAX_REQUEST_LINE]; int server; - FILE *fd; + FILE *fd_read, *fd_write; gnutls_session_t session; gnutls_certificate_credentials_t xcred; @@ -59,9 +62,10 @@ int main (int argc, char *argv[]) { const gnutls_datum_t *cert_list; unsigned int cert_list_size; - if (5 != argc) { + if (argc != 5 && argc != 6) { fprintf(stderr, - "Usage: %s \n", + "Usage: %s " + "[]\n", argv[0]); return EXIT_FAILURE; } @@ -69,6 +73,11 @@ int main (int argc, char *argv[]) { gnutls_global_init(); gnutls_certificate_allocate_credentials(&xcred); +#if 0 + gnutls_global_set_log_level(10); + gnutls_global_set_log_function(log_function_gnutls); +#endif + gnutls_certificate_set_x509_trust_file(xcred, argv[1], GNUTLS_X509_FMT_PEM); @@ -77,40 +86,46 @@ int main (int argc, char *argv[]) { gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, xcred); server = connect_to_host("localhost", "4711"); - if (-1 == server) { + if (server == -1) { return EXIT_FAILURE; } - fd = fdopen(server, "a+"); - if (NULL == fd) { - perror("fdopen()"); + if (fdopen_read_write(server, &fd_read, &fd_write) != 0) { return EXIT_FAILURE; } /* Talk to tlsproxy. */ - fprintf(fd, "CONNECT %s:%s HTTP/1.0\r\n", argv[2], argv[3]); - fprintf(fd, "\r\n"); - fflush(fd); - if (-1 == read_http_request(fd, buffer, sizeof(buffer))) { + fprintf(fd_write, "CONNECT %s:%s HTTP/1.0\r\n", argv[2], argv[3]); + if (argc == 6) { + fprintf(fd_write, "Proxy-Authorization: Basic %s\r\n", argv[5]); + } + fprintf(fd_write, "\r\n"); + fflush(fd_write); + if (read_http_request(fd_read, buffer, sizeof(buffer)) != 0) { fprintf(stderr, "invalid proxy response\n"); return EXIT_FAILURE; } printf("response: %s\n", buffer); - if (1 != sscanf(buffer, "HTTP/1.0 %d", &response)) { + if (sscanf(buffer, "HTTP/1.0 %d", &response) != 1) { fprintf(stderr, "invalid proxy response: %s\n", buffer); return EXIT_FAILURE; } - if (200 != response) { + if (response != 200) { fprintf(stderr, "proxy failure\n"); return EXIT_FAILURE; } +#ifdef HAVE_GNUTLS_TRANSPORT_SET_INT2 + /* gnutls_transport_set_int() is a macro. */ + gnutls_transport_set_int(session, server); +#else gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)server); +#endif result = gnutls_handshake(session); - if (GNUTLS_E_SUCCESS != result) { + if (result != GNUTLS_E_SUCCESS) { fprintf(stderr, "gnutls_handshake() failed\n"); gnutls_perror(result); return EXIT_FAILURE; @@ -118,7 +133,7 @@ int main (int argc, char *argv[]) { /* Verify the proxy certificate. */ result = gnutls_certificate_verify_peers2(session, &status); - if (0 > result) { + if (result < 0) { fprintf(stderr, "gnutls_certificate_verify_peers2() failed\n"); gnutls_perror(result); return EXIT_FAILURE; @@ -129,20 +144,20 @@ int main (int argc, char *argv[]) { } /* Get proxy certificate. */ - if (0 > (result = gnutls_x509_crt_init(&cert))) { + if ((result = gnutls_x509_crt_init(&cert)) < 0) { fprintf(stderr, "gnutls_x509_crt_init() failed"); gnutls_perror(result); return EXIT_FAILURE; } cert_list = gnutls_certificate_get_peers(session, &cert_list_size); - if (NULL == cert_list) { + if (cert_list == NULL) { fprintf(stderr, "gnutls_certificate_get_peers() failed"); return EXIT_FAILURE; } - if (0 > (result = gnutls_x509_crt_import(cert, &cert_list[0], - GNUTLS_X509_FMT_DER))) { + if ((result = gnutls_x509_crt_import(cert, &cert_list[0], + GNUTLS_X509_FMT_DER)) < 0) { fprintf(stderr, "gnutls_x509_crt_import() failed"); gnutls_perror(result); return EXIT_FAILURE; @@ -151,12 +166,19 @@ int main (int argc, char *argv[]) { /* Check hostname. */ if (!gnutls_x509_crt_check_hostname(cert, argv[4])) { fprintf(stderr, "hostname didn't match '%s'\n", argv[4]); + return EXIT_FAILURE; } gnutls_x509_crt_deinit(cert); + /* Send a bogus request to the server. Otherwise recent gnutls-serv won't + * terminate the connection when gnutls_bye() is used. */ + gnutls_record_send(session, "GET / HTTP/1.0\r\n\r\n", + strlen("GET / HTTP/1.0\r\n\r\n")); + gnutls_bye(session, GNUTLS_SHUT_RDWR); - fclose(fd); + fclose(fd_read); + fclose(fd_write); gnutls_deinit(session); gnutls_certificate_free_credentials(xcred); @@ -168,6 +190,24 @@ int main (int argc, char *argv[]) { /* Copied from src/connection.c (and removed LOG_* stuff)! Don't modify. */ +static int fdopen_read_write(int socket, FILE **read_fd, FILE **write_fd) { + *read_fd = fdopen(socket, "r"); + if (*read_fd == NULL) { + perror("fdopen_read_write(): fdopen(\"r\") failed"); + return -1; + } + + *write_fd = fdopen(dup(socket), "w"); + if (*write_fd == NULL) { + perror("fdopen_read_write(): fdopen(\"w\") failed"); + fclose(*read_fd); + *read_fd = NULL; /* "tell" caller read_fd is already closed */ + return -1; + } + + return 0; +} + static int connect_to_host(const char *hostname, const char *port) { struct addrinfo gai_hints; struct addrinfo *gai_result; @@ -176,7 +216,7 @@ static int connect_to_host(const char *hostname, const char *port) { int server_socket; struct addrinfo *server; - if (NULL == hostname || NULL == port) { + if (hostname == NULL || port == NULL) { return -1; } @@ -186,27 +226,33 @@ static int connect_to_host(const char *hostname, const char *port) { gai_hints.ai_socktype = SOCK_STREAM; gai_hints.ai_protocol = 0; gai_hints.ai_flags = AI_NUMERICSERV /* given port is numeric */ +#ifdef AI_ADDRCONFIG | AI_ADDRCONFIG /* supported by this computer */ - | AI_V4MAPPED; /* support IPv4 through IPv6 */ +#endif + ; gai_return = getaddrinfo(hostname, port, &gai_hints, &gai_result); - if (0 != gai_return) { - perror("connect_to_host(): getaddrinfo()"); + if (gai_return != 0) { + if (gai_return == EAI_SYSTEM) { + perror("connect_to_host(): getaddrinfo()"); + } else { + fprintf(stderr, "connect_to_host(): getaddrinfo(): %s", + gai_strerror(gai_return)); + } return -1; } /* Now try to connect to each server returned by getaddrinfo(), use the * first successful connect. */ - for (server = gai_result; NULL != server; server = server->ai_next) { + for (server = gai_result; server != NULL; server = server->ai_next) { server_socket = socket(server->ai_family, server->ai_socktype, server->ai_protocol); - if (-1 == server_socket) { + if (server_socket < 0) { perror("connect_to_host(): socket(), trying next"); continue; } - if (-1 != connect(server_socket, server->ai_addr, - server->ai_addrlen)) { + if (connect(server_socket, server->ai_addr, server->ai_addrlen) == 0) { break; } perror("connect_to_host(): connect(), trying next"); @@ -216,7 +262,7 @@ static int connect_to_host(const char *hostname, const char *port) { /* Make sure we free the result from getaddrinfo(). */ freeaddrinfo(gai_result); - if (NULL == server) { + if (server == NULL) { perror("connect_to_host(): no server found, abort"); return -1; } @@ -227,24 +273,27 @@ static int connect_to_host(const char *hostname, const char *port) { static int read_http_request(FILE *client_fd, char *request, size_t length) { char buffer[MAX_REQUEST_LINE]; - if (NULL == fgets(request, (int)length, client_fd)) { + assert(length <= INT_MAX); + if (fgets(request, (int)length, client_fd) == NULL) { if (ferror(client_fd)) { perror("read_http_request(): fgets()"); return -1; } - + /* EOF */ return -2; } - while (NULL != fgets(buffer, MAX_REQUEST_LINE, client_fd)) { + while (fgets(buffer, sizeof(buffer), client_fd) != NULL) { /* End of header. */ - if (0 == strcmp(buffer, "\n") || 0 == strcmp(buffer, "\r\n")) { + if (!strcmp(buffer, "\n") || !strcmp(buffer, "\r\n")) { break; } } if (ferror(client_fd)) { perror("read_http_request(): fgets()"); return -1; + } else if (feof(client_fd)) { + return -2; } return 0;