* Add specific error codes for every command on the wire protocol, so that
[citadel.git] / citadel / serv_crypto.c
1 /* $Id$ */
2
3 #include <string.h>
4 #include <unistd.h>
5 #include <sys/types.h>
6 #include "sysdep.h"
7
8 #ifdef HAVE_OPENSSL
9 #include <openssl/ssl.h>
10 #include <openssl/err.h>
11 #include <openssl/rand.h>
12 #endif
13
14 #if TIME_WITH_SYS_TIME
15 # include <sys/time.h>
16 # include <time.h>
17 #else
18 # if HAVE_SYS_TIME_H
19 #  include <sys/time.h>
20 # else
21 #  include <time.h>
22 # endif
23 #endif
24
25 #ifdef HAVE_PTHREAD_H
26 #include <pthread.h>
27 #endif
28
29 #ifdef HAVE_SYS_SELECT_H
30 #include <sys/select.h>
31 #endif
32
33 #include <stdio.h>
34 #include "server.h"
35 #include "serv_crypto.h"
36 #include "sysdep_decls.h"
37 #include "serv_extensions.h"
38
39
40 #ifdef HAVE_OPENSSL
41 SSL_CTX *ssl_ctx;               /* SSL context */
42 pthread_mutex_t **SSLCritters;  /* Things needing locking */
43
44 static unsigned long id_callback(void)
45 {
46         return (unsigned long) pthread_self();
47 }
48
49  /*
50   * Set up the cert things on the server side. We do need both the
51   * private key (in key_file) and the cert (in cert_file).
52   * Both files may be identical.
53   *
54   * This function is taken from OpenSSL apps/s_cb.c
55   */
56
57 static int set_cert_stuff(SSL_CTX * ctx,
58                           const char *cert_file, const char *key_file)
59 {
60         if (cert_file != NULL) {
61                 if (SSL_CTX_use_certificate_file(ctx, cert_file,
62                                                  SSL_FILETYPE_PEM) <= 0) {
63                         lprintf(3, "unable to get certificate from '%s'",
64                                 cert_file);
65                         return (0);
66                 }
67                 if (key_file == NULL)
68                         key_file = cert_file;
69                 if (SSL_CTX_use_PrivateKey_file(ctx, key_file,
70                                                 SSL_FILETYPE_PEM) <= 0) {
71                         lprintf(3, "unable to get private key from '%s'",
72                                 key_file);
73                         return (0);
74                 }
75                 /* Now we know that a key and cert have been set against
76                  * the SSL context */
77                 if (!SSL_CTX_check_private_key(ctx)) {
78                         lprintf(3,
79                                 "Private key does not match the certificate public key");
80                         return (0);
81                 }
82         }
83         return (1);
84 }
85
86
87 void init_ssl(void)
88 {
89         SSL_METHOD *ssl_method;
90         DH *dh;
91
92         if (!access("/var/run/egd-pool", F_OK))
93                 RAND_egd("/var/run/egd-pool");
94
95         if (!RAND_status()) {
96                 lprintf(2,
97                         "PRNG not adequately seeded, won't do SSL/TLS\n");
98                 return;
99         }
100         SSLCritters =
101             mallok(CRYPTO_num_locks() * sizeof(pthread_mutex_t *));
102         if (!SSLCritters) {
103                 lprintf(1, "citserver: can't allocate memory!!\n");
104                 /* Nothing's been initialized, just die */
105                 exit(1);
106         } else {
107                 int a;
108
109                 for (a = 0; a < CRYPTO_num_locks(); a++) {
110                         SSLCritters[a] = mallok(sizeof(pthread_mutex_t));
111                         if (!SSLCritters[a]) {
112                                 lprintf(1,
113                                         "citserver: can't allocate memory!!\n");
114                                 /* Nothing's been initialized, just die */
115                                 exit(1);
116                         }
117                         pthread_mutex_init(SSLCritters[a], NULL);
118                 }
119         }
120
121         /*
122          * Initialize SSL transport layer
123          */
124         SSL_library_init();
125         SSL_load_error_strings();
126         ssl_method = SSLv23_server_method();
127         if (!(ssl_ctx = SSL_CTX_new(ssl_method))) {
128                 lprintf(2, "SSL_CTX_new failed: %s\n",
129                         ERR_reason_error_string(ERR_get_error()));
130                 return;
131         }
132         if (!(SSL_CTX_set_cipher_list(ssl_ctx, CIT_CIPHERS))) {
133                 lprintf(2, "SSL: No ciphers available\n");
134                 SSL_CTX_free(ssl_ctx);
135                 ssl_ctx = NULL;
136                 return;
137         }
138 #if 0
139 #if SSLEAY_VERSION_NUMBER >= 0x00906000L
140         SSL_CTX_set_mode(ssl_ctx, SSL_CTX_get_mode(ssl_ctx) |
141                          SSL_MODE_AUTO_RETRY);
142 #endif
143 #endif
144         CRYPTO_set_locking_callback(ssl_lock);
145         CRYPTO_set_id_callback(id_callback);
146
147         /* Load DH parameters into the context */
148         dh = DH_new();
149         if (!dh) {
150                 lprintf(2, "init_ssl() can't allocate a DH object: %s\n",
151                         ERR_reason_error_string(ERR_get_error()));
152                 SSL_CTX_free(ssl_ctx);
153                 ssl_ctx = NULL;
154                 return;
155         }
156         if (!(BN_hex2bn(&(dh->p), DH_P))) {
157                 lprintf(2, "init_ssl() can't assign DH_P: %s\n",
158                         ERR_reason_error_string(ERR_get_error()));
159                 SSL_CTX_free(ssl_ctx);
160                 ssl_ctx = NULL;
161                 return;
162         }
163         if (!(BN_hex2bn(&(dh->g), DH_G))) {
164                 lprintf(2, "init_ssl() can't assign DH_G: %s\n",
165                         ERR_reason_error_string(ERR_get_error()));
166                 SSL_CTX_free(ssl_ctx);
167                 ssl_ctx = NULL;
168                 return;
169         }
170         dh->length = DH_L;
171         SSL_CTX_set_tmp_dh(ssl_ctx, dh);
172         DH_free(dh);
173
174         /* Get our certificates in order */
175         if (set_cert_stuff(ssl_ctx,
176                            "/etc/ssh/mail01.jemcaterers.net.cer",
177                            "/etc/ssh/ssh_host_rsa_key") != 1) {
178
179                 lprintf(3, "SSL ERROR: cert is bad!\n");
180
181         }
182
183         /* Finally let the server know we're here */
184         CtdlRegisterProtoHook(cmd_stls, "STLS", "Start SSL/TLS session");
185         CtdlRegisterProtoHook(cmd_gtls, "GTLS",
186                               "Get SSL/TLS session status");
187         CtdlRegisterSessionHook(endtls, EVT_STOP);
188 }
189
190
191 /*
192  * client_write_ssl() Send binary data to the client encrypted.
193  */
194 void client_write_ssl(char *buf, int nbytes)
195 {
196         int retval;
197         int nremain;
198         char junk[1];
199
200         nremain = nbytes;
201
202         while (nremain > 0) {
203                 if (SSL_want_write(CC->ssl)) {
204                         if ((SSL_read(CC->ssl, junk, 0)) < 1) {
205                                 lprintf(9, "SSL_read in client_write:\n");
206                                 ERR_print_errors_fp(stderr);
207                         }
208                 }
209                 retval =
210                     SSL_write(CC->ssl, &buf[nbytes - nremain], nremain);
211                 if (retval < 1) {
212                         long errval;
213
214                         errval = SSL_get_error(CC->ssl, retval);
215                         if (errval == SSL_ERROR_WANT_READ ||
216                             errval == SSL_ERROR_WANT_WRITE) {
217                                 sleep(1);
218                                 continue;
219                         }
220                         lprintf(9, "SSL_write got error %ld\n", errval);
221                         endtls();
222                         client_write(&buf[nbytes - nremain], nremain);
223                         return;
224                 }
225                 nremain -= retval;
226         }
227 }
228
229
230 /*
231  * client_read_ssl() - read data from the encrypted layer.
232  */
233 int client_read_ssl(char *buf, int bytes, int timeout)
234 {
235         int len, rlen;
236         fd_set rfds;
237         struct timeval tv;
238         int retval;
239         int s;
240         char junk[1];
241
242         len = 0;
243         while (len < bytes) {
244                 FD_ZERO(&rfds);
245                 s = BIO_get_fd(CC->ssl->rbio, NULL);
246                 FD_SET(s, &rfds);
247                 tv.tv_sec = timeout;
248                 tv.tv_usec = 0;
249
250                 retval = select(s + 1, &rfds, NULL, NULL, &tv);
251
252                 if (FD_ISSET(s, &rfds) == 0) {
253                         return (0);
254                 }
255
256                 if (SSL_want_read(CC->ssl)) {
257                         if ((SSL_write(CC->ssl, junk, 0)) < 1) {
258                                 lprintf(9, "SSL_write in client_read:\n");
259                                 ERR_print_errors_fp(stderr);
260                         }
261                 }
262                 rlen = SSL_read(CC->ssl, &buf[len], bytes - len);
263                 if (rlen < 1) {
264                         long errval;
265
266                         errval = SSL_get_error(CC->ssl, rlen);
267                         if (errval == SSL_ERROR_WANT_READ ||
268                             errval == SSL_ERROR_WANT_WRITE) {
269                                 sleep(1);
270                                 continue;
271                         }
272                         lprintf(9, "SSL_read got error %ld\n", errval);
273                         endtls();
274                         return (client_read_to
275                                 (&buf[len], bytes - len, timeout));
276                 }
277                 len += rlen;
278         }
279         return (1);
280 }
281
282
283 /*
284  * cmd_stls() starts SSL/TLS encryption for the current session
285  */
286 void cmd_stls(char *params)
287 {
288         int retval, bits, alg_bits;
289
290         if (!ssl_ctx) {
291                 cprintf("%d No SSL_CTX available\n", ERROR + CMD_NOT_SUPPORTED);
292                 return;
293         }
294         if (!(CC->ssl = SSL_new(ssl_ctx))) {
295                 lprintf(2, "SSL_new failed: %s\n",
296                                 ERR_reason_error_string(ERR_peek_error()));
297                 cprintf("%d SSL_new: %s\n", ERROR + INTERNAL_ERROR,
298                                 ERR_reason_error_string(ERR_get_error()));
299                 return;
300         }
301         if (!(SSL_set_fd(CC->ssl, CC->client_socket))) {
302                 lprintf(2, "SSL_set_fd failed: %s\n",
303                         ERR_reason_error_string(ERR_peek_error()));
304                 SSL_free(CC->ssl);
305                 CC->ssl = NULL;
306                 cprintf("%d SSL_set_fd: %s\n", ERROR + INTERNAL_ERROR,
307                                 ERR_reason_error_string(ERR_get_error()));
308                 return;
309         }
310         cprintf("%d \n", CIT_OK);
311         retval = SSL_accept(CC->ssl);
312         if (retval < 1) {
313                 /*
314                  * Can't notify the client of an error here; they will
315                  * discover the problem at the SSL layer and should
316                  * revert to unencrypted communications.
317                  */
318                 long errval;
319
320                 errval = SSL_get_error(CC->ssl, retval);
321                 lprintf(2, "SSL_accept failed: %s\n",
322                         ERR_reason_error_string(ERR_get_error()));
323                 SSL_free(CC->ssl);
324                 CC->ssl = NULL;
325                 return;
326         }
327         BIO_set_close(CC->ssl->rbio, BIO_NOCLOSE);
328         bits =
329             SSL_CIPHER_get_bits(SSL_get_current_cipher(CC->ssl),
330                                 &alg_bits);
331         lprintf(3, "SSL/TLS using %s on %s (%d of %d bits)\n",
332                 SSL_CIPHER_get_name(SSL_get_current_cipher(CC->ssl)),
333                 SSL_CIPHER_get_version(SSL_get_current_cipher(CC->ssl)),
334                 bits, alg_bits);
335         CC->redirect_ssl = 1;
336 }
337
338
339 /*
340  * cmd_gtls() returns status info about the TLS connection
341  */
342 void cmd_gtls(char *params)
343 {
344         int bits, alg_bits;
345
346         if (!CC->ssl || !CC->redirect_ssl) {
347                 cprintf("%d Session is not encrypted.\n", ERROR);
348                 return;
349         }
350         bits =
351             SSL_CIPHER_get_bits(SSL_get_current_cipher(CC->ssl),
352                                 &alg_bits);
353         cprintf("%d %s|%s|%d|%d\n", CIT_OK,
354                 SSL_CIPHER_get_version(SSL_get_current_cipher(CC->ssl)),
355                 SSL_CIPHER_get_name(SSL_get_current_cipher(CC->ssl)),
356                 alg_bits, bits);
357 }
358
359
360 /*
361  * endtls() shuts down the TLS connection
362  *
363  * WARNING:  This may make your session vulnerable to a known plaintext
364  * attack in the current implmentation.
365  */
366 void endtls(void)
367 {
368         lprintf(7, "Ending SSL/TLS\n");
369
370         if (!CC->ssl) {
371                 CC->redirect_ssl = 0;
372                 return;
373         }
374
375         SSL_shutdown(CC->ssl);
376         SSL_free(CC->ssl);
377         CC->ssl = NULL;
378         CC->redirect_ssl = 0;
379 }
380
381
382 /*
383  * ssl_lock() callback for OpenSSL mutex locks
384  */
385 void ssl_lock(int mode, int n, const char *file, int line)
386 {
387         if (mode & CRYPTO_LOCK)
388                 pthread_mutex_lock(SSLCritters[n]);
389         else
390                 pthread_mutex_unlock(SSLCritters[n]);
391 }
392 #endif                          /* HAVE_OPENSSL */