* starttls.c (main): Use poll() instead of select() if available.
[elisp/starttls.git] / starttls.c
index 778ab48..15fc5a0 100644 (file)
@@ -1,9 +1,8 @@
-/* TLSv1 filter for STARTTLS extension.
+/* simple wrapper program for STARTTLS
 
-   Copyright (C) 1999, 2000 Daiki Ueno <ueno@unixuser.org>
+   Copyright (C) 1999, 2000 Free Software Foundation, Inc.
 
    Author: Daiki Ueno <ueno@unixuser.org>
-       Kenichi OKADA <okada@opaopa.org>
    Created: 1999-11-19
    Keywords: TLS, OpenSSL
 
 
 */
 
-/*
-  How to compile: (OpenSSL is required)
-  
-  gcc -I/usr/local/ssl/include -o starttls starttls.c \
-    -L/usr/local/ssl/lib -lssl -lcrypto
-
-*/
-
 #include <sys/types.h>
 #include <stdio.h>
 #include <stdlib.h>
@@ -41,8 +32,6 @@
 
 #include <unistd.h>
 
-/* OpenSSL library. */
-
 #include <openssl/lhash.h>
 #include <openssl/bn.h>
 #include <openssl/err.h>
@@ -54,9 +43,9 @@
 #include <socks.h>
 #endif 
 
-#ifndef HAVE_GETADDRINFO
-#include "getaddrinfo.h"
-#endif /* !HAVE_GETADDRINFO */
+#ifdef NEED_ADDRINFO_H
+#include "addrinfo.h"
+#endif
 
 #include <sys/time.h>
 #include <sys/socket.h>
 #include <signal.h>
 #include <fcntl.h>
 #include <netinet/in.h>
+#ifdef HAVE_POLL_H
+#include <sys/poll.h>
+#endif
 #define _GNU_SOURCE
 #include <getopt.h>
 
-#ifdef HAVE_BASENAME
-# ifdef HAVE_LIBGEN_H
-#  include <libgen.h>
-#  ifdef basename
-#   undef basename
-#  endif
-# endif
-# include <string.h>
-#else
-inline char *
-basename(path) 
-     const char *path;
-{ 
-  char *p = rindex((path), '/');
-  return p ? p + 1 : (path);
-}
-#endif
-
-#define true 1
-
 static SSL_CTX *tls_ctx = NULL;
 static SSL *tls_conn = NULL;
 static int tls_fd;
 
 static char *opt_cert_file = NULL, *opt_key_file = NULL;
 static int  opt_verify = 0;
-static int opt_force;
 
 static int
 tls_ssl_ctx_new (cert_file, key_file)
@@ -223,8 +194,7 @@ usage (progname)
          "Options:\n\n"
          " --cert-file [file]      specify certificate file\n"
          " --key-file [file]       specify private key file\n"
-         " --verify [level]        set verification level\n"
-         " --force                 force negotiate\n",
+         " --verify [level]        set verification level\n",
          progname, PACKAGE, VERSION, progname);
 }
      
@@ -235,7 +205,11 @@ main (argc, argv)
 {
   int in = fileno (stdin), out = fileno (stdout), 
     nbuffer, wrote;
+#ifdef HAVE_POLL
+  struct pollfd readfds[2], writefds[1];
+#else
   fd_set readfds, writefds;
+#endif
   char buffer[BUFSIZ], *retry;
   struct sigaction act;
 
@@ -246,7 +220,6 @@ main (argc, argv)
       {"cert-file", 1, 0, 'c'},
       {"key-file", 1, 0, 'k'},
       {"verify", 1, 0, 'v'},
-      {"force", 0, 0, 'f'},
       {0, 0, 0, 0}
     };
 
@@ -267,9 +240,6 @@ main (argc, argv)
        case 'v':
          opt_verify = atoi (optarg);
          break;
-       case 'f':
-         opt_force = true;
-         break;
        default:
          usage (basename (argv[0]));
          return 1;
@@ -295,31 +265,51 @@ main (argc, argv)
   act.sa_flags = SA_RESTART|SA_RESETHAND;
   sigaction (SIGALRM, &act, NULL);
 
-  if (opt_force == true)
-    tls_negotiate();
+#ifdef HAVE_POLL
+  readfds[0].fd = in;
+  readfds[1].fd = tls_fd;
+  readfds[0].events = POLLIN;
+  readfds[1].events = POLLIN;
+  writefds[0].events = POLLOUT;
+#endif
 
   while (1)
     {
+#ifdef HAVE_POLL
+      if (poll (readfds, 2, -1) == -1 && errno != EINTR)
+#else
+      FD_ZERO (&readfds);
       FD_SET (tls_fd, &readfds);
       FD_SET (in, &readfds);
       if (select (tls_fd+1, &readfds, NULL, NULL, NULL) == -1
          && errno != EINTR )
+#endif
        {
-         perror ("select");
+         perror ("poll");
          return 1;
        }
+#ifdef HAVE_POLL
+      if (readfds[0].revents & POLLIN)
+#else
       if (FD_ISSET (in, &readfds))
+#endif
        {
-         nbuffer = read (in, buffer, BUFSIZ/2);
+         nbuffer = read (in, buffer, sizeof buffer -1);
 
          if (nbuffer == 0)
            goto finish;
          for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
            {
+#ifdef HAVE_POLL
+             writefds[0].fd = tls_fd;
+             if (poll (writefds, 1, -1) == -1)
+#else
+             FD_ZERO (&writefds);
              FD_SET (tls_fd, &writefds);
              if (select (tls_fd+1, NULL, &writefds, NULL, NULL) == -1)
+#endif
                {
-                 perror ("select");
+                 perror ("poll");
                  return 1;
                }
              if (tls_conn) 
@@ -329,20 +319,30 @@ main (argc, argv)
              if (wrote < 0) goto finish;
            }
        }
+#ifdef HAVE_POLL
+      if (readfds[1].revents & POLLIN)
+#else
       if (FD_ISSET (tls_fd, &readfds))
+#endif
        {
          if (tls_conn)
-           nbuffer = SSL_read (tls_conn, buffer, BUFSIZ/8);
+           nbuffer = SSL_read (tls_conn, buffer, sizeof buffer -1);
          else
-           nbuffer = read (tls_fd, buffer, BUFSIZ/2);
+           nbuffer = read (tls_fd, buffer, sizeof buffer -1);
          if (nbuffer == 0)
            goto finish;
          for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
            {
+#ifdef HAVE_POLL
+             writefds[0].fd = out;
+             if (poll (writefds, 1, -1) == -1)
+#else
+             FD_ZERO (&writefds);
              FD_SET (out, &writefds);
              if (select (out+1, NULL, &writefds, NULL, NULL) == -1)
+#endif
                {
-                 perror ("select");
+                 perror ("poll");
                  return 1;
                }
              wrote = write (out, retry, nbuffer);