Allow set_driver to fail and handle this case.
[cascardo/rnetproxy.git] / hcconn_ssl.c
1 /*
2 ** Copyright (C) 2006 Thadeu Lima de Souza Cascardo <cascardo@minaslivre.org>
3 ** Copyright (C) 2009 Thadeu Lima de Souza Cascardo <cascardo@holoscopio.com>
4 **  
5 ** This program is free software; you can redistribute it and/or modify
6 ** it under the terms of the GNU General Public License as published by
7 ** the Free Software Foundation; either version 2 of the License, or
8 ** (at your option) any later version.
9 **  
10 ** This program is distributed in the hope that it will be useful,
11 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
12 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 ** GNU General Public License for more details.
14 **  
15 ** You should have received a copy of the GNU General Public License
16 ** along with this program; if not, write to the Free Software
17 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
18 **  
19 */
20
21 #include <gnutls/gnutls.h>
22 #include <glib.h>
23 #include <string.h>
24 #include <errno.h>
25 #include <fcntl.h>
26 #include "hcconn_internal.h"
27
28 struct ssl_data
29 {
30   gnutls_session_t session;
31   GString *buffer;
32   gboolean handshaking;
33   gpointer lowconn;
34 };
35
36 static void
37 ssl_client_session_new (gnutls_session_t *session)
38 {
39   int kx_prio[] = {GNUTLS_KX_RSA, 0};
40   gnutls_certificate_credentials cred;
41   gnutls_certificate_allocate_credentials (&cred);
42   gnutls_init (session, GNUTLS_CLIENT);
43   gnutls_set_default_priority (*session);
44   gnutls_kx_set_priority (*session, kx_prio);
45   gnutls_credentials_set (*session, GNUTLS_CRD_CERTIFICATE, cred);
46 }
47
48 static struct ssl_data *
49 ssl_data_new (int server)
50 {
51   struct ssl_data *ssl;
52   if (server)
53     return NULL;
54   ssl = g_slice_new (struct ssl_data);
55   ssl_client_session_new (&ssl->session);
56   ssl->buffer = g_string_sized_new (4096);
57   ssl->handshaking = FALSE;
58   return ssl;
59 }
60
61 static void
62 ssl_data_destroy (struct ssl_data *ssl)
63 {
64   gnutls_deinit (ssl->session);
65   g_string_free (ssl->buffer, TRUE);
66   g_slice_free (struct ssl_data, ssl);
67 }
68
69 static ssize_t
70 ssl_push (gnutls_transport_ptr_t ptr, const void *buffer, size_t len)
71 {
72   HCConn *conn = ptr;
73   struct ssl_data *ssl = conn->layer;
74   hc_conn_write (ssl->lowconn, (void *) buffer, len);
75   return len;
76 }
77
78 static ssize_t
79 ssl_pull (gnutls_transport_ptr_t ptr, void *buffer, size_t len)
80 {
81   HCConn *conn = ptr;
82   struct ssl_data *ssl = conn->layer;
83   int r;
84   if (ssl->handshaking == TRUE)
85     {
86       r = hc_conn_read (ssl->lowconn, buffer, len);
87       return r;
88     }
89   if (len > ssl->buffer->len)
90     {
91       r = ssl->buffer->len;
92       memcpy (buffer, ssl->buffer->str, r);
93       g_string_truncate (ssl->buffer, 0);
94     }
95   else
96     {
97       r = len;
98       memcpy (buffer, ssl->buffer->str, r);
99       g_string_erase (ssl->buffer, 0, r);
100     }
101   if (r == 0)
102     {
103       gnutls_transport_set_errno (ssl->session, EAGAIN);
104       return -1;
105     }
106   return r;
107 }
108
109 static void
110 ssl_server_handshake (struct ssl_data *ssl)
111 {
112   int error;
113   if ((error = gnutls_handshake (ssl->session)) < 0)
114     {
115       if (gnutls_error_is_fatal (error))
116         g_critical ("Fatal error while doing TLS handshaking: %s\n",
117                     gnutls_strerror (error));
118     }
119   else
120     {
121       ssl->handshaking = FALSE;
122     }
123 }
124
125 static void
126 ssl_server_connect (HCConn *conn)
127 {
128   struct ssl_data *ssl = conn->layer;
129   gnutls_transport_set_ptr (ssl->session, (gnutls_transport_ptr_t) conn);
130   gnutls_transport_set_push_function (ssl->session, ssl_push);
131   gnutls_transport_set_pull_function (ssl->session, ssl_pull);
132   ssl->handshaking = TRUE;
133   ssl_server_handshake (ssl);
134 }
135
136 static void
137 hc_conn_ssl_close (gpointer data)
138 {
139   struct ssl_data *ssl = data;
140   if (ssl != NULL)
141     {
142       gnutls_bye (ssl->session, GNUTLS_SHUT_RDWR);
143       hc_conn_close (ssl->lowconn);
144       ssl_data_destroy (ssl);
145     }
146 }
147
148 static ssize_t
149 hc_conn_ssl_read (gpointer data, gchar *buffer, size_t len)
150 {
151   struct ssl_data *ssl = data;
152   return gnutls_record_recv (ssl->session, buffer, len);
153 }
154
155 static ssize_t
156 hc_conn_ssl_write (gpointer data, gchar *buffer, size_t len)
157 {
158   struct ssl_data *ssl = data;
159   return gnutls_record_send (ssl->session, buffer, len);
160 }
161
162 void
163 hc_conn_ssl_watch (HCConn *conn, HCEvent event, gpointer data)
164 {
165   char buffer[4096];
166   HCConn *ssl_conn = data;
167   struct ssl_data *ssl = ssl_conn->layer;
168   int r;
169   switch (event)
170     {
171     case HC_EVENT_READ:
172       if (ssl->handshaking)
173         {
174           ssl_server_handshake (ssl);
175           return;
176         }
177       while ((r = hc_conn_read (ssl->lowconn, buffer, sizeof (buffer))) > 0)
178         g_string_append_len (ssl->buffer, buffer, r);
179       if (ssl_conn->func && !ssl->handshaking)
180         ssl_conn->func (ssl_conn, event, ssl_conn->data);
181       break;
182     case HC_EVENT_CLOSE:
183       if (ssl_conn->func)
184         ssl_conn->func (ssl_conn, event, ssl_conn->data);
185     }
186 }
187
188 static int
189 hc_conn_set_driver_ssl (HCConn *conn, HCConn *lowconn, int server)
190 {
191   struct ssl_data *ssl;
192   ssl = ssl_data_new (server);
193   if (ssl == NULL)
194     return -1;
195   ssl->lowconn = lowconn;
196   conn->layer = ssl;
197   conn->read = hc_conn_ssl_read;
198   conn->write = hc_conn_ssl_write;
199   conn->close = hc_conn_ssl_close;
200   hc_conn_set_callback (lowconn, hc_conn_ssl_watch, conn);
201   ssl_server_connect (conn);
202   return 0;
203 }
204
205 int
206 hc_conn_set_driver_ssl_client (HCConn *conn, HCConn *lowconn)
207 {
208   return hc_conn_set_driver_ssl (conn, lowconn, 0);
209 }