* starttls.c (main): Clear fd sets before select().
[elisp/starttls.git] / starttls.c
1 /* simple wrapper program for STARTTLS
2
3    Copyright (C) 1999, 2000 Free Software Foundation, Inc.
4
5    Author: Daiki Ueno <ueno@unixuser.org>
6    Created: 1999-11-19
7    Keywords: TLS, OpenSSL
8
9    This file is not part of any package.
10
11    This program is free software; you can redistribute it and/or modify 
12    it under the terms of the GNU General Public License as published by 
13    the Free Software Foundation; either version 2, or (at your option)  
14    any later version.                                                   
15
16    This program is distributed in the hope that it will be useful,      
17    but WITHOUT ANY WARRANTY; without even the implied warranty of       
18    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        
19    GNU General Public License for more details.                         
20
21    You should have received a copy of the GNU General Public License    
22    along with GNU Emacs; see the file COPYING.  If not, write to the    
23    Free Software Foundation, Inc., 59 Temple Place - Suite 330,         
24    Boston, MA 02111-1307, USA.                                          
25
26 */
27
28 #include <sys/types.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32
33 #include <unistd.h>
34
35 #include <openssl/lhash.h>
36 #include <openssl/bn.h>
37 #include <openssl/err.h>
38 #include <openssl/pem.h>
39 #include <openssl/x509.h>
40 #include <openssl/ssl.h>
41
42 #ifdef HAVE_SOCKS_H
43 #include <socks.h>
44 #endif 
45
46 #ifdef NEED_ADDRINFO_H
47 #include "addrinfo.h"
48 #endif
49
50 #include <sys/time.h>
51 #include <sys/socket.h>
52 #include <sys/file.h>
53 #include <sys/ioctl.h>
54 #include <sys/stat.h>
55 #include <netdb.h>
56 #include <stdio.h>
57 #include <signal.h>
58 #include <fcntl.h>
59 #include <netinet/in.h>
60 #define _GNU_SOURCE
61 #include <getopt.h>
62
63 static SSL_CTX *tls_ctx = NULL;
64 static SSL *tls_conn = NULL;
65 static int tls_fd;
66
67 static char *opt_cert_file = NULL, *opt_key_file = NULL;
68 static int  opt_verify = 0;
69
70 static int
71 tls_ssl_ctx_new (cert_file, key_file)
72   const char *cert_file, *key_file;
73 {
74   SSL_load_error_strings ();
75   SSLeay_add_ssl_algorithms ();
76
77   tls_ctx = SSL_CTX_new (TLSv1_client_method());
78   if (!tls_ctx)
79     return -1;
80
81   SSL_CTX_set_options (tls_ctx, SSL_OP_ALL /* Work around all known bugs */); 
82
83   if (cert_file)
84     {
85       if (SSL_CTX_use_certificate_file (tls_ctx, cert_file, 
86                                         SSL_FILETYPE_PEM) <= 0)
87         return -1;
88       if (!key_file)
89         key_file = cert_file;
90       if (SSL_CTX_use_PrivateKey_file (tls_ctx, key_file, 
91                                        SSL_FILETYPE_PEM) <= 0)
92         return -1;
93       if (!SSL_CTX_check_private_key (tls_ctx))
94         return -1;
95     }
96
97   SSL_CTX_set_verify (tls_ctx, SSL_VERIFY_NONE, NULL);
98
99   return 0;
100 }
101
102 static int
103 tls_ssl_new(ctx, s)
104   SSL_CTX *ctx;
105   int s;
106 {
107   SSL_SESSION *session;
108   SSL_CIPHER *cipher;
109   X509   *peer;
110
111   tls_conn = (SSL *) SSL_new (ctx);
112   if (!tls_conn)
113     return -1;
114   SSL_clear(tls_conn);
115
116   if (!SSL_set_fd (tls_conn, s))
117     return -1;
118
119   SSL_set_connect_state (tls_conn);
120
121   if (SSL_connect (tls_conn) <= 0)
122     {
123       session = SSL_get_session (tls_conn);
124       if (session)
125         SSL_CTX_remove_session (ctx, session);
126       if (tls_conn!=NULL)
127         SSL_free (tls_conn);
128       return -1;
129     }
130
131   return 0;
132 }
133
134 static int
135 tls_connect (hostname, service)
136      const char *hostname, *service;
137 {
138   struct protoent *proto;
139   struct addrinfo *in, hints;
140   int server, false = 0;
141
142   proto = getprotobyname ("tcp");
143   if (!proto)
144     return -1;
145
146   memset (&hints, 0, sizeof (hints));
147   hints.ai_family = AF_UNSPEC;
148   hints.ai_socktype = SOCK_STREAM;
149   hints.ai_protocol = proto->p_proto;
150
151   if (getaddrinfo (hostname, service, &hints, &in) < 0) 
152     return -1;
153
154   server = socket (in->ai_family, in->ai_socktype, 0);
155   if (server < 0)
156     return -1;
157
158   if (setsockopt (server, SOL_SOCKET, SO_KEEPALIVE,
159                   (const char *) &false, sizeof (false))) 
160     return -1;
161
162   if (connect (server, in->ai_addr, in->ai_addrlen) < 0)
163     {
164       close (server);
165       return -1;
166     }
167
168   return server;
169 }
170
171 static void
172 tls_negotiate (sig)
173      int sig;
174 {
175   if (tls_ssl_ctx_new (opt_cert_file, opt_key_file) == -1)
176     return;
177
178   (void) tls_ssl_new (tls_ctx, tls_fd); /* Negotiation has done. */
179 }
180
181 static void
182 usage (progname)
183      const char *progname;
184 {
185   printf ("%s (%s) %s\n"
186           "Copyright (C) 1999 Free Software Foundation, Inc.\n"
187           "This program comes with ABSOLUTELY NO WARRANTY.\n"
188           "This is free software, and you are welcome to redistribute it\n"
189           "under certain conditions. See the file COPYING for details.\n\n"
190           "Usage: %s [options] host port\n\n"
191           "Options:\n\n"
192           " --cert-file [file]      specify certificate file\n"
193           " --key-file [file]       specify private key file\n"
194           " --verify [level]        set verification level\n",
195           progname, PACKAGE, VERSION, progname);
196 }
197      
198 int
199 main (argc, argv) 
200   int argc;
201   char **argv;
202 {
203   int in = fileno (stdin), out = fileno (stdout), 
204     nbuffer, wrote;
205   fd_set readfds, writefds;
206   char buffer[BUFSIZ], *retry;
207   struct sigaction act;
208
209   int this_option_optind = optind ? optind : 1;
210   int option_index = 0, c;
211   static struct option long_options[] =
212     {
213       {"cert-file", 1, 0, 'c'},
214       {"key-file", 1, 0, 'k'},
215       {"verify", 1, 0, 'v'},
216       {0, 0, 0, 0}
217     };
218
219   while (1)
220     {
221       c = getopt_long (argc, argv, "c:k:v:f", long_options, &option_index);
222       if (c == -1)
223         break;
224     
225       switch (c)
226         {
227         case 'c':
228           opt_cert_file = optarg;
229           break;
230         case 'k':
231           opt_key_file = optarg;
232           break;
233         case 'v':
234           opt_verify = atoi (optarg);
235           break;
236         default:
237           usage (basename (argv[0]));
238           return 1;
239         }
240     }
241
242   if (optind+2 != argc)
243     {
244       usage (basename (argv[0]));
245       return 1;
246     }
247
248   tls_fd = tls_connect (argv[optind], argv[optind+1]);
249   if (tls_fd < 0)
250     {
251       perror ("tls_connect");
252       return 1;
253     }
254
255   memset (&act, 0, sizeof (act));
256   act.sa_handler = tls_negotiate;
257   sigemptyset (&act.sa_mask);
258   act.sa_flags = SA_RESTART|SA_RESETHAND;
259   sigaction (SIGALRM, &act, NULL);
260
261   while (1)
262     {
263       FD_ZERO (&readfds);
264       FD_SET (tls_fd, &readfds);
265       FD_SET (in, &readfds);
266       if (select (tls_fd+1, &readfds, NULL, NULL, NULL) == -1
267           && errno != EINTR )
268         {
269           perror ("select");
270           return 1;
271         }
272       if (FD_ISSET (in, &readfds))
273         {
274           nbuffer = read (in, buffer, sizeof buffer -1);
275
276           if (nbuffer == 0)
277             goto finish;
278           for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
279             {
280               FD_ZERO (&writefds);
281               FD_SET (tls_fd, &writefds);
282               if (select (tls_fd+1, NULL, &writefds, NULL, NULL) == -1)
283                 {
284                   perror ("select");
285                   return 1;
286                 }
287               if (tls_conn) 
288                 wrote = SSL_write (tls_conn, retry, nbuffer);
289               else
290                 wrote = write (tls_fd, retry, nbuffer);
291               if (wrote < 0) goto finish;
292             }
293         }
294       if (FD_ISSET (tls_fd, &readfds))
295         {
296           if (tls_conn)
297             nbuffer = SSL_read (tls_conn, buffer, sizeof buffer -1);
298           else
299             nbuffer = read (tls_fd, buffer, sizeof buffer -1);
300           if (nbuffer == 0)
301             goto finish;
302           for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
303             {
304               FD_ZERO (&writefds);
305               FD_SET (out, &writefds);
306               if (select (out+1, NULL, &writefds, NULL, NULL) == -1)
307                 {
308                   perror ("select");
309                   return 1;
310                 }
311               wrote = write (out, retry, nbuffer);
312               if (wrote < 0) goto finish;
313             }
314         }
315     }
316
317  finish:
318   close (in);
319   close (out);
320   
321   return 0;
322 }