Check return code for write when sending messages to server.
[cascardo/libreceita.git] / rnetclient.c
1 /*
2  *  Copyright (C) 2012-2013  Thadeu Lima de Souza Cascardo <cascardo@minaslivre.org>
3  *
4  *  This program is free software; you can redistribute it and/or modify
5  *  it under the terms of the GNU General Public License as published by
6  *  the Free Software Foundation; either version 3 of the License, or
7  *  (at your option) any later version.
8  *
9  *  This program is distributed in the hope that it will be useful,
10  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *  GNU General Public License for more details.
13  *
14  *  You should have received a copy of the GNU General Public License along
15  *  with this program; if not, write to the Free Software Foundation, Inc.,
16  *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17  */
18
19 #include <string.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <errno.h>
23 #include <unistd.h>
24 #include <sys/socket.h>
25 #include <netinet/in.h>
26 #include <arpa/inet.h>
27 #include <netdb.h>
28 #include <gnutls/gnutls.h>
29 #include <zlib.h>
30 #include "decfile.h"
31 #include "rnet_message.h"
32 #include "rnet_encode.h"
33
34 static size_t chars2len (unsigned char buf[2]) {
35         return (buf[0] << 8 | buf[1]);
36 }
37
38 static void * get_creds(char *certfile)
39 {
40         static gnutls_certificate_credentials_t cred;
41         gnutls_certificate_allocate_credentials(&cred);
42         gnutls_certificate_set_x509_trust_file(cred, certfile,
43                                         GNUTLS_X509_FMT_PEM);
44         return cred;
45 }
46
47 static void session_new(gnutls_session_t *session)
48 {
49         static void *cred;
50         cred = get_creds("cert.pem");
51         gnutls_init(session, GNUTLS_CLIENT);
52         gnutls_set_default_priority(*session);
53         gnutls_credentials_set(*session, GNUTLS_CRD_CERTIFICATE, cred);
54 }
55
56 static int deflateRecord(char *buffer, size_t len, char **out, size_t *olen, int header)
57 {
58         z_stream zstrm;
59         int r;
60         zstrm.zalloc = Z_NULL;
61         zstrm.zfree = Z_NULL;
62         zstrm.opaque = Z_NULL;
63         if ((r = deflateInit(&zstrm, Z_DEFAULT_COMPRESSION)) != Z_OK)
64                 return -1;
65         *out = malloc(len * 2 + 36);
66         if (!out) {
67                 deflateEnd(&zstrm);
68                 return -1;
69         }
70         zstrm.next_in = buffer;
71         zstrm.avail_in = len;
72         zstrm.next_out = *out + 6;
73         zstrm.avail_out = len * 2 + 30;
74         while ((r = deflate(&zstrm, Z_FINISH)) != Z_STREAM_END &&
75                 zstrm.avail_out > 0);
76         if ((r = deflate(&zstrm, Z_FINISH)) != Z_STREAM_END) {
77                 deflateEnd(&zstrm);
78                 free(*out);
79                 return -1;
80         }
81         *olen = zstrm.total_out + 6;
82         (*out)[0] = 0x1;
83         (*out)[1] = (zstrm.total_out >> 8);
84         (*out)[2] = (zstrm.total_out & 0xff);
85         (*out)[3] = (len >> 8);
86         (*out)[4] = (len & 0xff);
87         (*out)[5] = header ? 0x01 : 0x0;
88         deflateEnd(&zstrm);
89         return 0;
90 }
91
92 static int inflateRecord(char *buffer, size_t len, char **out, size_t *olen)
93 {
94         z_stream zstrm;
95         int r;
96         zstrm.zalloc = Z_NULL;
97         zstrm.zfree = Z_NULL;
98         zstrm.opaque = Z_NULL;
99         if ((r = inflateInit(&zstrm)) != Z_OK)
100                 return -1;
101         *olen = chars2len(buffer+3);
102         *out = malloc(*olen);
103         if (!out) {
104                 inflateEnd(&zstrm);
105                 return -1;
106         }
107         zstrm.next_in = buffer + 6;
108         zstrm.avail_in = len - 6;
109         zstrm.next_out = *out;
110         zstrm.avail_out = *olen;
111         while ((r = inflate(&zstrm, Z_FINISH)) != Z_STREAM_END &&
112                 zstrm.avail_out > 0);
113         if ((r = inflate(&zstrm, Z_FINISH)) != Z_STREAM_END) {
114                 inflateEnd(&zstrm);
115                 free(*out);
116                 return -1;
117         }
118         inflateEnd(&zstrm);
119         return 0;
120 }
121
122 #define RNET_ADDRESS "receitanet.receita.fazenda.gov.br"
123
124 static int connect_rnet(int *c)
125 {
126         struct addrinfo *addresses;
127         struct addrinfo *addr;
128         struct addrinfo hint;
129         struct sockaddr_in saddr;
130         int r;
131         int fd = *c = -1;
132         int i;
133         memset(&hint, 0, sizeof(hint));
134         hint.ai_family = AF_UNSPEC;
135         hint.ai_socktype = SOCK_STREAM;
136         hint.ai_protocol = IPPROTO_TCP;
137         hint.ai_flags = AI_ADDRCONFIG;
138         r = getaddrinfo(RNET_ADDRESS, "3456", &hint, &addresses);
139         if (r) {
140                 return r;
141         }
142         for (addr = addresses; addr != NULL; addr = addr->ai_next) {
143                 fd = socket(addr->ai_family, addr->ai_socktype,
144                                 addr->ai_protocol);
145                 if (fd >= 0)
146                         if (!(r = connect(fd, addr->ai_addr,
147                                                 addr->ai_addrlen)))
148                                 break;
149                 close(fd);
150                 fd = -1;
151         }
152         freeaddrinfo(addresses);
153         *c = fd;
154         if (fd == -1)
155                 return EAI_SYSTEM;
156         return 0;
157 }
158
159 static int handshake(int c)
160 {
161         char buffer[16];
162         int r;
163         buffer[0] = 1;
164         r = write(c, buffer, 1);
165         if (r < 1)
166                 return -1;
167         r = write(c, "00000000000000", 14);
168         if (r < 14)
169                 return -1;
170         r = read(c, buffer, 1);
171         if (r != 1 && buffer[0] != 'E')
172                 return -1;
173         r = read(c, buffer, 14);
174         if (r != 14)
175                 return -1;
176         return 0;
177 }
178
179 static void usage(void)
180 {
181         fprintf(stderr, "rnetclient [filename]\n");
182         exit(1);
183 }
184
185 static int rnet_send(gnutls_session_t session, char *buffer, size_t len, int header)
186 {
187         int r = 0;
188         /* Large files have to be uploaded as multiple
189            separately-deflated chunks, because the compressed and
190            uncompressed lengths in each record are encoded in unsigned
191            16-bit integers each.
192
193            The header can't be split into multiple chunks, and it
194            should never have to, since it won't ever get even close to
195            64KiB.
196
197            The uploaded file may be larger: to upload such large
198            files, it suffices to send multiple records till the entire
199            file is transferred, without waiting for a response.  Since
200            we've alread informed the server of the file size in the
201            header, it knows exactly how much data to expect before
202            sending a response.  It will only send an error message
203            before that if it times us out.
204
205            Odds are that any reasonably large size will do, but it
206            can't be too close to 64KiB, otherwise there won't be room
207            for the compressed length should it not compress well,
208            which should never happen for capital-ASCII-only
209            declaration files, but who knows?
210
211            This chunk size worked at the first try, uploading a
212            ~100KiB file, so let's stick with it.  */
213         const int maxc = 64472;
214         if (header && len > maxc)
215                 return -1;
216
217         do {
218                 char *out = NULL;
219                 size_t olen;
220                 size_t clen = len < maxc ? len : maxc;
221                 r = deflateRecord(buffer, clen, &out, &olen, header);
222                 if (!r) {
223                         size_t n = gnutls_record_send(session, out, olen);
224                         if (n != olen)
225                                 r = -1;
226                 }
227                 free(out);
228                 buffer += clen;
229                 len -= clen;
230         } while (len && !r);
231         return r;
232 }
233
234 static int rnet_recv(gnutls_session_t session, struct rnet_message **message)
235 {
236         char *out;
237         size_t olen;
238         int r;
239         char *buffer;
240         size_t len;
241         rnet_message_expand(message, 6);
242         buffer = (*message)->buffer;
243         r = gnutls_record_recv(session, buffer, 6);
244         if (buffer[0] == 0x01) {
245                 len = chars2len(buffer+1);
246                 rnet_message_expand(message, len);
247                 buffer = (*message)->buffer + 6;
248                 r = gnutls_record_recv(session, buffer, len);
249                 inflateRecord(buffer - 6, len + 6, &out, &olen);
250                 rnet_message_del(*message);
251                 *message = NULL;
252                 rnet_message_expand(message, olen);
253                 memcpy((*message)->buffer, out, olen);
254                 (*message)->len = olen;
255                 free(out);
256         } else {
257                 len = chars2len(buffer+1);
258                 rnet_message_expand(message, len - 1);
259                 buffer = (*message)->buffer + 6;
260                 r = gnutls_record_recv(session, buffer, len - 1);
261                 (*message)->len = len + 4;
262                 rnet_message_strip(*message, 4);
263         }
264         return 0;
265 }
266
267 static void save_rec_file(char *cpf, char *buffer, int len)
268 {
269         int fd;
270         char *filename;
271         char *home, *tmpdir;
272         mode_t mask;
273         size_t fnlen;
274         int r;
275         home = getenv("HOME");
276         if (!home) {
277                 tmpdir = getenv("TMPDIR");
278                 if (!tmpdir)
279                         tmpdir = "/tmp";
280                 home = tmpdir;
281         }
282         fnlen = strlen(home) + strlen(cpf) + 13;
283         filename = malloc(fnlen);
284         snprintf(filename, fnlen, "%s/%s.REC.XXXXXX", home, cpf);
285         mask = umask(0177);
286         fd = mkstemp(filename);
287         if (fd < 0) {
288                 fprintf(stderr, "Could not create receipt file: %s\n",
289                                                 strerror(errno));
290                 goto out;
291         }
292         r = write(fd, buffer, len);
293         if (r != len) {
294                 fprintf(stderr, "Could not write to receipt file%s%s\n",
295                         r < 0 ? ": " : ".",
296                         r < 0 ? strerror(errno) : "");
297                 goto out;
298         }
299         fprintf(stderr, "Wrote the receipt to %s.\n", filename);
300 out:
301         close(fd);
302         free(filename);
303         umask(mask);
304 }
305
306 static void handle_response_text_and_file(char *cpf, struct rnet_message *message)
307 {
308         char *value;
309         int vlen;
310         if (!rnet_message_parse(message, "texto", &value, &vlen))
311                 fprintf(stderr, "%.*s\n", vlen, value);
312         if (!rnet_message_parse(message, "arquivo", &value, &vlen))
313                 save_rec_file(cpf, value, vlen);
314 }
315
316 static void handle_response_already_found(char *cpf, struct rnet_message *message)
317 {
318         handle_response_text_and_file(cpf, message);
319 }
320
321 static void handle_response_error(struct rnet_message *message)
322 {
323         char *value;
324         int vlen;
325         if (!rnet_message_parse(message, "texto", &value, &vlen))
326                 fprintf(stderr, "%.*s\n", vlen, value);
327         fprintf(stderr, "Error transmiting DEC file.\n");
328 }
329
330 int main(int argc, char **argv)
331 {
332         int c;
333         int r;
334         struct rnet_decfile *decfile;
335         struct rnet_message *message = NULL;
336         gnutls_session_t session;
337         int finish = 0;
338         char *cpf;
339         
340         if (argc < 2) {
341                 usage();
342         }
343
344         decfile = rnet_decfile_open(argv[1]);
345         if (!decfile) {
346                 fprintf(stderr, "could not parse %s: %s\n", argv[1], strerror(errno));
347                 exit(1);
348         }
349
350         cpf = rnet_decfile_get_header_field(decfile, "cpf");
351
352         gnutls_global_init();
353
354         session_new(&session);
355         r = connect_rnet(&c);
356         if (r) {
357                 fprintf(stderr, "error connecting to server: %s\n",
358                         r == EAI_SYSTEM ? strerror(errno) : gai_strerror(r));
359                 exit(1);
360         }
361         gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)(intptr_t) c);
362         r = handshake(c);
363         if (r < 0) {
364                 exit(1);
365         }
366         if ((r = gnutls_handshake(session)) < 0)
367                 fprintf(stderr, "error in handshake: %s\n",
368                                 gnutls_strerror(r));
369
370         rnet_encode(decfile, &message);
371         rnet_send(session, message->buffer, message->len, 1);
372         rnet_message_del(message);
373
374         message = NULL;
375         r = rnet_recv(session, &message);
376         if (r || !message || message->len == 0) {
377                 fprintf(stderr, "error when receiving response\n");
378                 goto out;
379         }
380         switch (message->buffer[0]) {
381         case 1: /* go ahead */
382                 handle_response_text_and_file(cpf, message);
383                 break;
384         case 3: /* error */
385                 handle_response_error(message);
386                 finish = 1;
387                 break;
388         case 4:
389                 handle_response_already_found(cpf, message);
390                 finish = 1;
391                 break;
392         case 2:
393         case 5:
394                 handle_response_text_and_file(cpf, message);
395                 finish = 1;
396                 break;
397         }
398         rnet_message_del(message);
399
400         if (finish)
401                 goto out;
402
403         message = rnet_decfile_get_file(decfile);
404         rnet_send(session, message->buffer, message->len, 0);
405
406         message = NULL;
407         r = rnet_recv(session, &message);
408         if (r || !message || message->len == 0) {
409                 fprintf(stderr, "error when receiving response\n");
410                 goto out;
411         }
412         switch (message->buffer[0]) {
413         case 3: /* error */
414                 handle_response_error(message);
415                 break;
416         case 2:
417         case 4:
418         case 5:
419         case 1:
420                 handle_response_text_and_file(cpf, message);
421                 break;
422         }
423         
424 out:
425         gnutls_bye(session, GNUTLS_SHUT_RDWR);
426         close(c);
427         rnet_decfile_close(decfile);
428         gnutls_global_deinit();
429
430         return 0;
431 }