crypto: rsa - Generate fixed-length output
[cascardo/linux.git] / lib / mpi / mpicoder.c
index 823cf5f..7150e5c 100644 (file)
@@ -237,16 +237,13 @@ EXPORT_SYMBOL_GPL(mpi_get_buffer);
  * @a:         a multi precision integer
  * @sgl:       scatterlist to write to. Needs to be at least
  *             mpi_get_size(a) long.
- * @nbytes:    in/out param - it has the be set to the maximum number of
- *             bytes that can be written to sgl. This has to be at least
- *             the size of the integer a. On return it receives the actual
- *             length of the data written on success or the data that would
- *             be written if buffer was too small.
+ * @nbytes:    the number of bytes to write.  Leading bytes will be
+ *             filled with zero.
  * @sign:      if not NULL, it will be set to the sign of a.
  *
  * Return:     0 on success or error code in case of error
  */
-int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes,
+int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,
                     int *sign)
 {
        u8 *p, *p2;
@@ -258,43 +255,44 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes,
 #error please implement for this limb size.
 #endif
        unsigned int n = mpi_get_size(a);
-       int i, x, y = 0, lzeros, buf_len;
-
-       if (!nbytes)
-               return -EINVAL;
+       int i, x, buf_len;
 
        if (sign)
                *sign = a->sign;
 
-       lzeros = count_lzeros(a);
-
-       if (*nbytes < n - lzeros) {
-               *nbytes = n - lzeros;
+       if (nbytes < n)
                return -EOVERFLOW;
-       }
 
-       *nbytes = n - lzeros;
        buf_len = sgl->length;
        p2 = sg_virt(sgl);
 
-       for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB,
-                       lzeros %= BYTES_PER_MPI_LIMB;
-               i >= 0; i--) {
+       while (nbytes > n) {
+               if (!buf_len) {
+                       sgl = sg_next(sgl);
+                       if (!sgl)
+                               return -EINVAL;
+                       buf_len = sgl->length;
+                       p2 = sg_virt(sgl);
+               }
+
+               i = min_t(unsigned, nbytes - n, buf_len);
+               memset(p2, 0, i);
+               p2 += i;
+               buf_len -= i;
+               nbytes -= i;
+       }
+
+       for (i = a->nlimbs - 1; i >= 0; i--) {
 #if BYTES_PER_MPI_LIMB == 4
-               alimb = cpu_to_be32(a->d[i]);
+               alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;
 #elif BYTES_PER_MPI_LIMB == 8
-               alimb = cpu_to_be64(a->d[i]);
+               alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;
 #else
 #error please implement for this limb size.
 #endif
-               if (lzeros) {
-                       y = lzeros;
-                       lzeros = 0;
-               }
-
-               p = (u8 *)&alimb + y;
+               p = (u8 *)&alimb;
 
-               for (x = 0; x < sizeof(alimb) - y; x++) {
+               for (x = 0; x < sizeof(alimb); x++) {
                        if (!buf_len) {
                                sgl = sg_next(sgl);
                                if (!sgl)
@@ -305,7 +303,6 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes,
                        *p2++ = *p++;
                        buf_len--;
                }
-               y = 0;
        }
        return 0;
 }