734339630b096fb14a4879e481ba5a0e64ba2be6
[elisp/starttls.git] / starttls.c
1 /* simple wrapper program for STARTTLS
2
3    Copyright (C) 1999, 2000 Free Software Foundation, Inc.
4
5    Author: Daiki Ueno <ueno@unixuser.org>
6    Created: 1999-11-19
7    Keywords: TLS, OpenSSL
8
9    This file is not part of any package.
10
11    This program is free software; you can redistribute it and/or modify 
12    it under the terms of the GNU General Public License as published by 
13    the Free Software Foundation; either version 2, or (at your option)  
14    any later version.                                                   
15
16    This program is distributed in the hope that it will be useful,      
17    but WITHOUT ANY WARRANTY; without even the implied warranty of       
18    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        
19    GNU General Public License for more details.                         
20
21    You should have received a copy of the GNU General Public License    
22    along with GNU Emacs; see the file COPYING.  If not, write to the    
23    Free Software Foundation, Inc., 59 Temple Place - Suite 330,         
24    Boston, MA 02111-1307, USA.                                          
25
26 */
27
28 #include <sys/types.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32
33 #include <unistd.h>
34
35 #include <errno.h>
36 #include <sys/time.h>
37 #include <sys/socket.h>
38 #include <sys/file.h>
39 #include <sys/ioctl.h>
40 #include <sys/stat.h>
41 #include <netdb.h>
42 #include <stdio.h>
43 #include <signal.h>
44 #include <fcntl.h>
45 #include <netinet/in.h>
46 #ifdef HAVE_POLL_H
47 #include <sys/poll.h>
48 #endif
49 #define _GNU_SOURCE
50 #include "getopt.h"
51
52 extern void tls_negotiate (int, const char *, const char *);
53 extern int tls_write(int, const char *, int);
54 extern int tls_read(int, char *, int);
55 extern int tls_pending();
56
57 static char *opt_cert_file = NULL, *opt_key_file = NULL;
58 static int tls_fd;
59
60 static void
61 usage (progname)
62      const char *progname;
63 {
64   printf ("%s (%s) %s\n"
65           "Copyright (C) 1999 Free Software Foundation, Inc.\n"
66           "This program comes with ABSOLUTELY NO WARRANTY.\n"
67           "This is free software, and you are welcome to redistribute it\n"
68           "under certain conditions. See the file COPYING for details.\n\n"
69           "Usage: %s [options] host port\n\n"
70           "Options:\n\n"
71           " --cert-file [file]      specify certificate file\n"
72           " --key-file [file]       specify private key file\n",
73           progname, PACKAGE, VERSION, progname);
74 }
75
76 static void
77 do_tls_negotiate(sig)
78   int sig;
79 {
80   tls_negotiate(tls_fd, opt_cert_file, opt_key_file);
81 }
82
83 int
84 tcp_connect (hostname, service)
85      const char *hostname, *service;
86 {
87   int server, false = 0;
88 #ifdef HAVE_ADDRINFO
89   struct addrinfo *in, *in0, hints;
90 #else
91   struct hostent *host;
92   struct servent *serv;
93   struct sockaddr_in sin;
94 #endif
95
96 #ifdef HAVE_ADDRINFO
97   memset (&hints, 0, sizeof (hints));
98   hints.ai_family = AF_UNSPEC;
99   hints.ai_socktype = SOCK_STREAM;
100   if (getaddrinfo (hostname, service, &hints, &in0))
101     return -1;
102
103   for (in = in0; in; in = in->ai_next)
104     {
105       server = socket (in->ai_family, in->ai_socktype, in->ai_protocol);
106       if (server < 0)
107         continue;
108       if (connect (server, in->ai_addr, in->ai_addrlen) < 0)
109         {
110           server = -1;
111           continue;
112         }
113       break;
114   }
115
116   if (server < 0)
117     return -1;
118 #else
119   memset (&sin, 0, sizeof (sin));
120   host = gethostbyname (hostname);
121   if (!host)
122     return -1;
123   memcpy (&sin.sin_addr, host->h_addr, host->h_length);
124   serv = getservbyname (service, "tcp");
125   if (serv)
126     sin.sin_port = serv->s_port;
127   else if (isdigit (service[0]))
128     sin.sin_port = htons (atoi (service));
129   sin.sin_family = AF_INET;
130   server = socket (sin.sin_family, SOCK_STREAM, 0);
131   if (server == -1)
132     return -1;
133
134   if (connect (server, (struct sockaddr *)&sin, sizeof (sin)) < 0)
135     {
136       close (server);
137       return -1;
138     }
139 #endif
140
141   setsockopt (server, SOL_SOCKET, SO_KEEPALIVE, (const char *) &false,
142               sizeof (false));
143
144   return server;
145 }
146
147 int
148 main (argc, argv) 
149   int argc;
150   char **argv;
151 {
152   int in = fileno (stdin), out = fileno (stdout), 
153     nbuffer, wrote;
154 #ifdef HAVE_POLL
155   struct pollfd readfds[2], writefds[1];
156 #else
157   fd_set readfds, writefds;
158 #endif
159   char buffer[BUFSIZ], *retry;
160   struct sigaction act;
161
162   int this_option_optind = optind ? optind : 1;
163   int option_index = 0, c;
164   static struct option long_options[] =
165     {
166       {"cert-file", 1, 0, 'c'},
167       {"key-file", 1, 0, 'k'},
168       {0, 0, 0, 0}
169     };
170
171   while (1)
172     {
173       c = getopt_long (argc, argv, "c:k:", long_options, &option_index);
174       if (c == -1)
175         break;
176     
177       switch (c)
178         {
179         case 'c':
180           opt_cert_file = optarg;
181           break;
182         case 'k':
183           opt_key_file = optarg;
184           break;
185         default:
186           usage (basename (argv[0]));
187           return 1;
188         }
189     }
190
191   if (optind+2 != argc)
192     {
193       usage (basename (argv[0]));
194       return 1;
195     }
196
197   tls_fd = tcp_connect (argv[optind], argv[optind+1]);
198   if (tls_fd < 0)
199     {
200       perror ("tcp_connect");
201       return 1;
202     }
203
204   memset (&act, 0, sizeof (act));
205   act.sa_handler = do_tls_negotiate;
206   sigemptyset (&act.sa_mask);
207   act.sa_flags = SA_RESTART|SA_RESETHAND;
208   sigaction (SIGALRM, &act, NULL);
209
210 #ifdef HAVE_POLL
211   readfds[0].fd = in;
212   readfds[1].fd = tls_fd;
213   readfds[0].events = POLLIN;
214   readfds[1].events = POLLIN;
215   writefds[0].events = POLLOUT;
216 #endif
217
218   while (1)
219     {
220 #ifdef HAVE_POLL
221       if (poll (readfds, 2, -1) == -1 && errno != EINTR)
222 #else
223       FD_ZERO (&readfds);
224       FD_SET (tls_fd, &readfds);
225       FD_SET (in, &readfds);
226       if (select (tls_fd+1, &readfds, NULL, NULL, NULL) == -1
227           && errno != EINTR )
228 #endif
229         {
230           perror ("poll");
231           return 1;
232         }
233 #ifdef HAVE_POLL
234       if (readfds[0].revents & POLLIN)
235 #else
236       if (FD_ISSET (in, &readfds))
237 #endif
238         {
239           nbuffer = read (in, buffer, sizeof buffer -1);
240
241           if (nbuffer == 0)
242             goto finish;
243           for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
244             {
245 #ifdef HAVE_POLL
246               writefds[0].fd = tls_fd;
247               if (poll (writefds, 1, -1) == -1)
248 #else
249               FD_ZERO (&writefds);
250               FD_SET (tls_fd, &writefds);
251               if (select (tls_fd+1, NULL, &writefds, NULL, NULL) == -1)
252 #endif
253                 {
254                   perror ("poll");
255                   return 1;
256                 }
257               wrote = tls_write(tls_fd, retry, nbuffer);
258               if (wrote < 0) goto finish;
259             }
260         }
261 #ifdef HAVE_POLL
262       if (readfds[1].revents & POLLIN)
263 #else
264       if (FD_ISSET (tls_fd, &readfds))
265 #endif
266         {
267 readtop:
268           nbuffer = tls_read(tls_fd, buffer, sizeof buffer -1);
269           if (nbuffer == 0)
270             goto finish;
271           for (retry = buffer; nbuffer > 0; nbuffer -= wrote, retry += wrote)
272             {
273 #ifdef HAVE_POLL
274               writefds[0].fd = out;
275               if (poll (writefds, 1, -1) == -1)
276 #else
277               FD_ZERO (&writefds);
278               FD_SET (out, &writefds);
279               if (select (out+1, NULL, &writefds, NULL, NULL) == -1)
280 #endif
281                 {
282                   perror ("poll");
283                   return 1;
284                 }
285               wrote = write (out, retry, nbuffer);
286               if (wrote < 0) goto finish;
287             }
288           if (tls_pending())
289             goto readtop;
290         }
291     }
292
293  finish:
294   close (in);
295   close (out);
296   
297   return 0;
298 }