hoge.
[elisp/starttls.git] / starttls.c
index 27ef839..ea60698 100644 (file)
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
-
 #include <unistd.h>
-
-#include <openssl/lhash.h>
-#include <openssl/bn.h>
-#include <openssl/err.h>
-#include <openssl/pem.h>
-#include <openssl/x509.h>
-#include <openssl/ssl.h>
+#include <gnutls.h>
+#include <errno.h>
 
 #ifdef HAVE_SOCKS_H
 #include <socks.h>
 #include <signal.h>
 #include <fcntl.h>
 #include <netinet/in.h>
-#define _GNU_SOURCE
-#include <getopt.h>
-
-static SSL_CTX *tls_ctx = NULL;
-static SSL *tls_conn = NULL;
-static int tls_fd;
-
-static char *opt_cert_file = NULL, *opt_key_file = NULL;
-static int  opt_verify = 0;
-
-static int
-tls_ssl_ctx_new (cert_file, key_file)
-  const char *cert_file, *key_file;
-{
-  SSL_load_error_strings ();
-  SSLeay_add_ssl_algorithms ();
-
-  tls_ctx = SSL_CTX_new (TLSv1_client_method());
-  if (!tls_ctx)
-    return -1;
-
-  SSL_CTX_set_options (tls_ctx, SSL_OP_ALL /* Work around all known bugs */); 
-
-  if (cert_file)
-    {
-      if (SSL_CTX_use_certificate_file (tls_ctx, cert_file, 
-                                       SSL_FILETYPE_PEM) <= 0)
-       return -1;
-      if (!key_file)
-       key_file = cert_file;
-      if (SSL_CTX_use_PrivateKey_file (tls_ctx, key_file, 
-                                      SSL_FILETYPE_PEM) <= 0)
-       return -1;
-      if (!SSL_CTX_check_private_key (tls_ctx))
-       return -1;
-    }
-
-  SSL_CTX_set_verify (tls_ctx, SSL_VERIFY_NONE, NULL);
 
-  return 0;
-}
+static int starttls_connect (const char *, const char *);
+static void tls_negotiate (int);
+static ssize_t tls_write (int, const void *, size_t);
+static ssize_t tls_read (int, void *, size_t);
+static int tls_close (int);
+static int raw_is_fatal_error (int);
+static void usage (char *);
 
-static int
-tls_ssl_new(ctx, s)
-  SSL_CTX *ctx;
-  int s;
+static GNUTLS_STATE tls_state;
+typedef struct
 {
-  SSL_SESSION *session;
-  SSL_CIPHER *cipher;
-  X509   *peer;
-
-  tls_conn = (SSL *) SSL_new (ctx);
-  if (!tls_conn)
-    return -1;
-  SSL_clear(tls_conn);
-
-  if (!SSL_set_fd (tls_conn, s))
-    return -1;
-
-  SSL_set_connect_state (tls_conn);
-
-  if (SSL_connect (tls_conn) <= 0)
-    {
-      session = SSL_get_session (tls_conn);
-      if (session)
-       SSL_CTX_remove_session (ctx, session);
-      if (tls_conn!=NULL)
-       SSL_free (tls_conn);
-      return -1;
-    }
-
-  return 0;
-}
+  ssize_t (*write) (int, const void *, size_t);
+  ssize_t (*read) (int, void *, size_t);
+  int (*close) (int);
+  int (*is_fatal_error) (int);
+} starttls_functions_t;
+
+static starttls_functions_t tls_functions =
+  {
+    tls_write,
+    tls_read,
+    tls_close,
+    gnutls_is_fatal_error
+  };
+
+static starttls_functions_t raw_functions =
+  {
+    write,
+    read,
+    close,
+    raw_is_fatal_error
+  };
+
+static starttls_functions_t *starttls_functions;
+static int starttls_fd;
 
 static int
-tls_connect (hostname, service)
+starttls_connect (hostname, service)
      const char *hostname, *service;
 {
   struct protoent *proto;
@@ -168,30 +125,75 @@ tls_connect (hostname, service)
   return server;
 }
 
+static ssize_t
+tls_write (fd, buf, count)
+     int fd;
+     const void *buf;
+     size_t count;
+{
+  return gnutls_write (fd, tls_state, (char *)buf, count);
+}
+
+static ssize_t
+tls_read (fd, buf, count)
+     int fd;
+     void *buf;
+     size_t count;
+{
+  return gnutls_read (fd, tls_state, buf, count);
+}
+
+static int
+tls_close(fd)
+     int fd;
+{
+  gnutls_close(fd, tls_state);
+  gnutls_deinit(&tls_state);
+}
+       
+static int
+raw_is_fatal_error (error)
+     int error;
+{
+  return (error < 0 ? 1 : 0);
+}
+
 static void
 tls_negotiate (sig)
      int sig;
 {
-  if (tls_ssl_ctx_new (opt_cert_file, opt_key_file) == -1)
-    return;
-
-  (void) tls_ssl_new (tls_ctx, tls_fd); /* Negotiation has done. */
+  int error;
+
+  gnutls_init (&tls_state, GNUTLS_CLIENT);
+  gnutls_set_current_version (tls_state, GNUTLS_TLS1);
+  gnutls_set_cipher_priority (tls_state, 4, GNUTLS_3DES, GNUTLS_TWOFISH,
+                             GNUTLS_RIJNDAEL, GNUTLS_ARCFOUR);
+  gnutls_set_compression_priority (tls_state, 2, GNUTLS_ZLIB, GNUTLS_NULL_COMPRESSION);
+  gnutls_set_kx_priority (tls_state, 3, GNUTLS_KX_ANON_DH, GNUTLS_KX_DHE_DSS,
+                         GNUTLS_KX_DHE_RSA);
+  gnutls_set_mac_priority (tls_state, 2, GNUTLS_MAC_SHA, GNUTLS_MAC_MD5);
+  error = gnutls_handshake(starttls_fd, tls_state);
+
+  if (error < 0)
+    {
+      starttls_functions->close (starttls_fd);
+      gnutls_perror (error);
+      exit (-1);
+    }
+  else
+    starttls_functions = &tls_functions;
 }
 
 static void
 usage (progname)
-     const char *progname;
+     char *progname;
 {
   printf ("%s (%s) %s\n"
-         "Copyright (C) 1999 Free Software Foundation, Inc.\n"
+         "Copyright (C) 2001 Free Software Foundation, Inc.\n"
          "This program comes with ABSOLUTELY NO WARRANTY.\n"
          "This is free software, and you are welcome to redistribute it\n"
          "under certain conditions. See the file COPYING for details.\n\n"
-         "Usage: %s [options] host port\n\n"
-         "Options:\n\n"
-         " --cert-file [file]      specify certificate file\n"
-         " --key-file [file]       specify private key file\n"
-         " --verify [level]        set verification level\n",
+         "Usage: %s host port\n\n",
          progname, PACKAGE, VERSION, progname);
 }
      
@@ -206,49 +208,17 @@ main (argc, argv)
   char buffer[BUFSIZ], *retry;
   struct sigaction act;
 
-  int this_option_optind = optind ? optind : 1;
-  int option_index = 0, c;
-  static struct option long_options[] =
-    {
-      {"cert-file", 1, 0, 'c'},
-      {"key-file", 1, 0, 'k'},
-      {"verify", 1, 0, 'v'},
-      {0, 0, 0, 0}
-    };
-
-  while (1)
-    {
-      c = getopt_long (argc, argv, "c:k:v:f", long_options, &option_index);
-      if (c == -1)
-       break;
-    
-      switch (c)
-       {
-       case 'c':
-         opt_cert_file = optarg;
-         break;
-       case 'k':
-         opt_key_file = optarg;
-         break;
-       case 'v':
-         opt_verify = atoi (optarg);
-         break;
-       default:
-         usage (basename (argv[0]));
-         return 1;
-       }
-    }
-
-  if (optind+2 != argc)
+  if (argc != 3)
     {
-      usage (basename (argv[0]));
+      usage ((char *)basename (argv[0]));
       return 1;
     }
 
-  tls_fd = tls_connect (argv[optind], argv[optind+1]);
-  if (tls_fd < 0)
+  starttls_functions = &raw_functions;
+  starttls_fd = starttls_connect (argv[1], argv[2]);
+  if (starttls_fd < 0)
     {
-      perror ("tls_connect");
+      perror ("starttls_connect");
       return 1;
     }
 
@@ -260,9 +230,9 @@ main (argc, argv)
 
   while (1)
     {
-      FD_SET (tls_fd, &readfds);
+      FD_SET (starttls_fd, &readfds);
       FD_SET (in, &readfds);
-      if (select (tls_fd+1, &readfds, NULL, NULL, NULL) == -1
+      if (select (starttls_fd+1, &readfds, NULL, NULL, NULL) == -1
          && errno != EINTR )
        {
          perror ("select");
@@ -276,26 +246,21 @@ main (argc, argv)
            goto finish;
          for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
            {
-             FD_SET (tls_fd, &writefds);
-             if (select (tls_fd+1, NULL, &writefds, NULL, NULL) == -1)
+             FD_SET (starttls_fd, &writefds);
+             if (select (starttls_fd+1, NULL, &writefds, NULL, NULL) == -1)
                {
                  perror ("select");
                  return 1;
                }
-             if (tls_conn) 
-               wrote = SSL_write (tls_conn, retry, nbuffer);
-             else
-               wrote = write (tls_fd, retry, nbuffer);
-             if (wrote < 0) goto finish;
+             wrote = starttls_functions->write (starttls_fd, retry, nbuffer);
+             if (starttls_functions->is_fatal_error (wrote))
+               goto finish;
            }
        }
-      if (FD_ISSET (tls_fd, &readfds))
+      if (FD_ISSET (starttls_fd, &readfds))
        {
-         if (tls_conn)
-           nbuffer = SSL_read (tls_conn, buffer, sizeof buffer -1);
-         else
-           nbuffer = read (tls_fd, buffer, sizeof buffer -1);
-         if (nbuffer == 0)
+         nbuffer = starttls_functions->read (starttls_fd, buffer, sizeof buffer -1);
+         if (starttls_functions->is_fatal_error (nbuffer))
            goto finish;
          for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
            {
@@ -312,8 +277,8 @@ main (argc, argv)
     }
 
  finish:
-  close (in);
-  close (out);
+  shutdown(starttls_fd, SHUT_RDWR);
+  starttls_functions->close (starttls_fd);
   
   return 0;
 }