net/hyperv: Fix the page buffer when an RNDIS message goes beyond page boundary
[cascardo/linux.git] / drivers / net / hyperv / rndis_filter.c
index da181f9..133b7fb 100644 (file)
@@ -321,6 +321,25 @@ static void rndis_filter_receive_data(struct rndis_device *dev,
        data_offset = RNDIS_HEADER_SIZE + rndis_pkt->data_offset;
 
        pkt->total_data_buflen -= data_offset;
+
+       /*
+        * Make sure we got a valid RNDIS message, now total_data_buflen
+        * should be the data packet size plus the trailer padding size
+        */
+       if (pkt->total_data_buflen < rndis_pkt->data_len) {
+               netdev_err(dev->net_dev->ndev, "rndis message buffer "
+                          "overflow detected (got %u, min %u)"
+                          "...dropping this message!\n",
+                          pkt->total_data_buflen, rndis_pkt->data_len);
+               return;
+       }
+
+       /*
+        * Remove the rndis trailer padding from rndis packet message
+        * rndis_pkt->data_len tell us the real data length, we only copy
+        * the data packet to the stack, without the rndis trailer padding
+        */
+       pkt->total_data_buflen = rndis_pkt->data_len;
        pkt->data = (void *)((unsigned long)pkt->data + data_offset);
 
        pkt->is_data_pkt = true;
@@ -778,6 +797,19 @@ int rndis_filter_send(struct hv_device *dev,
                        (unsigned long)rndisMessage & (PAGE_SIZE-1);
        pkt->page_buf[0].len = rndisMessageSize;
 
+       /* Add one page_buf if the rndis msg goes beyond page boundary */
+       if (pkt->page_buf[0].offset + rndisMessageSize > PAGE_SIZE) {
+               int i;
+               for (i = pkt->page_buf_cnt; i > 1; i--)
+                       pkt->page_buf[i] = pkt->page_buf[i-1];
+               pkt->page_buf_cnt++;
+               pkt->page_buf[0].len = PAGE_SIZE - pkt->page_buf[0].offset;
+               pkt->page_buf[1].pfn = virt_to_phys((void *)((ulong)
+                       rndisMessage + pkt->page_buf[0].len)) >> PAGE_SHIFT;
+               pkt->page_buf[1].offset = 0;
+               pkt->page_buf[1].len = rndisMessageSize - pkt->page_buf[0].len;
+       }
+
        /* Save the packet send completion and context */
        filterPacket->completion = pkt->completion.send.send_completion;
        filterPacket->completion_ctx =