lib/mpi: Fix SG miter leak
[cascardo/linux.git] / lib / mpi / mpicoder.c
index 747606f..5a0f75a 100644 (file)
@@ -21,6 +21,7 @@
 #include <linux/bitops.h>
 #include <linux/count_zeros.h>
 #include <linux/byteorder/generic.h>
+#include <linux/scatterlist.h>
 #include <linux/string.h>
 #include "mpi-internal.h"
 
@@ -50,9 +51,7 @@ MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes)
                return NULL;
        }
        if (nbytes > 0)
-               nbits -= count_leading_zeros(buffer[0]);
-       else
-               nbits = 0;
+               nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8);
 
        nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
        val = mpi_alloc(nlimbs);
@@ -82,50 +81,30 @@ EXPORT_SYMBOL_GPL(mpi_read_raw_data);
 MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread)
 {
        const uint8_t *buffer = xbuffer;
-       int i, j;
-       unsigned nbits, nbytes, nlimbs, nread = 0;
-       mpi_limb_t a;
-       MPI val = NULL;
+       unsigned int nbits, nbytes;
+       MPI val;
 
        if (*ret_nread < 2)
-               goto leave;
+               return ERR_PTR(-EINVAL);
        nbits = buffer[0] << 8 | buffer[1];
 
        if (nbits > MAX_EXTERN_MPI_BITS) {
                pr_info("MPI: mpi too large (%u bits)\n", nbits);
-               goto leave;
+               return ERR_PTR(-EINVAL);
        }
-       buffer += 2;
-       nread = 2;
 
        nbytes = DIV_ROUND_UP(nbits, 8);
-       nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
-       val = mpi_alloc(nlimbs);
-       if (!val)
-               return NULL;
-       i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
-       i %= BYTES_PER_MPI_LIMB;
-       val->nbits = nbits;
-       j = val->nlimbs = nlimbs;
-       val->sign = 0;
-       for (; j > 0; j--) {
-               a = 0;
-               for (; i < BYTES_PER_MPI_LIMB; i++) {
-                       if (++nread > *ret_nread) {
-                               printk
-                                   ("MPI: mpi larger than buffer nread=%d ret_nread=%d\n",
-                                    nread, *ret_nread);
-                               goto leave;
-                       }
-                       a <<= 8;
-                       a |= *buffer++;
-               }
-               i = 0;
-               val->d[j - 1] = a;
+       if (nbytes + 2 > *ret_nread) {
+               pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n",
+                               nbytes, *ret_nread);
+               return ERR_PTR(-EINVAL);
        }
 
-leave:
-       *ret_nread = nread;
+       val = mpi_read_raw_data(buffer + 2, nbytes);
+       if (!val)
+               return ERR_PTR(-ENOMEM);
+
+       *ret_nread = nbytes + 2;
        return val;
 }
 EXPORT_SYMBOL_GPL(mpi_read_from_buffer);
@@ -250,82 +229,6 @@ void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign)
 }
 EXPORT_SYMBOL_GPL(mpi_get_buffer);
 
-/****************
- * Use BUFFER to update MPI.
- */
-int mpi_set_buffer(MPI a, const void *xbuffer, unsigned nbytes, int sign)
-{
-       const uint8_t *buffer = xbuffer, *p;
-       mpi_limb_t alimb;
-       int nlimbs;
-       int i;
-
-       nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
-       if (RESIZE_IF_NEEDED(a, nlimbs) < 0)
-               return -ENOMEM;
-       a->sign = sign;
-
-       for (i = 0, p = buffer + nbytes - 1; p >= buffer + BYTES_PER_MPI_LIMB;) {
-#if BYTES_PER_MPI_LIMB == 4
-               alimb = (mpi_limb_t) *p--;
-               alimb |= (mpi_limb_t) *p-- << 8;
-               alimb |= (mpi_limb_t) *p-- << 16;
-               alimb |= (mpi_limb_t) *p-- << 24;
-#elif BYTES_PER_MPI_LIMB == 8
-               alimb = (mpi_limb_t) *p--;
-               alimb |= (mpi_limb_t) *p-- << 8;
-               alimb |= (mpi_limb_t) *p-- << 16;
-               alimb |= (mpi_limb_t) *p-- << 24;
-               alimb |= (mpi_limb_t) *p-- << 32;
-               alimb |= (mpi_limb_t) *p-- << 40;
-               alimb |= (mpi_limb_t) *p-- << 48;
-               alimb |= (mpi_limb_t) *p-- << 56;
-#else
-#error please implement for this limb size.
-#endif
-               a->d[i++] = alimb;
-       }
-       if (p >= buffer) {
-#if BYTES_PER_MPI_LIMB == 4
-               alimb = *p--;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 8;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 16;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 24;
-#elif BYTES_PER_MPI_LIMB == 8
-               alimb = (mpi_limb_t) *p--;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 8;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 16;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 24;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 32;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 40;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 48;
-               if (p >= buffer)
-                       alimb |= (mpi_limb_t) *p-- << 56;
-#else
-#error please implement for this limb size.
-#endif
-               a->d[i++] = alimb;
-       }
-       a->nlimbs = i;
-
-       if (i != nlimbs) {
-               pr_emerg("MPI: mpi_set_buffer: Assertion failed (%d != %d)", i,
-                      nlimbs);
-               BUG();
-       }
-       return 0;
-}
-EXPORT_SYMBOL_GPL(mpi_set_buffer);
-
 /**
  * mpi_write_to_sgl() - Funnction exports MPI to an sgl (msb first)
  *
@@ -335,16 +238,13 @@ EXPORT_SYMBOL_GPL(mpi_set_buffer);
  * @a:         a multi precision integer
  * @sgl:       scatterlist to write to. Needs to be at least
  *             mpi_get_size(a) long.
- * @nbytes:    in/out param - it has the be set to the maximum number of
- *             bytes that can be written to sgl. This has to be at least
- *             the size of the integer a. On return it receives the actual
- *             length of the data written on success or the data that would
- *             be written if buffer was too small.
+ * @nbytes:    the number of bytes to write.  Leading bytes will be
+ *             filled with zero.
  * @sign:      if not NULL, it will be set to the sign of a.
  *
  * Return:     0 on success or error code in case of error
  */
-int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes,
+int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,
                     int *sign)
 {
        u8 *p, *p2;
@@ -356,55 +256,60 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes,
 #error please implement for this limb size.
 #endif
        unsigned int n = mpi_get_size(a);
-       int i, x, y = 0, lzeros, buf_len;
-
-       if (!nbytes)
-               return -EINVAL;
+       struct sg_mapping_iter miter;
+       int i, x, buf_len;
+       int nents;
 
        if (sign)
                *sign = a->sign;
 
-       lzeros = count_lzeros(a);
-
-       if (*nbytes < n - lzeros) {
-               *nbytes = n - lzeros;
+       if (nbytes < n)
                return -EOVERFLOW;
-       }
 
-       *nbytes = n - lzeros;
-       buf_len = sgl->length;
-       p2 = sg_virt(sgl);
+       nents = sg_nents_for_len(sgl, nbytes);
+       if (nents < 0)
+               return -EINVAL;
 
-       for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB,
-                       lzeros %= BYTES_PER_MPI_LIMB;
-               i >= 0; i--) {
+       sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG);
+       sg_miter_next(&miter);
+       buf_len = miter.length;
+       p2 = miter.addr;
+
+       while (nbytes > n) {
+               i = min_t(unsigned, nbytes - n, buf_len);
+               memset(p2, 0, i);
+               p2 += i;
+               nbytes -= i;
+
+               buf_len -= i;
+               if (!buf_len) {
+                       sg_miter_next(&miter);
+                       buf_len = miter.length;
+                       p2 = miter.addr;
+               }
+       }
+
+       for (i = a->nlimbs - 1; i >= 0; i--) {
 #if BYTES_PER_MPI_LIMB == 4
-               alimb = cpu_to_be32(a->d[i]);
+               alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;
 #elif BYTES_PER_MPI_LIMB == 8
-               alimb = cpu_to_be64(a->d[i]);
+               alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;
 #else
 #error please implement for this limb size.
 #endif
-               if (lzeros) {
-                       y = lzeros;
-                       lzeros = 0;
-               }
-
-               p = (u8 *)&alimb + y;
+               p = (u8 *)&alimb;
 
-               for (x = 0; x < sizeof(alimb) - y; x++) {
-                       if (!buf_len) {
-                               sgl = sg_next(sgl);
-                               if (!sgl)
-                                       return -EINVAL;
-                               buf_len = sgl->length;
-                               p2 = sg_virt(sgl);
-                       }
+               for (x = 0; x < sizeof(alimb); x++) {
                        *p2++ = *p++;
-                       buf_len--;
+                       if (!--buf_len) {
+                               sg_miter_next(&miter);
+                               buf_len = miter.length;
+                               p2 = miter.addr;
+                       }
                }
-               y = 0;
        }
+
+       sg_miter_stop(&miter);
        return 0;
 }
 EXPORT_SYMBOL_GPL(mpi_write_to_sgl);
@@ -424,19 +329,23 @@ EXPORT_SYMBOL_GPL(mpi_write_to_sgl);
  */
 MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
 {
-       struct scatterlist *sg;
-       int x, i, j, z, lzeros, ents;
+       struct sg_mapping_iter miter;
        unsigned int nbits, nlimbs;
+       int x, j, z, lzeros, ents;
+       unsigned int len;
+       const u8 *buff;
        mpi_limb_t a;
        MPI val = NULL;
 
-       lzeros = 0;
-       ents = sg_nents(sgl);
+       ents = sg_nents_for_len(sgl, nbytes);
+       if (ents < 0)
+               return NULL;
 
-       for_each_sg(sgl, sg, ents, i) {
-               const u8 *buff = sg_virt(sg);
-               int len = sg->length;
+       sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG);
 
+       lzeros = 0;
+       len = 0;
+       while (nbytes > 0) {
                while (len && !*buff) {
                        lzeros++;
                        len--;
@@ -446,12 +355,17 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
                if (len && *buff)
                        break;
 
-               ents--;
+               sg_miter_next(&miter);
+               buff = miter.addr;
+               len = miter.length;
+
                nbytes -= lzeros;
                lzeros = 0;
        }
 
-       sgl = sg;
+       miter.consumed = lzeros;
+       sg_miter_stop(&miter);
+
        nbytes -= lzeros;
        nbits = nbytes * 8;
        if (nbits > MAX_EXTERN_MPI_BITS) {
@@ -460,8 +374,7 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
        }
 
        if (nbytes > 0)
-               nbits -= count_leading_zeros(*(u8 *)(sg_virt(sgl) + lzeros)) -
-                       (BITS_PER_LONG - 8);
+               nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8);
 
        nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
        val = mpi_alloc(nlimbs);
@@ -480,21 +393,21 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
        z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
        z %= BYTES_PER_MPI_LIMB;
 
-       for_each_sg(sgl, sg, ents, i) {
-               const u8 *buffer = sg_virt(sg) + lzeros;
-               int len = sg->length - lzeros;
+       while (sg_miter_next(&miter)) {
+               buff = miter.addr;
+               len = miter.length;
 
                for (x = 0; x < len; x++) {
                        a <<= 8;
-                       a |= *buffer++;
+                       a |= *buff++;
                        if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) {
                                val->d[j--] = a;
                                a = 0;
                        }
                }
                z += x;
-               lzeros = 0;
        }
+
        return val;
 }
 EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl);