Merge commit '2c563880ea' into work.xattr
[cascardo/linux.git] / fs / afs / rxrpc.c
index 4832de8..14d04c8 100644 (file)
@@ -150,10 +150,9 @@ void afs_close_socket(void)
 }
 
 /*
- * note that the data in a socket buffer is now delivered and that the buffer
- * should be freed
+ * Note that the data in a socket buffer is now consumed.
  */
-static void afs_data_delivered(struct sk_buff *skb)
+void afs_data_consumed(struct afs_call *call, struct sk_buff *skb)
 {
        if (!skb) {
                _debug("DLVR NULL [%d]", atomic_read(&afs_outstanding_skbs));
@@ -161,9 +160,7 @@ static void afs_data_delivered(struct sk_buff *skb)
        } else {
                _debug("DLVR %p{%u} [%d]",
                       skb, skb->mark, atomic_read(&afs_outstanding_skbs));
-               if (atomic_dec_return(&afs_outstanding_skbs) == -1)
-                       BUG();
-               rxrpc_kernel_data_delivered(skb);
+               rxrpc_kernel_data_consumed(call->rxcall, skb);
        }
 }
 
@@ -489,9 +486,15 @@ static void afs_deliver_to_call(struct afs_call *call)
                        last = rxrpc_kernel_is_data_last(skb);
                        ret = call->type->deliver(call, skb, last);
                        switch (ret) {
+                       case -EAGAIN:
+                               if (last) {
+                                       _debug("short data");
+                                       goto unmarshal_error;
+                               }
+                               break;
                        case 0:
-                               if (last &&
-                                   call->state == AFS_CALL_AWAIT_REPLY)
+                               ASSERT(last);
+                               if (call->state == AFS_CALL_AWAIT_REPLY)
                                        call->state = AFS_CALL_COMPLETE;
                                break;
                        case -ENOTCONN:
@@ -501,6 +504,7 @@ static void afs_deliver_to_call(struct afs_call *call)
                                abort_code = RX_INVALID_OPERATION;
                                goto do_abort;
                        default:
+                       unmarshal_error:
                                abort_code = RXGEN_CC_UNMARSHAL;
                                if (call->state != AFS_CALL_AWAIT_REPLY)
                                        abort_code = RXGEN_SS_UNMARSHAL;
@@ -511,9 +515,7 @@ static void afs_deliver_to_call(struct afs_call *call)
                                call->state = AFS_CALL_ERROR;
                                break;
                        }
-                       afs_data_delivered(skb);
-                       skb = NULL;
-                       continue;
+                       break;
                case RXRPC_SKB_MARK_FINAL_ACK:
                        _debug("Rcv ACK");
                        call->state = AFS_CALL_COMPLETE;
@@ -685,15 +687,35 @@ static void afs_process_async_call(struct afs_call *call)
 }
 
 /*
- * empty a socket buffer into a flat reply buffer
+ * Empty a socket buffer into a flat reply buffer.
  */
-void afs_transfer_reply(struct afs_call *call, struct sk_buff *skb)
+int afs_transfer_reply(struct afs_call *call, struct sk_buff *skb, bool last)
 {
        size_t len = skb->len;
 
-       if (skb_copy_bits(skb, 0, call->buffer + call->reply_size, len) < 0)
-               BUG();
-       call->reply_size += len;
+       if (len > call->reply_max - call->reply_size) {
+               _leave(" = -EBADMSG [%zu > %u]",
+                      len, call->reply_max - call->reply_size);
+               return -EBADMSG;
+       }
+
+       if (len > 0) {
+               if (skb_copy_bits(skb, 0, call->buffer + call->reply_size,
+                                 len) < 0)
+                       BUG();
+               call->reply_size += len;
+       }
+
+       afs_data_consumed(call, skb);
+       if (!last)
+               return -EAGAIN;
+
+       if (call->reply_size != call->reply_max) {
+               _leave(" = -EBADMSG [%u != %u]",
+                      call->reply_size, call->reply_max);
+               return -EBADMSG;
+       }
+       return 0;
 }
 
 /*
@@ -745,7 +767,8 @@ static void afs_collect_incoming_call(struct work_struct *work)
 }
 
 /*
- * grab the operation ID from an incoming cache manager call
+ * Grab the operation ID from an incoming cache manager call.  The socket
+ * buffer is discarded on error or if we don't yet have sufficient data.
  */
 static int afs_deliver_cm_op_id(struct afs_call *call, struct sk_buff *skb,
                                bool last)
@@ -766,12 +789,9 @@ static int afs_deliver_cm_op_id(struct afs_call *call, struct sk_buff *skb,
        call->offset += len;
 
        if (call->offset < 4) {
-               if (last) {
-                       _leave(" = -EBADMSG [op ID short]");
-                       return -EBADMSG;
-               }
-               _leave(" = 0 [incomplete]");
-               return 0;
+               afs_data_consumed(call, skb);
+               _leave(" = -EAGAIN");
+               return -EAGAIN;
        }
 
        call->state = AFS_CALL_AWAIT_REQUEST;
@@ -855,7 +875,7 @@ void afs_send_simple_reply(struct afs_call *call, const void *buf, size_t len)
 }
 
 /*
- * extract a piece of data from the received data socket buffers
+ * Extract a piece of data from the received data socket buffers.
  */
 int afs_extract_data(struct afs_call *call, struct sk_buff *skb,
                     bool last, void *buf, size_t count)
@@ -873,10 +893,7 @@ int afs_extract_data(struct afs_call *call, struct sk_buff *skb,
        call->offset += len;
 
        if (call->offset < count) {
-               if (last) {
-                       _leave(" = -EBADMSG [%d < %zu]", call->offset, count);
-                       return -EBADMSG;
-               }
+               afs_data_consumed(call, skb);
                _leave(" = -EAGAIN");
                return -EAGAIN;
        }