64bcb10bb901d51914772607d5bd75aff7b8130f
[cascardo/linux.git] / drivers / vhost / vsock.c
1 /*
2  * vhost transport for vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/miscdevice.h>
11 #include <linux/module.h>
12 #include <linux/mutex.h>
13 #include <net/sock.h>
14 #include <linux/virtio_vsock.h>
15 #include <linux/vhost.h>
16
17 #include <net/af_vsock.h>
18 #include "vhost.h"
19 #include "vsock.h"
20
21 #define VHOST_VSOCK_DEFAULT_HOST_CID    2
22
23 static int vhost_transport_socket_init(struct vsock_sock *vsk,
24                                        struct vsock_sock *psk);
25
26 enum {
27         VHOST_VSOCK_FEATURES = VHOST_FEATURES,
28 };
29
30 /* Used to track all the vhost_vsock instances on the system. */
31 static LIST_HEAD(vhost_vsock_list);
32 static DEFINE_MUTEX(vhost_vsock_mutex);
33
34 struct vhost_vsock_virtqueue {
35         struct vhost_virtqueue vq;
36 };
37
38 struct vhost_vsock {
39         /* Vhost device */
40         struct vhost_dev dev;
41         /* Vhost vsock virtqueue*/
42         struct vhost_vsock_virtqueue vqs[VSOCK_VQ_MAX];
43         /* Link to global vhost_vsock_list*/
44         struct list_head list;
45         /* Head for pkt from host to guest */
46         struct list_head send_pkt_list;
47         /* Work item to send pkt */
48         struct vhost_work send_pkt_work;
49         /* Wait queue for send pkt */
50         wait_queue_head_t queue_wait;
51         /* Used for global tx buf limitation */
52         u32 total_tx_buf;
53         /* Guest contex id this vhost_vsock instance handles */
54         u32 guest_cid;
55 };
56
57 static u32 vhost_transport_get_local_cid(void)
58 {
59         return VHOST_VSOCK_DEFAULT_HOST_CID;
60 }
61
62 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
63 {
64         struct vhost_vsock *vsock;
65
66         mutex_lock(&vhost_vsock_mutex);
67         list_for_each_entry(vsock, &vhost_vsock_list, list) {
68                 if (vsock->guest_cid == guest_cid) {
69                         mutex_unlock(&vhost_vsock_mutex);
70                         return vsock;
71                 }
72         }
73         mutex_unlock(&vhost_vsock_mutex);
74
75         return NULL;
76 }
77
78 static void
79 vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
80                             struct vhost_virtqueue *vq)
81 {
82         bool added = false;
83
84         mutex_lock(&vq->mutex);
85         vhost_disable_notify(&vsock->dev, vq);
86         for (;;) {
87                 struct virtio_vsock_pkt *pkt;
88                 struct iov_iter iov_iter;
89                 unsigned out, in;
90                 struct sock *sk;
91                 size_t nbytes;
92                 size_t len;
93                 int head;
94
95                 if (list_empty(&vsock->send_pkt_list)) {
96                         vhost_enable_notify(&vsock->dev, vq);
97                         break;
98                 }
99
100                 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
101                                          &out, &in, NULL, NULL);
102                 pr_debug("%s: head = %d\n", __func__, head);
103                 if (head < 0)
104                         break;
105
106                 if (head == vq->num) {
107                         if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
108                                 vhost_disable_notify(&vsock->dev, vq);
109                                 continue;
110                         }
111                         break;
112                 }
113
114                 pkt = list_first_entry(&vsock->send_pkt_list,
115                                        struct virtio_vsock_pkt, list);
116                 list_del_init(&pkt->list);
117
118                 if (out) {
119                         virtio_transport_free_pkt(pkt);
120                         vq_err(vq, "Expected 0 output buffers, got %u\n", out);
121                         break;
122                 }
123
124                 len = iov_length(&vq->iov[out], in);
125                 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
126
127                 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
128                 if (nbytes != sizeof(pkt->hdr)) {
129                         virtio_transport_free_pkt(pkt);
130                         vq_err(vq, "Faulted on copying pkt hdr\n");
131                         break;
132                 }
133
134                 nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
135                 if (nbytes != pkt->len) {
136                         virtio_transport_free_pkt(pkt);
137                         vq_err(vq, "Faulted on copying pkt buf\n");
138                         break;
139                 }
140
141                 vhost_add_used(vq, head, pkt->len); /* TODO should this be sizeof(pkt->hdr) + pkt->len? */
142                 added = true;
143
144                 virtio_transport_dec_tx_pkt(pkt);
145                 vsock->total_tx_buf -= pkt->len;
146
147                 sk = sk_vsock(pkt->trans->vsk);
148                 /* Release refcnt taken in vhost_transport_send_pkt */
149                 sock_put(sk);
150
151                 virtio_transport_free_pkt(pkt);
152         }
153         if (added)
154                 vhost_signal(&vsock->dev, vq);
155         mutex_unlock(&vq->mutex);
156
157         if (added)
158                 wake_up(&vsock->queue_wait);
159 }
160
161 static void vhost_transport_send_pkt_work(struct vhost_work *work)
162 {
163         struct vhost_virtqueue *vq;
164         struct vhost_vsock *vsock;
165
166         vsock = container_of(work, struct vhost_vsock, send_pkt_work);
167         vq = &vsock->vqs[VSOCK_VQ_RX].vq;
168
169         vhost_transport_do_send_pkt(vsock, vq);
170 }
171
172 static int
173 vhost_transport_send_pkt(struct vsock_sock *vsk,
174                          struct virtio_vsock_pkt_info *info)
175 {
176         u32 src_cid, src_port, dst_cid, dst_port;
177         struct virtio_transport *trans;
178         struct virtio_vsock_pkt *pkt;
179         struct vhost_virtqueue *vq;
180         struct vhost_vsock *vsock;
181         u32 pkt_len = info->pkt_len;
182         DEFINE_WAIT(wait);
183
184         src_cid = vhost_transport_get_local_cid();
185         src_port = vsk->local_addr.svm_port;
186         if (!info->remote_cid) {
187                 dst_cid = vsk->remote_addr.svm_cid;
188                 dst_port = vsk->remote_addr.svm_port;
189         } else {
190                 dst_cid = info->remote_cid;
191                 dst_port = info->remote_port;
192         }
193
194         /* Find the vhost_vsock according to guest context id  */
195         vsock = vhost_vsock_get(dst_cid);
196         if (!vsock)
197                 return -ENODEV;
198
199         trans = vsk->trans;
200         vq = &vsock->vqs[VSOCK_VQ_RX].vq;
201
202         /* we can send less than pkt_len bytes */
203         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
204                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
205
206         /* virtio_transport_get_credit might return less than pkt_len credit */
207         pkt_len = virtio_transport_get_credit(trans, pkt_len);
208
209         /* Do not send zero length OP_RW pkt*/
210         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
211                 return pkt_len;
212
213         /* Respect global tx buf limitation */
214         mutex_lock(&vq->mutex);
215         while (pkt_len + vsock->total_tx_buf > VIRTIO_VSOCK_MAX_TX_BUF_SIZE) {
216                 prepare_to_wait_exclusive(&vsock->queue_wait, &wait,
217                                           TASK_UNINTERRUPTIBLE);
218                 mutex_unlock(&vq->mutex);
219                 schedule();
220                 mutex_lock(&vq->mutex);
221                 finish_wait(&vsock->queue_wait, &wait);
222         }
223         vsock->total_tx_buf += pkt_len;
224         mutex_unlock(&vq->mutex);
225
226         pkt = virtio_transport_alloc_pkt(vsk, info, pkt_len,
227                                          src_cid, src_port,
228                                          dst_cid, dst_port);
229         if (!pkt) {
230                 mutex_lock(&vq->mutex);
231                 vsock->total_tx_buf -= pkt_len;
232                 mutex_unlock(&vq->mutex);
233                 virtio_transport_put_credit(trans, pkt_len);
234                 return -ENOMEM;
235         }
236
237         pr_debug("%s:info->pkt_len= %d\n", __func__, pkt_len);
238         /* Released in vhost_transport_do_send_pkt */
239         sock_hold(&trans->vsk->sk);
240         virtio_transport_inc_tx_pkt(pkt);
241
242         /* Queue it up in vhost work */
243         mutex_lock(&vq->mutex);
244         list_add_tail(&pkt->list, &vsock->send_pkt_list);
245         vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
246         mutex_unlock(&vq->mutex);
247
248         return pkt_len;
249 }
250
251 static struct virtio_transport_pkt_ops vhost_ops = {
252         .send_pkt = vhost_transport_send_pkt,
253 };
254
255 static struct virtio_vsock_pkt *
256 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
257                       unsigned int out, unsigned int in)
258 {
259         struct virtio_vsock_pkt *pkt;
260         struct iov_iter iov_iter;
261         size_t nbytes;
262         size_t len;
263
264         if (in != 0) {
265                 vq_err(vq, "Expected 0 input buffers, got %u\n", in);
266                 return NULL;
267         }
268
269         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
270         if (!pkt)
271                 return NULL;
272
273         len = iov_length(vq->iov, out);
274         iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
275
276         nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
277         if (nbytes != sizeof(pkt->hdr)) {
278                 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
279                        sizeof(pkt->hdr), nbytes);
280                 kfree(pkt);
281                 return NULL;
282         }
283
284         if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_DGRAM)
285                 pkt->len = le32_to_cpu(pkt->hdr.len) & 0XFFFF;
286         else if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
287                 pkt->len = le32_to_cpu(pkt->hdr.len);
288
289         /* No payload */
290         if (!pkt->len)
291                 return pkt;
292
293         /* The pkt is too big */
294         if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
295                 kfree(pkt);
296                 return NULL;
297         }
298
299         pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
300         if (!pkt->buf) {
301                 kfree(pkt);
302                 return NULL;
303         }
304
305         nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
306         if (nbytes != pkt->len) {
307                 vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
308                        pkt->len, nbytes);
309                 virtio_transport_free_pkt(pkt);
310                 return NULL;
311         }
312
313         return pkt;
314 }
315
316 static void vhost_vsock_handle_ctl_kick(struct vhost_work *work)
317 {
318         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
319                                                   poll.work);
320         struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
321                                                  dev);
322
323         pr_debug("%s vq=%p, vsock=%p\n", __func__, vq, vsock);
324 }
325
326 static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
327 {
328         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
329                                                   poll.work);
330         struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
331                                                  dev);
332         struct virtio_vsock_pkt *pkt;
333         int head;
334         unsigned int out, in;
335         bool added = false;
336         u32 len;
337
338         mutex_lock(&vq->mutex);
339         vhost_disable_notify(&vsock->dev, vq);
340         for (;;) {
341                 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
342                                          &out, &in, NULL, NULL);
343                 if (head < 0)
344                         break;
345
346                 if (head == vq->num) {
347                         if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
348                                 vhost_disable_notify(&vsock->dev, vq);
349                                 continue;
350                         }
351                         break;
352                 }
353
354                 pkt = vhost_vsock_alloc_pkt(vq, out, in);
355                 if (!pkt) {
356                         vq_err(vq, "Faulted on pkt\n");
357                         continue;
358                 }
359
360                 len = pkt->len;
361
362                 /* Only accept correctly addressed packets */
363                 if (le32_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
364                     le32_to_cpu(pkt->hdr.dst_cid) == vhost_transport_get_local_cid())
365                         virtio_transport_recv_pkt(pkt);
366                 else
367                         virtio_transport_free_pkt(pkt);
368
369                 vhost_add_used(vq, head, len);
370                 added = true;
371         }
372         if (added)
373                 vhost_signal(&vsock->dev, vq);
374         mutex_unlock(&vq->mutex);
375 }
376
377 static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
378 {
379         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
380                                                 poll.work);
381         struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
382                                                  dev);
383
384         vhost_transport_do_send_pkt(vsock, vq);
385 }
386
387 static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
388 {
389         struct vhost_virtqueue **vqs;
390         struct vhost_vsock *vsock;
391         int ret;
392
393         vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
394         if (!vsock)
395                 return -ENOMEM;
396
397         pr_debug("%s:vsock=%p\n", __func__, vsock);
398
399         vqs = kmalloc(VSOCK_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
400         if (!vqs) {
401                 ret = -ENOMEM;
402                 goto out;
403         }
404
405         vqs[VSOCK_VQ_CTRL] = &vsock->vqs[VSOCK_VQ_CTRL].vq;
406         vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX].vq;
407         vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX].vq;
408         vsock->vqs[VSOCK_VQ_CTRL].vq.handle_kick = vhost_vsock_handle_ctl_kick;
409         vsock->vqs[VSOCK_VQ_TX].vq.handle_kick = vhost_vsock_handle_tx_kick;
410         vsock->vqs[VSOCK_VQ_RX].vq.handle_kick = vhost_vsock_handle_rx_kick;
411
412         vhost_dev_init(&vsock->dev, vqs, VSOCK_VQ_MAX);
413
414         file->private_data = vsock;
415         init_waitqueue_head(&vsock->queue_wait);
416         INIT_LIST_HEAD(&vsock->send_pkt_list);
417         vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
418
419         mutex_lock(&vhost_vsock_mutex);
420         list_add_tail(&vsock->list, &vhost_vsock_list);
421         mutex_unlock(&vhost_vsock_mutex);
422         return 0;
423
424 out:
425         kfree(vsock);
426         return ret;
427 }
428
429 static void vhost_vsock_flush(struct vhost_vsock *vsock)
430 {
431         int i;
432
433         for (i = 0; i < VSOCK_VQ_MAX; i++)
434                 vhost_poll_flush(&vsock->vqs[i].vq.poll);
435         vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
436 }
437
438 static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
439 {
440         struct vhost_vsock *vsock = file->private_data;
441
442         mutex_lock(&vhost_vsock_mutex);
443         list_del(&vsock->list);
444         mutex_unlock(&vhost_vsock_mutex);
445
446         vhost_dev_stop(&vsock->dev);
447         vhost_vsock_flush(vsock);
448         vhost_dev_cleanup(&vsock->dev, false);
449         kfree(vsock->dev.vqs);
450         kfree(vsock);
451         return 0;
452 }
453
454 static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u32 guest_cid)
455 {
456         struct vhost_vsock *other;
457
458         /* Refuse reserved CIDs */
459         if (guest_cid <= VMADDR_CID_HOST) {
460                 return -EINVAL;
461         }
462
463         /* Refuse if CID is already in use */
464         other = vhost_vsock_get(guest_cid);
465         if (other && other != vsock) {
466                 return -EADDRINUSE;
467         }
468
469         mutex_lock(&vhost_vsock_mutex);
470         vsock->guest_cid = guest_cid;
471         pr_debug("%s:guest_cid=%d\n", __func__, guest_cid);
472         mutex_unlock(&vhost_vsock_mutex);
473
474         return 0;
475 }
476
477 static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
478 {
479         struct vhost_virtqueue *vq;
480         int i;
481
482         if (features & ~VHOST_VSOCK_FEATURES)
483                 return -EOPNOTSUPP;
484
485         mutex_lock(&vsock->dev.mutex);
486         if ((features & (1 << VHOST_F_LOG_ALL)) &&
487             !vhost_log_access_ok(&vsock->dev)) {
488                 mutex_unlock(&vsock->dev.mutex);
489                 return -EFAULT;
490         }
491
492         for (i = 0; i < VSOCK_VQ_MAX; i++) {
493                 vq = &vsock->vqs[i].vq;
494                 mutex_lock(&vq->mutex);
495                 vq->acked_features = features;
496                 mutex_unlock(&vq->mutex);
497         }
498         mutex_unlock(&vsock->dev.mutex);
499         return 0;
500 }
501
502 static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
503                                   unsigned long arg)
504 {
505         struct vhost_vsock *vsock = f->private_data;
506         void __user *argp = (void __user *)arg;
507         u64 __user *featurep = argp;
508         u32 __user *cidp = argp;
509         u32 guest_cid;
510         u64 features;
511         int r;
512
513         switch (ioctl) {
514         case VHOST_VSOCK_SET_GUEST_CID:
515                 if (get_user(guest_cid, cidp))
516                         return -EFAULT;
517                 return vhost_vsock_set_cid(vsock, guest_cid);
518         case VHOST_GET_FEATURES:
519                 features = VHOST_VSOCK_FEATURES;
520                 if (copy_to_user(featurep, &features, sizeof(features)))
521                         return -EFAULT;
522                 return 0;
523         case VHOST_SET_FEATURES:
524                 if (copy_from_user(&features, featurep, sizeof(features)))
525                         return -EFAULT;
526                 return vhost_vsock_set_features(vsock, features);
527         default:
528                 mutex_lock(&vsock->dev.mutex);
529                 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
530                 if (r == -ENOIOCTLCMD)
531                         r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
532                 else
533                         vhost_vsock_flush(vsock);
534                 mutex_unlock(&vsock->dev.mutex);
535                 return r;
536         }
537 }
538
539 static const struct file_operations vhost_vsock_fops = {
540         .owner          = THIS_MODULE,
541         .open           = vhost_vsock_dev_open,
542         .release        = vhost_vsock_dev_release,
543         .llseek         = noop_llseek,
544         .unlocked_ioctl = vhost_vsock_dev_ioctl,
545 };
546
547 static struct miscdevice vhost_vsock_misc = {
548         .minor = MISC_DYNAMIC_MINOR,
549         .name = "vhost-vsock",
550         .fops = &vhost_vsock_fops,
551 };
552
553 static int
554 vhost_transport_socket_init(struct vsock_sock *vsk, struct vsock_sock *psk)
555 {
556         struct virtio_transport *trans;
557         int ret;
558
559         ret = virtio_transport_do_socket_init(vsk, psk);
560         if (ret)
561                 return ret;
562
563         trans = vsk->trans;
564         trans->ops = &vhost_ops;
565
566         return ret;
567 }
568
569 static struct vsock_transport vhost_transport = {
570         .get_local_cid            = vhost_transport_get_local_cid,
571
572         .init                     = vhost_transport_socket_init,
573         .destruct                 = virtio_transport_destruct,
574         .release                  = virtio_transport_release,
575         .connect                  = virtio_transport_connect,
576         .shutdown                 = virtio_transport_shutdown,
577
578         .dgram_enqueue            = virtio_transport_dgram_enqueue,
579         .dgram_dequeue            = virtio_transport_dgram_dequeue,
580         .dgram_bind               = virtio_transport_dgram_bind,
581         .dgram_allow              = virtio_transport_dgram_allow,
582
583         .stream_enqueue           = virtio_transport_stream_enqueue,
584         .stream_dequeue           = virtio_transport_stream_dequeue,
585         .stream_has_data          = virtio_transport_stream_has_data,
586         .stream_has_space         = virtio_transport_stream_has_space,
587         .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
588         .stream_is_active         = virtio_transport_stream_is_active,
589         .stream_allow             = virtio_transport_stream_allow,
590
591         .notify_poll_in           = virtio_transport_notify_poll_in,
592         .notify_poll_out          = virtio_transport_notify_poll_out,
593         .notify_recv_init         = virtio_transport_notify_recv_init,
594         .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
595         .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
596         .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
597         .notify_send_init         = virtio_transport_notify_send_init,
598         .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
599         .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
600         .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
601
602         .set_buffer_size          = virtio_transport_set_buffer_size,
603         .set_min_buffer_size      = virtio_transport_set_min_buffer_size,
604         .set_max_buffer_size      = virtio_transport_set_max_buffer_size,
605         .get_buffer_size          = virtio_transport_get_buffer_size,
606         .get_min_buffer_size      = virtio_transport_get_min_buffer_size,
607         .get_max_buffer_size      = virtio_transport_get_max_buffer_size,
608 };
609
610 static int __init vhost_vsock_init(void)
611 {
612         int ret;
613
614         ret = vsock_core_init(&vhost_transport);
615         if (ret < 0)
616                 return ret;
617         return misc_register(&vhost_vsock_misc);
618 };
619
620 static void __exit vhost_vsock_exit(void)
621 {
622         misc_deregister(&vhost_vsock_misc);
623         vsock_core_exit();
624 };
625
626 module_init(vhost_vsock_init);
627 module_exit(vhost_vsock_exit);
628 MODULE_LICENSE("GPL v2");
629 MODULE_AUTHOR("Asias He");
630 MODULE_DESCRIPTION("vhost transport for vsock ");