Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / rxrpc / rxkad.c
index 63afa9e..4374e7b 100644 (file)
@@ -80,12 +80,10 @@ static int rxkad_init_connection_security(struct rxrpc_connection *conn)
        case RXRPC_SECURITY_AUTH:
                conn->size_align = 8;
                conn->security_size = sizeof(struct rxkad_level1_hdr);
-               conn->header_size += sizeof(struct rxkad_level1_hdr);
                break;
        case RXRPC_SECURITY_ENCRYPT:
                conn->size_align = 8;
                conn->security_size = sizeof(struct rxkad_level2_hdr);
-               conn->header_size += sizeof(struct rxkad_level2_hdr);
                break;
        default:
                ret = -EKEYREJECTED;
@@ -161,7 +159,7 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
 
        _enter("");
 
-       check = sp->hdr.seq ^ sp->hdr.callNumber;
+       check = sp->hdr.seq ^ call->call_id;
        data_size |= (u32)check << 16;
 
        hdr.data_size = htonl(data_size);
@@ -205,7 +203,7 @@ static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
 
        _enter("");
 
-       check = sp->hdr.seq ^ sp->hdr.callNumber;
+       check = sp->hdr.seq ^ call->call_id;
 
        rxkhdr.data_size = htonl(data_size | (u32)check << 16);
        rxkhdr.checksum = 0;
@@ -275,9 +273,9 @@ static int rxkad_secure_packet(struct rxrpc_call *call,
        memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
 
        /* calculate the security checksum */
-       x = call->channel << (32 - RXRPC_CIDSHIFT);
+       x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
        x |= sp->hdr.seq & 0x3fffffff;
-       call->crypto_buf[0] = htonl(sp->hdr.callNumber);
+       call->crypto_buf[0] = htonl(call->call_id);
        call->crypto_buf[1] = htonl(x);
 
        sg_init_one(&sg, call->crypto_buf, 8);
@@ -316,12 +314,11 @@ static int rxkad_secure_packet(struct rxrpc_call *call,
 /*
  * decrypt partial encryption on a packet (level 1 security)
  */
-static int rxkad_verify_packet_auth(const struct rxrpc_call *call,
-                                   struct sk_buff *skb,
-                                   u32 *_abort_code)
+static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
+                                unsigned int offset, unsigned int len,
+                                rxrpc_seq_t seq)
 {
        struct rxkad_level1_hdr sechdr;
-       struct rxrpc_skb_priv *sp;
        SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
        struct rxrpc_crypt iv;
        struct scatterlist sg[16];
@@ -332,15 +329,20 @@ static int rxkad_verify_packet_auth(const struct rxrpc_call *call,
 
        _enter("");
 
-       sp = rxrpc_skb(skb);
+       if (len < 8) {
+               rxrpc_abort_call("V1H", call, seq, RXKADSEALEDINCON, EPROTO);
+               goto protocol_error;
+       }
 
-       /* we want to decrypt the skbuff in-place */
+       /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
+        * directly into the target buffer.
+        */
        nsg = skb_cow_data(skb, 0, &trailer);
        if (nsg < 0 || nsg > 16)
                goto nomem;
 
        sg_init_table(sg, nsg);
-       skb_to_sgvec(skb, sg, 0, 8);
+       skb_to_sgvec(skb, sg, offset, 8);
 
        /* start the decryption afresh */
        memset(&iv, 0, sizeof(iv));
@@ -351,35 +353,35 @@ static int rxkad_verify_packet_auth(const struct rxrpc_call *call,
        crypto_skcipher_decrypt(req);
        skcipher_request_zero(req);
 
-       /* remove the decrypted packet length */
-       if (skb_copy_bits(skb, 0, &sechdr, sizeof(sechdr)) < 0)
-               goto datalen_error;
-       if (!skb_pull(skb, sizeof(sechdr)))
-               BUG();
+       /* Extract the decrypted packet length */
+       if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
+               rxrpc_abort_call("XV1", call, seq, RXKADDATALEN, EPROTO);
+               goto protocol_error;
+       }
+       offset += sizeof(sechdr);
+       len -= sizeof(sechdr);
 
        buf = ntohl(sechdr.data_size);
        data_size = buf & 0xffff;
 
        check = buf >> 16;
-       check ^= sp->hdr.seq ^ sp->hdr.callNumber;
+       check ^= seq ^ call->call_id;
        check &= 0xffff;
        if (check != 0) {
-               *_abort_code = RXKADSEALEDINCON;
+               rxrpc_abort_call("V1C", call, seq, RXKADSEALEDINCON, EPROTO);
                goto protocol_error;
        }
 
-       /* shorten the packet to remove the padding */
-       if (data_size > skb->len)
-               goto datalen_error;
-       else if (data_size < skb->len)
-               skb->len = data_size;
+       if (data_size > len) {
+               rxrpc_abort_call("V1L", call, seq, RXKADDATALEN, EPROTO);
+               goto protocol_error;
+       }
 
        _leave(" = 0 [dlen=%x]", data_size);
        return 0;
 
-datalen_error:
-       *_abort_code = RXKADDATALEN;
 protocol_error:
+       rxrpc_send_abort_packet(call);
        _leave(" = -EPROTO");
        return -EPROTO;
 
@@ -391,13 +393,12 @@ nomem:
 /*
  * wholly decrypt a packet (level 2 security)
  */
-static int rxkad_verify_packet_encrypt(const struct rxrpc_call *call,
-                                      struct sk_buff *skb,
-                                      u32 *_abort_code)
+static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
+                                unsigned int offset, unsigned int len,
+                                rxrpc_seq_t seq)
 {
        const struct rxrpc_key_token *token;
        struct rxkad_level2_hdr sechdr;
-       struct rxrpc_skb_priv *sp;
        SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
        struct rxrpc_crypt iv;
        struct scatterlist _sg[4], *sg;
@@ -408,9 +409,14 @@ static int rxkad_verify_packet_encrypt(const struct rxrpc_call *call,
 
        _enter(",{%d}", skb->len);
 
-       sp = rxrpc_skb(skb);
+       if (len < 8) {
+               rxrpc_abort_call("V2H", call, seq, RXKADSEALEDINCON, EPROTO);
+               goto protocol_error;
+       }
 
-       /* we want to decrypt the skbuff in-place */
+       /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
+        * directly into the target buffer.
+        */
        nsg = skb_cow_data(skb, 0, &trailer);
        if (nsg < 0)
                goto nomem;
@@ -423,7 +429,7 @@ static int rxkad_verify_packet_encrypt(const struct rxrpc_call *call,
        }
 
        sg_init_table(sg, nsg);
-       skb_to_sgvec(skb, sg, 0, skb->len);
+       skb_to_sgvec(skb, sg, offset, len);
 
        /* decrypt from the session key */
        token = call->conn->params.key->payload.data[0];
@@ -431,41 +437,41 @@ static int rxkad_verify_packet_encrypt(const struct rxrpc_call *call,
 
        skcipher_request_set_tfm(req, call->conn->cipher);
        skcipher_request_set_callback(req, 0, NULL, NULL);
-       skcipher_request_set_crypt(req, sg, sg, skb->len, iv.x);
+       skcipher_request_set_crypt(req, sg, sg, len, iv.x);
        crypto_skcipher_decrypt(req);
        skcipher_request_zero(req);
        if (sg != _sg)
                kfree(sg);
 
-       /* remove the decrypted packet length */
-       if (skb_copy_bits(skb, 0, &sechdr, sizeof(sechdr)) < 0)
-               goto datalen_error;
-       if (!skb_pull(skb, sizeof(sechdr)))
-               BUG();
+       /* Extract the decrypted packet length */
+       if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
+               rxrpc_abort_call("XV2", call, seq, RXKADDATALEN, EPROTO);
+               goto protocol_error;
+       }
+       offset += sizeof(sechdr);
+       len -= sizeof(sechdr);
 
        buf = ntohl(sechdr.data_size);
        data_size = buf & 0xffff;
 
        check = buf >> 16;
-       check ^= sp->hdr.seq ^ sp->hdr.callNumber;
+       check ^= seq ^ call->call_id;
        check &= 0xffff;
        if (check != 0) {
-               *_abort_code = RXKADSEALEDINCON;
+               rxrpc_abort_call("V2C", call, seq, RXKADSEALEDINCON, EPROTO);
                goto protocol_error;
        }
 
-       /* shorten the packet to remove the padding */
-       if (data_size > skb->len)
-               goto datalen_error;
-       else if (data_size < skb->len)
-               skb->len = data_size;
+       if (data_size > len) {
+               rxrpc_abort_call("V2L", call, seq, RXKADDATALEN, EPROTO);
+               goto protocol_error;
+       }
 
        _leave(" = 0 [dlen=%x]", data_size);
        return 0;
 
-datalen_error:
-       *_abort_code = RXKADDATALEN;
 protocol_error:
+       rxrpc_send_abort_packet(call);
        _leave(" = -EPROTO");
        return -EPROTO;
 
@@ -475,40 +481,31 @@ nomem:
 }
 
 /*
- * verify the security on a received packet
+ * Verify the security on a received packet or subpacket (if part of a
+ * jumbo packet).
  */
-static int rxkad_verify_packet(struct rxrpc_call *call,
-                              struct sk_buff *skb,
-                              u32 *_abort_code)
+static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
+                              unsigned int offset, unsigned int len,
+                              rxrpc_seq_t seq, u16 expected_cksum)
 {
        SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
-       struct rxrpc_skb_priv *sp;
        struct rxrpc_crypt iv;
        struct scatterlist sg;
        u16 cksum;
        u32 x, y;
-       int ret;
-
-       sp = rxrpc_skb(skb);
 
        _enter("{%d{%x}},{#%u}",
-              call->debug_id, key_serial(call->conn->params.key), sp->hdr.seq);
+              call->debug_id, key_serial(call->conn->params.key), seq);
 
        if (!call->conn->cipher)
                return 0;
 
-       if (sp->hdr.securityIndex != RXRPC_SECURITY_RXKAD) {
-               *_abort_code = RXKADINCONSISTENCY;
-               _leave(" = -EPROTO [not rxkad]");
-               return -EPROTO;
-       }
-
        /* continue encrypting from where we left off */
        memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
 
        /* validate the security checksum */
-       x = call->channel << (32 - RXRPC_CIDSHIFT);
-       x |= sp->hdr.seq & 0x3fffffff;
+       x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
+       x |= seq & 0x3fffffff;
        call->crypto_buf[0] = htonl(call->call_id);
        call->crypto_buf[1] = htonl(x);
 
@@ -524,29 +521,69 @@ static int rxkad_verify_packet(struct rxrpc_call *call,
        if (cksum == 0)
                cksum = 1; /* zero checksums are not permitted */
 
-       if (sp->hdr.cksum != cksum) {
-               *_abort_code = RXKADSEALEDINCON;
+       if (cksum != expected_cksum) {
+               rxrpc_abort_call("VCK", call, seq, RXKADSEALEDINCON, EPROTO);
+               rxrpc_send_abort_packet(call);
                _leave(" = -EPROTO [csum failed]");
                return -EPROTO;
        }
 
        switch (call->conn->params.security_level) {
        case RXRPC_SECURITY_PLAIN:
-               ret = 0;
-               break;
+               return 0;
        case RXRPC_SECURITY_AUTH:
-               ret = rxkad_verify_packet_auth(call, skb, _abort_code);
-               break;
+               return rxkad_verify_packet_1(call, skb, offset, len, seq);
        case RXRPC_SECURITY_ENCRYPT:
-               ret = rxkad_verify_packet_encrypt(call, skb, _abort_code);
-               break;
+               return rxkad_verify_packet_2(call, skb, offset, len, seq);
        default:
-               ret = -ENOANO;
-               break;
+               return -ENOANO;
        }
+}
 
-       _leave(" = %d", ret);
-       return ret;
+/*
+ * Locate the data contained in a packet that was partially encrypted.
+ */
+static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb,
+                               unsigned int *_offset, unsigned int *_len)
+{
+       struct rxkad_level1_hdr sechdr;
+
+       if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
+               BUG();
+       *_offset += sizeof(sechdr);
+       *_len = ntohl(sechdr.data_size) & 0xffff;
+}
+
+/*
+ * Locate the data contained in a packet that was completely encrypted.
+ */
+static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb,
+                               unsigned int *_offset, unsigned int *_len)
+{
+       struct rxkad_level2_hdr sechdr;
+
+       if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
+               BUG();
+       *_offset += sizeof(sechdr);
+       *_len = ntohl(sechdr.data_size) & 0xffff;
+}
+
+/*
+ * Locate the data contained in an already decrypted packet.
+ */
+static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb,
+                             unsigned int *_offset, unsigned int *_len)
+{
+       switch (call->conn->params.security_level) {
+       case RXRPC_SECURITY_AUTH:
+               rxkad_locate_data_1(call, skb, _offset, _len);
+               return;
+       case RXRPC_SECURITY_ENCRYPT:
+               rxkad_locate_data_2(call, skb, _offset, _len);
+               return;
+       default:
+               return;
+       }
 }
 
 /*
@@ -716,7 +753,7 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
        struct rxkad_challenge challenge;
        struct rxkad_response resp
                __attribute__((aligned(8))); /* must be aligned for crypto */
-       struct rxrpc_skb_priv *sp;
+       struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
        u32 version, nonce, min_level, abort_code;
        int ret;
 
@@ -734,8 +771,8 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
        }
 
        abort_code = RXKADPACKETSHORT;
-       sp = rxrpc_skb(skb);
-       if (skb_copy_bits(skb, 0, &challenge, sizeof(challenge)) < 0)
+       if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
+                         &challenge, sizeof(challenge)) < 0)
                goto protocol_error;
 
        version = ntohl(challenge.version);
@@ -981,7 +1018,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
 {
        struct rxkad_response response
                __attribute__((aligned(8))); /* must be aligned for crypto */
-       struct rxrpc_skb_priv *sp;
+       struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
        struct rxrpc_crypt session_key;
        time_t expiry;
        void *ticket;
@@ -992,7 +1029,8 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
        _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
 
        abort_code = RXKADPACKETSHORT;
-       if (skb_copy_bits(skb, 0, &response, sizeof(response)) < 0)
+       if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
+                         &response, sizeof(response)) < 0)
                goto protocol_error;
        if (!pskb_pull(skb, sizeof(response)))
                BUG();
@@ -1000,7 +1038,6 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
        version = ntohl(response.version);
        ticket_len = ntohl(response.ticket_len);
        kvno = ntohl(response.kvno);
-       sp = rxrpc_skb(skb);
        _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
               sp->hdr.serial, version, kvno, ticket_len);
 
@@ -1022,7 +1059,8 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
                return -ENOMEM;
 
        abort_code = RXKADPACKETSHORT;
-       if (skb_copy_bits(skb, 0, ticket, ticket_len) < 0)
+       if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
+                         ticket, ticket_len) < 0)
                goto protocol_error_free;
 
        ret = rxkad_decrypt_ticket(conn, ticket, ticket_len, &session_key,
@@ -1147,6 +1185,7 @@ const struct rxrpc_security rxkad = {
        .prime_packet_security          = rxkad_prime_packet_security,
        .secure_packet                  = rxkad_secure_packet,
        .verify_packet                  = rxkad_verify_packet,
+       .locate_data                    = rxkad_locate_data,
        .issue_challenge                = rxkad_issue_challenge,
        .respond_to_challenge           = rxkad_respond_to_challenge,
        .verify_response                = rxkad_verify_response,