update.
[elisp/starttls.git] / starttls.c
1 /* simple wrapper program for STARTTLS
2
3    Copyright (C) 1999, 2000 Free Software Foundation, Inc.
4
5    Author: Daiki Ueno <ueno@unixuser.org>
6    Created: 1999-11-19
7    Keywords: TLS, OpenSSL
8
9    This file is not part of any package.
10
11    This program is free software; you can redistribute it and/or modify 
12    it under the terms of the GNU General Public License as published by 
13    the Free Software Foundation; either version 2, or (at your option)  
14    any later version.                                                   
15
16    This program is distributed in the hope that it will be useful,      
17    but WITHOUT ANY WARRANTY; without even the implied warranty of       
18    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        
19    GNU General Public License for more details.                         
20
21    You should have received a copy of the GNU General Public License    
22    along with GNU Emacs; see the file COPYING.  If not, write to the    
23    Free Software Foundation, Inc., 59 Temple Place - Suite 330,         
24    Boston, MA 02111-1307, USA.                                          
25
26 */
27
28 #include <sys/types.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32
33 #include <unistd.h>
34
35 #include <openssl/lhash.h>
36 #include <openssl/bn.h>
37 #include <openssl/err.h>
38 #include <openssl/pem.h>
39 #include <openssl/x509.h>
40 #include <openssl/ssl.h>
41
42 #include <sys/time.h>
43 #include <sys/socket.h>
44 #include <sys/file.h>
45 #include <sys/ioctl.h>
46 #include <sys/stat.h>
47 #include <netdb.h>
48 #include <stdio.h>
49 #include <signal.h>
50 #include <fcntl.h>
51 #include <netinet/in.h>
52 #ifdef HAVE_POLL_H
53 #include <sys/poll.h>
54 #endif
55 #define _GNU_SOURCE
56 #include "getopt.h"
57
58 static SSL_CTX *tls_ctx = NULL;
59 static SSL *tls_conn = NULL;
60 static int tls_fd;
61
62 static char *opt_cert_file = NULL, *opt_key_file = NULL;
63 static int  opt_verify = 0;
64
65 static int
66 tls_ssl_ctx_new (cert_file, key_file)
67   const char *cert_file, *key_file;
68 {
69   SSL_load_error_strings ();
70   SSLeay_add_ssl_algorithms ();
71
72   tls_ctx = SSL_CTX_new (TLSv1_client_method());
73   if (!tls_ctx)
74     return -1;
75
76   SSL_CTX_set_options (tls_ctx, SSL_OP_ALL /* Work around all known bugs */); 
77
78   if (cert_file)
79     {
80       if (SSL_CTX_use_certificate_file (tls_ctx, cert_file, 
81                                         SSL_FILETYPE_PEM) <= 0)
82         return -1;
83       if (!key_file)
84         key_file = cert_file;
85       if (SSL_CTX_use_PrivateKey_file (tls_ctx, key_file, 
86                                        SSL_FILETYPE_PEM) <= 0)
87         return -1;
88       if (!SSL_CTX_check_private_key (tls_ctx))
89         return -1;
90     }
91
92   SSL_CTX_set_verify (tls_ctx, SSL_VERIFY_NONE, NULL);
93
94   return 0;
95 }
96
97 static int
98 tls_ssl_new(ctx, s)
99   SSL_CTX *ctx;
100   int s;
101 {
102   SSL_SESSION *session;
103   SSL_CIPHER *cipher;
104   X509   *peer;
105
106   tls_conn = (SSL *) SSL_new (ctx);
107   if (!tls_conn)
108     return -1;
109   SSL_clear(tls_conn);
110
111   if (!SSL_set_fd (tls_conn, s))
112     return -1;
113
114   SSL_set_connect_state (tls_conn);
115
116   if (SSL_connect (tls_conn) <= 0)
117     {
118       session = SSL_get_session (tls_conn);
119       if (session)
120         SSL_CTX_remove_session (ctx, session);
121       if (tls_conn!=NULL)
122         SSL_free (tls_conn);
123       return -1;
124     }
125
126   return 0;
127 }
128
129 static int
130 tls_connect (hostname, service)
131      const char *hostname, *service;
132 {
133   int server, false = 0;
134 #ifdef HAVE_ADDRINFO
135   struct addrinfo *in, *in0, hints;
136 #else
137   struct hostent *host;
138   struct servent *serv;
139   struct sockaddr_in sin;
140 #endif
141
142 #ifdef HAVE_ADDRINFO
143   memset (&hints, 0, sizeof (hints));
144   hints.ai_family = AF_UNSPEC;
145   hints.ai_socktype = SOCK_STREAM;
146   if (getaddrinfo (hostname, service, &hints, &in))
147     return -1;
148
149   for (in = in0; in; in = in->ai_next)
150     {
151       server = socket (in->ai_family, in->ai_socktype, in->ai_protocol);
152       if (server < 0)
153         continue;
154       if (connect (server, in->ai_addr, in->ai_addrlen) < 0)
155         {
156           server = -1;
157           continue;
158         }
159       break;
160   }
161
162   freeaddrinfo (in0);
163   if (server < 0)
164     return -1;
165 #else
166   memset (&sin, 0, sizeof (sin));
167   host = gethostbyname (hostname);
168   if (!host)
169     return -1;
170   memcpy (&sin.sin_addr, host->h_addr, host->h_length);
171   serv = getservbyname (service, "tcp");
172   if (serv)
173     sin.sin_port = serv->s_port;
174   else if (isdigit (service[0]))
175     sin.sin_port = htons (atoi (service));
176   sin.sin_family = AF_INET;
177   server = socket (sin.sin_family, SOCK_STREAM, 0);
178   if (server == -1)
179     return -1;
180
181   if (connect (server, (struct sockaddr *)&sin, sizeof (sin)) < 0)
182     {
183       close (server);
184       return -1;
185     }
186 #endif
187
188   setsockopt (server, SOL_SOCKET, SO_KEEPALIVE, (const char *) &false,
189               sizeof (false));
190
191   return server;
192 }
193
194 static void
195 tls_negotiate (sig)
196      int sig;
197 {
198   if (tls_ssl_ctx_new (opt_cert_file, opt_key_file) == -1)
199     return;
200
201   (void) tls_ssl_new (tls_ctx, tls_fd); /* Negotiation has done. */
202 }
203
204 static void
205 usage (progname)
206      const char *progname;
207 {
208   printf ("%s (%s) %s\n"
209           "Copyright (C) 1999 Free Software Foundation, Inc.\n"
210           "This program comes with ABSOLUTELY NO WARRANTY.\n"
211           "This is free software, and you are welcome to redistribute it\n"
212           "under certain conditions. See the file COPYING for details.\n\n"
213           "Usage: %s [options] host port\n\n"
214           "Options:\n\n"
215           " --cert-file [file]      specify certificate file\n"
216           " --key-file [file]       specify private key file\n"
217           " --verify [level]        set verification level\n",
218           progname, PACKAGE, VERSION, progname);
219 }
220      
221 int
222 main (argc, argv) 
223   int argc;
224   char **argv;
225 {
226   int in = fileno (stdin), out = fileno (stdout), 
227     nbuffer, wrote;
228 #ifdef HAVE_POLL
229   struct pollfd readfds[2], writefds[1];
230 #else
231   fd_set readfds, writefds;
232 #endif
233   char buffer[BUFSIZ], *retry;
234   struct sigaction act;
235
236   int this_option_optind = optind ? optind : 1;
237   int option_index = 0, c;
238   static struct option long_options[] =
239     {
240       {"cert-file", 1, 0, 'c'},
241       {"key-file", 1, 0, 'k'},
242       {"verify", 1, 0, 'v'},
243       {0, 0, 0, 0}
244     };
245
246   while (1)
247     {
248       c = getopt_long (argc, argv, "c:k:v:", long_options, &option_index);
249       if (c == -1)
250         break;
251     
252       switch (c)
253         {
254         case 'c':
255           opt_cert_file = optarg;
256           break;
257         case 'k':
258           opt_key_file = optarg;
259           break;
260         case 'v':
261           opt_verify = atoi (optarg);
262           break;
263         default:
264           usage (basename (argv[0]));
265           return 1;
266         }
267     }
268
269   if (optind+2 != argc)
270     {
271       usage (basename (argv[0]));
272       return 1;
273     }
274
275   tls_fd = tls_connect (argv[optind], argv[optind+1]);
276   if (tls_fd < 0)
277     {
278       perror ("tls_connect");
279       return 1;
280     }
281
282   memset (&act, 0, sizeof (act));
283   act.sa_handler = tls_negotiate;
284   sigemptyset (&act.sa_mask);
285   act.sa_flags = SA_RESTART|SA_RESETHAND;
286   sigaction (SIGALRM, &act, NULL);
287
288 #ifdef HAVE_POLL
289   readfds[0].fd = in;
290   readfds[1].fd = tls_fd;
291   readfds[0].events = POLLIN;
292   readfds[1].events = POLLIN;
293   writefds[0].events = POLLOUT;
294 #endif
295
296   while (1)
297     {
298 #ifdef HAVE_POLL
299       if (poll (readfds, 2, -1) == -1 && errno != EINTR)
300 #else
301       FD_ZERO (&readfds);
302       FD_SET (tls_fd, &readfds);
303       FD_SET (in, &readfds);
304       if (select (tls_fd+1, &readfds, NULL, NULL, NULL) == -1
305           && errno != EINTR )
306 #endif
307         {
308           perror ("poll");
309           return 1;
310         }
311 #ifdef HAVE_POLL
312       if (readfds[0].revents & POLLIN)
313 #else
314       if (FD_ISSET (in, &readfds))
315 #endif
316         {
317           nbuffer = read (in, buffer, sizeof buffer -1);
318
319           if (nbuffer == 0)
320             goto finish;
321           for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
322             {
323 #ifdef HAVE_POLL
324               writefds[0].fd = tls_fd;
325               if (poll (writefds, 1, -1) == -1)
326 #else
327               FD_ZERO (&writefds);
328               FD_SET (tls_fd, &writefds);
329               if (select (tls_fd+1, NULL, &writefds, NULL, NULL) == -1)
330 #endif
331                 {
332                   perror ("poll");
333                   return 1;
334                 }
335               if (tls_conn) 
336                 wrote = SSL_write (tls_conn, retry, nbuffer);
337               else
338                 wrote = write (tls_fd, retry, nbuffer);
339               if (wrote < 0) goto finish;
340             }
341         }
342 #ifdef HAVE_POLL
343       if (readfds[1].revents & POLLIN)
344 #else
345       if (FD_ISSET (tls_fd, &readfds))
346 #endif
347         {
348           if (tls_conn)
349             nbuffer = SSL_read (tls_conn, buffer, sizeof buffer -1);
350           else
351             nbuffer = read (tls_fd, buffer, sizeof buffer -1);
352           if (nbuffer == 0)
353             goto finish;
354           for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
355             {
356 #ifdef HAVE_POLL
357               writefds[0].fd = out;
358               if (poll (writefds, 1, -1) == -1)
359 #else
360               FD_ZERO (&writefds);
361               FD_SET (out, &writefds);
362               if (select (out+1, NULL, &writefds, NULL, NULL) == -1)
363 #endif
364                 {
365                   perror ("poll");
366                   return 1;
367                 }
368               wrote = write (out, retry, nbuffer);
369               if (wrote < 0) goto finish;
370             }
371         }
372     }
373
374  finish:
375   close (in);
376   close (out);
377   
378   return 0;
379 }