lib/mpi: Fix SG miter leak
[cascardo/linux.git] / lib / mpi / mpicoder.c
1 /* mpicoder.c  -  Coder for the external representation of MPIs
2  * Copyright (C) 1998, 1999 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuPG.
5  *
6  * GnuPG is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuPG is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
19  */
20
21 #include <linux/bitops.h>
22 #include <linux/count_zeros.h>
23 #include <linux/byteorder/generic.h>
24 #include <linux/scatterlist.h>
25 #include <linux/string.h>
26 #include "mpi-internal.h"
27
28 #define MAX_EXTERN_MPI_BITS 16384
29
30 /**
31  * mpi_read_raw_data - Read a raw byte stream as a positive integer
32  * @xbuffer: The data to read
33  * @nbytes: The amount of data to read
34  */
35 MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes)
36 {
37         const uint8_t *buffer = xbuffer;
38         int i, j;
39         unsigned nbits, nlimbs;
40         mpi_limb_t a;
41         MPI val = NULL;
42
43         while (nbytes > 0 && buffer[0] == 0) {
44                 buffer++;
45                 nbytes--;
46         }
47
48         nbits = nbytes * 8;
49         if (nbits > MAX_EXTERN_MPI_BITS) {
50                 pr_info("MPI: mpi too large (%u bits)\n", nbits);
51                 return NULL;
52         }
53         if (nbytes > 0)
54                 nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8);
55
56         nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
57         val = mpi_alloc(nlimbs);
58         if (!val)
59                 return NULL;
60         val->nbits = nbits;
61         val->sign = 0;
62         val->nlimbs = nlimbs;
63
64         if (nbytes > 0) {
65                 i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
66                 i %= BYTES_PER_MPI_LIMB;
67                 for (j = nlimbs; j > 0; j--) {
68                         a = 0;
69                         for (; i < BYTES_PER_MPI_LIMB; i++) {
70                                 a <<= 8;
71                                 a |= *buffer++;
72                         }
73                         i = 0;
74                         val->d[j - 1] = a;
75                 }
76         }
77         return val;
78 }
79 EXPORT_SYMBOL_GPL(mpi_read_raw_data);
80
81 MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread)
82 {
83         const uint8_t *buffer = xbuffer;
84         unsigned int nbits, nbytes;
85         MPI val;
86
87         if (*ret_nread < 2)
88                 return ERR_PTR(-EINVAL);
89         nbits = buffer[0] << 8 | buffer[1];
90
91         if (nbits > MAX_EXTERN_MPI_BITS) {
92                 pr_info("MPI: mpi too large (%u bits)\n", nbits);
93                 return ERR_PTR(-EINVAL);
94         }
95
96         nbytes = DIV_ROUND_UP(nbits, 8);
97         if (nbytes + 2 > *ret_nread) {
98                 pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n",
99                                 nbytes, *ret_nread);
100                 return ERR_PTR(-EINVAL);
101         }
102
103         val = mpi_read_raw_data(buffer + 2, nbytes);
104         if (!val)
105                 return ERR_PTR(-ENOMEM);
106
107         *ret_nread = nbytes + 2;
108         return val;
109 }
110 EXPORT_SYMBOL_GPL(mpi_read_from_buffer);
111
112 static int count_lzeros(MPI a)
113 {
114         mpi_limb_t alimb;
115         int i, lzeros = 0;
116
117         for (i = a->nlimbs - 1; i >= 0; i--) {
118                 alimb = a->d[i];
119                 if (alimb == 0) {
120                         lzeros += sizeof(mpi_limb_t);
121                 } else {
122                         lzeros += count_leading_zeros(alimb) / 8;
123                         break;
124                 }
125         }
126         return lzeros;
127 }
128
129 /**
130  * mpi_read_buffer() - read MPI to a bufer provided by user (msb first)
131  *
132  * @a:          a multi precision integer
133  * @buf:        bufer to which the output will be written to. Needs to be at
134  *              leaset mpi_get_size(a) long.
135  * @buf_len:    size of the buf.
136  * @nbytes:     receives the actual length of the data written on success and
137  *              the data to-be-written on -EOVERFLOW in case buf_len was too
138  *              small.
139  * @sign:       if not NULL, it will be set to the sign of a.
140  *
141  * Return:      0 on success or error code in case of error
142  */
143 int mpi_read_buffer(MPI a, uint8_t *buf, unsigned buf_len, unsigned *nbytes,
144                     int *sign)
145 {
146         uint8_t *p;
147 #if BYTES_PER_MPI_LIMB == 4
148         __be32 alimb;
149 #elif BYTES_PER_MPI_LIMB == 8
150         __be64 alimb;
151 #else
152 #error please implement for this limb size.
153 #endif
154         unsigned int n = mpi_get_size(a);
155         int i, lzeros;
156
157         if (!buf || !nbytes)
158                 return -EINVAL;
159
160         if (sign)
161                 *sign = a->sign;
162
163         lzeros = count_lzeros(a);
164
165         if (buf_len < n - lzeros) {
166                 *nbytes = n - lzeros;
167                 return -EOVERFLOW;
168         }
169
170         p = buf;
171         *nbytes = n - lzeros;
172
173         for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB,
174                         lzeros %= BYTES_PER_MPI_LIMB;
175                 i >= 0; i--) {
176 #if BYTES_PER_MPI_LIMB == 4
177                 alimb = cpu_to_be32(a->d[i]);
178 #elif BYTES_PER_MPI_LIMB == 8
179                 alimb = cpu_to_be64(a->d[i]);
180 #else
181 #error please implement for this limb size.
182 #endif
183                 memcpy(p, (u8 *)&alimb + lzeros, BYTES_PER_MPI_LIMB - lzeros);
184                 p += BYTES_PER_MPI_LIMB - lzeros;
185                 lzeros = 0;
186         }
187         return 0;
188 }
189 EXPORT_SYMBOL_GPL(mpi_read_buffer);
190
191 /*
192  * mpi_get_buffer() - Returns an allocated buffer with the MPI (msb first).
193  * Caller must free the return string.
194  * This function does return a 0 byte buffer with nbytes set to zero if the
195  * value of A is zero.
196  *
197  * @a:          a multi precision integer.
198  * @nbytes:     receives the length of this buffer.
199  * @sign:       if not NULL, it will be set to the sign of the a.
200  *
201  * Return:      Pointer to MPI buffer or NULL on error
202  */
203 void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign)
204 {
205         uint8_t *buf;
206         unsigned int n;
207         int ret;
208
209         if (!nbytes)
210                 return NULL;
211
212         n = mpi_get_size(a);
213
214         if (!n)
215                 n++;
216
217         buf = kmalloc(n, GFP_KERNEL);
218
219         if (!buf)
220                 return NULL;
221
222         ret = mpi_read_buffer(a, buf, n, nbytes, sign);
223
224         if (ret) {
225                 kfree(buf);
226                 return NULL;
227         }
228         return buf;
229 }
230 EXPORT_SYMBOL_GPL(mpi_get_buffer);
231
232 /**
233  * mpi_write_to_sgl() - Funnction exports MPI to an sgl (msb first)
234  *
235  * This function works in the same way as the mpi_read_buffer, but it
236  * takes an sgl instead of u8 * buf.
237  *
238  * @a:          a multi precision integer
239  * @sgl:        scatterlist to write to. Needs to be at least
240  *              mpi_get_size(a) long.
241  * @nbytes:     the number of bytes to write.  Leading bytes will be
242  *              filled with zero.
243  * @sign:       if not NULL, it will be set to the sign of a.
244  *
245  * Return:      0 on success or error code in case of error
246  */
247 int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,
248                      int *sign)
249 {
250         u8 *p, *p2;
251 #if BYTES_PER_MPI_LIMB == 4
252         __be32 alimb;
253 #elif BYTES_PER_MPI_LIMB == 8
254         __be64 alimb;
255 #else
256 #error please implement for this limb size.
257 #endif
258         unsigned int n = mpi_get_size(a);
259         struct sg_mapping_iter miter;
260         int i, x, buf_len;
261         int nents;
262
263         if (sign)
264                 *sign = a->sign;
265
266         if (nbytes < n)
267                 return -EOVERFLOW;
268
269         nents = sg_nents_for_len(sgl, nbytes);
270         if (nents < 0)
271                 return -EINVAL;
272
273         sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG);
274         sg_miter_next(&miter);
275         buf_len = miter.length;
276         p2 = miter.addr;
277
278         while (nbytes > n) {
279                 i = min_t(unsigned, nbytes - n, buf_len);
280                 memset(p2, 0, i);
281                 p2 += i;
282                 nbytes -= i;
283
284                 buf_len -= i;
285                 if (!buf_len) {
286                         sg_miter_next(&miter);
287                         buf_len = miter.length;
288                         p2 = miter.addr;
289                 }
290         }
291
292         for (i = a->nlimbs - 1; i >= 0; i--) {
293 #if BYTES_PER_MPI_LIMB == 4
294                 alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;
295 #elif BYTES_PER_MPI_LIMB == 8
296                 alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;
297 #else
298 #error please implement for this limb size.
299 #endif
300                 p = (u8 *)&alimb;
301
302                 for (x = 0; x < sizeof(alimb); x++) {
303                         *p2++ = *p++;
304                         if (!--buf_len) {
305                                 sg_miter_next(&miter);
306                                 buf_len = miter.length;
307                                 p2 = miter.addr;
308                         }
309                 }
310         }
311
312         sg_miter_stop(&miter);
313         return 0;
314 }
315 EXPORT_SYMBOL_GPL(mpi_write_to_sgl);
316
317 /*
318  * mpi_read_raw_from_sgl() - Function allocates an MPI and populates it with
319  *                           data from the sgl
320  *
321  * This function works in the same way as the mpi_read_raw_data, but it
322  * takes an sgl instead of void * buffer. i.e. it allocates
323  * a new MPI and reads the content of the sgl to the MPI.
324  *
325  * @sgl:        scatterlist to read from
326  * @nbytes:     number of bytes to read
327  *
328  * Return:      Pointer to a new MPI or NULL on error
329  */
330 MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
331 {
332         struct sg_mapping_iter miter;
333         unsigned int nbits, nlimbs;
334         int x, j, z, lzeros, ents;
335         unsigned int len;
336         const u8 *buff;
337         mpi_limb_t a;
338         MPI val = NULL;
339
340         ents = sg_nents_for_len(sgl, nbytes);
341         if (ents < 0)
342                 return NULL;
343
344         sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG);
345
346         lzeros = 0;
347         len = 0;
348         while (nbytes > 0) {
349                 while (len && !*buff) {
350                         lzeros++;
351                         len--;
352                         buff++;
353                 }
354
355                 if (len && *buff)
356                         break;
357
358                 sg_miter_next(&miter);
359                 buff = miter.addr;
360                 len = miter.length;
361
362                 nbytes -= lzeros;
363                 lzeros = 0;
364         }
365
366         miter.consumed = lzeros;
367         sg_miter_stop(&miter);
368
369         nbytes -= lzeros;
370         nbits = nbytes * 8;
371         if (nbits > MAX_EXTERN_MPI_BITS) {
372                 pr_info("MPI: mpi too large (%u bits)\n", nbits);
373                 return NULL;
374         }
375
376         if (nbytes > 0)
377                 nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8);
378
379         nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
380         val = mpi_alloc(nlimbs);
381         if (!val)
382                 return NULL;
383
384         val->nbits = nbits;
385         val->sign = 0;
386         val->nlimbs = nlimbs;
387
388         if (nbytes == 0)
389                 return val;
390
391         j = nlimbs - 1;
392         a = 0;
393         z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
394         z %= BYTES_PER_MPI_LIMB;
395
396         while (sg_miter_next(&miter)) {
397                 buff = miter.addr;
398                 len = miter.length;
399
400                 for (x = 0; x < len; x++) {
401                         a <<= 8;
402                         a |= *buff++;
403                         if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) {
404                                 val->d[j--] = a;
405                                 a = 0;
406                         }
407                 }
408                 z += x;
409         }
410
411         return val;
412 }
413 EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl);