On Thu, Dec 13, 2018 at 10:38:09AM +0800, jiangyiwen wrote:
> Hi Michael,
> 
> On 2018/12/12 23:31, Michael S. Tsirkin wrote:
> > On Wed, Dec 12, 2018 at 05:31:39PM +0800, jiangyiwen wrote:
> >> Guest receive mergeable rx buffer, it can merge
> >> scatter rx buffer into a big buffer and then copy
> >> to user space.
> >>
> >> In addition, it also use iovec to replace buf in struct
> >> virtio_vsock_pkt, keep tx and rx consistency. The only
> >> difference is now tx still uses a segment of continuous
> >> physical memory to implement.
> >>
> >> Signed-off-by: Yiwen Jiang <jiangyi...@huawei.com>
> >> ---
> >>  drivers/vhost/vsock.c                   |  31 +++++++---
> >>  include/linux/virtio_vsock.h            |   6 +-
> >>  net/vmw_vsock/virtio_transport.c        | 105 
> >> ++++++++++++++++++++++++++++----
> >>  net/vmw_vsock/virtio_transport_common.c |  59 ++++++++++++++----
> >>  4 files changed, 166 insertions(+), 35 deletions(-)
> > 
> > 
> > This was supposed to be a guest patch, why is vhost changed here?
> > 
> 
> In mergeable rx buff cases, it need to scatter big packets into several
> buffers, so I add kvec variable in struct virtio_vsock_pkt, at the same
> time, in order to keep tx and rx consistency, I use kvec to replace
> variable buf, because vhost use the variable pkt->buf, so this patch
> caused vhost is changed.

You'd want to split these patches imho.

> >> diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
> >> index dc52b0f..c7ab0dd 100644
> >> --- a/drivers/vhost/vsock.c
> >> +++ b/drivers/vhost/vsock.c
> >> @@ -179,6 +179,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
> >>            size_t nbytes;
> >>            size_t len;
> >>            s16 headcount;
> >> +          size_t remain_len;
> >> +          int i;
> >>
> >>            spin_lock_bh(&vsock->send_pkt_list_lock);
> >>            if (list_empty(&vsock->send_pkt_list)) {
> >> @@ -221,11 +223,19 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
> >>                    break;
> >>            }
> >>
> >> -          nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
> >> -          if (nbytes != pkt->len) {
> >> -                  virtio_transport_free_pkt(pkt);
> >> -                  vq_err(vq, "Faulted on copying pkt buf\n");
> >> -                  break;
> >> +          remain_len = pkt->len;
> >> +          for (i = 0; i < pkt->nr_vecs; i++) {
> >> +                  int tmp_len;
> >> +
> >> +                  tmp_len = min(remain_len, pkt->vec[i].iov_len);
> >> +                  nbytes = copy_to_iter(pkt->vec[i].iov_base, tmp_len, 
> >> &iov_iter);
> >> +                  if (nbytes != tmp_len) {
> >> +                          virtio_transport_free_pkt(pkt);
> >> +                          vq_err(vq, "Faulted on copying pkt buf\n");
> >> +                          break;
> >> +                  }
> >> +
> >> +                  remain_len -= tmp_len;
> >>            }
> >>
> >>            vhost_add_used_n(vq, vq->heads, headcount);
> >> @@ -341,6 +351,7 @@ static void vhost_transport_send_pkt_work(struct 
> >> vhost_work *work)
> >>    struct iov_iter iov_iter;
> >>    size_t nbytes;
> >>    size_t len;
> >> +  void *buf;
> >>
> >>    if (in != 0) {
> >>            vq_err(vq, "Expected 0 input buffers, got %u\n", in);
> >> @@ -375,13 +386,17 @@ static void vhost_transport_send_pkt_work(struct 
> >> vhost_work *work)
> >>            return NULL;
> >>    }
> >>
> >> -  pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
> >> -  if (!pkt->buf) {
> >> +  buf = kmalloc(pkt->len, GFP_KERNEL);
> >> +  if (!buf) {
> >>            kfree(pkt);
> >>            return NULL;
> >>    }
> >>
> >> -  nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
> >> +  pkt->vec[0].iov_base = buf;
> >> +  pkt->vec[0].iov_len = pkt->len;
> >> +  pkt->nr_vecs = 1;
> >> +
> >> +  nbytes = copy_from_iter(buf, pkt->len, &iov_iter);
> >>    if (nbytes != pkt->len) {
> >>            vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
> >>                   pkt->len, nbytes);
> >> diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
> >> index da9e1fe..734eeed 100644
> >> --- a/include/linux/virtio_vsock.h
> >> +++ b/include/linux/virtio_vsock.h
> >> @@ -13,6 +13,8 @@
> >>  #define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE  (1024 * 4)
> >>  #define VIRTIO_VSOCK_MAX_BUF_SIZE         0xFFFFFFFFUL
> >>  #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE             (1024 * 64)
> >> +/* virtio_vsock_pkt + max_pkt_len(default MAX_PKT_BUF_SIZE) */
> >> +#define VIRTIO_VSOCK_MAX_VEC_NUM ((VIRTIO_VSOCK_MAX_PKT_BUF_SIZE / 
> >> PAGE_SIZE) + 1)
> >>
> >>  /* Virtio-vsock feature */
> >>  #define VIRTIO_VSOCK_F_MRG_RXBUF 0 /* Host can merge receive buffers. */
> >> @@ -55,10 +57,12 @@ struct virtio_vsock_pkt {
> >>    struct list_head list;
> >>    /* socket refcnt not held, only use for cancellation */
> >>    struct vsock_sock *vsk;
> >> -  void *buf;
> >> +  struct kvec vec[VIRTIO_VSOCK_MAX_VEC_NUM];
> >> +  int nr_vecs;
> >>    u32 len;
> >>    u32 off;
> >>    bool reply;
> >> +  bool mergeable;
> >>  };
> >>
> >>  struct virtio_vsock_pkt_info {
> >> diff --git a/net/vmw_vsock/virtio_transport.c 
> >> b/net/vmw_vsock/virtio_transport.c
> >> index c4a465c..148b58a 100644
> >> --- a/net/vmw_vsock/virtio_transport.c
> >> +++ b/net/vmw_vsock/virtio_transport.c
> >> @@ -155,8 +155,10 @@ static int virtio_transport_send_pkt_loopback(struct 
> >> virtio_vsock *vsock,
> >>
> >>            sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
> >>            sgs[out_sg++] = &hdr;
> >> -          if (pkt->buf) {
> >> -                  sg_init_one(&buf, pkt->buf, pkt->len);
> >> +          if (pkt->len) {
> >> +                  /* Currently only support a segment of memory in tx */
> >> +                  BUG_ON(pkt->vec[0].iov_len != pkt->len);
> >> +                  sg_init_one(&buf, pkt->vec[0].iov_base, 
> >> pkt->vec[0].iov_len);
> >>                    sgs[out_sg++] = &buf;
> >>            }
> >>
> >> @@ -304,23 +306,28 @@ static int fill_old_rx_buff(struct virtqueue *vq)
> >>    struct virtio_vsock_pkt *pkt;
> >>    struct scatterlist hdr, buf, *sgs[2];
> >>    int ret;
> >> +  void *pkt_buf;
> >>
> >>    pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
> >>    if (!pkt)
> >>            return -ENOMEM;
> >>
> >> -  pkt->buf = kmalloc(buf_len, GFP_KERNEL);
> >> -  if (!pkt->buf) {
> >> +  pkt_buf = kmalloc(buf_len, GFP_KERNEL);
> >> +  if (!pkt_buf) {
> >>            virtio_transport_free_pkt(pkt);
> >>            return -ENOMEM;
> >>    }
> >>
> >> +  pkt->vec[0].iov_base = pkt_buf;
> >> +  pkt->vec[0].iov_len = buf_len;
> >> +  pkt->nr_vecs = 1;
> >> +
> >>    pkt->len = buf_len;
> >>
> >>    sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
> >>    sgs[0] = &hdr;
> >>
> >> -  sg_init_one(&buf, pkt->buf, buf_len);
> >> +  sg_init_one(&buf, pkt->vec[0].iov_base, buf_len);
> >>    sgs[1] = &buf;
> >>    ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
> >>    if (ret)
> >> @@ -388,11 +395,78 @@ static bool virtio_transport_more_replies(struct 
> >> virtio_vsock *vsock)
> >>    return val < virtqueue_get_vring_size(vq);
> >>  }
> >>
> >> +static struct virtio_vsock_pkt *receive_mergeable(struct virtqueue *vq,
> >> +          struct virtio_vsock *vsock, unsigned int *total_len)
> >> +{
> >> +  struct virtio_vsock_pkt *pkt;
> >> +  u16 num_buf;
> >> +  void *buf;
> >> +  unsigned int len;
> >> +  size_t vsock_hlen = sizeof(struct virtio_vsock_pkt);
> >> +
> >> +  buf = virtqueue_get_buf(vq, &len);
> >> +  if (!buf)
> >> +          return NULL;
> >> +
> >> +  *total_len = len;
> >> +  vsock->rx_buf_nr--;
> >> +
> >> +  if (unlikely(len < vsock_hlen)) {
> >> +          put_page(virt_to_head_page(buf));
> >> +          return NULL;
> >> +  }
> >> +
> >> +  pkt = buf;
> >> +  num_buf = le16_to_cpu(pkt->mrg_rxbuf_hdr.num_buffers);
> >> +  if (!num_buf || num_buf > VIRTIO_VSOCK_MAX_VEC_NUM) {
> >> +          put_page(virt_to_head_page(buf));
> >> +          return NULL;
> >> +  }
> > 
> > So everything just stops going, and host and user don't even
> > know what the reason is. And not only that - the next
> > packet will be corrupted because we skipped the first one.
> > 
> > 
> 
> I understand this case will not encountered unless the code has
> *BUG*, like Host send some problematic packages (shorten/longer than
> expected). In this case, I think we should ignore/drop these packets.

If there's a specific packet length expected, e.g.  in this case vector
size, it needs to be negotiated between host and guest in some way.

> > 
> >> +
> >> +  /* Initialize pkt residual structure */
> >> +  memset(&pkt->work, 0, vsock_hlen - sizeof(struct virtio_vsock_hdr) -
> >> +                  sizeof(struct virtio_vsock_mrg_rxbuf_hdr));
> >> +
> >> +  pkt->mergeable = true;
> >> +  pkt->len = le32_to_cpu(pkt->hdr.len);
> >> +  if (!pkt->len)
> >> +          return pkt;
> >> +
> >> +  len -= vsock_hlen;
> >> +  if (len) {
> >> +          pkt->vec[pkt->nr_vecs].iov_base = buf + vsock_hlen;
> >> +          pkt->vec[pkt->nr_vecs].iov_len = len;
> >> +          /* Shared page with pkt, so get page in advance */
> >> +          get_page(virt_to_head_page(buf));
> >> +          pkt->nr_vecs++;
> >> +  }
> >> +
> >> +  while (--num_buf) {
> >> +          buf = virtqueue_get_buf(vq, &len);
> >> +          if (!buf)
> >> +                  goto err;
> >> +
> >> +          *total_len += len;
> >> +          vsock->rx_buf_nr--;
> >> +
> >> +          pkt->vec[pkt->nr_vecs].iov_base = buf;
> >> +          pkt->vec[pkt->nr_vecs].iov_len = len;
> >> +          pkt->nr_vecs++;
> >> +  }
> >> +
> >> +  return pkt;
> >> +err:
> >> +  virtio_transport_free_pkt(pkt);
> >> +  return NULL;
> >> +}
> >> +
> >>  static void virtio_transport_rx_work(struct work_struct *work)
> >>  {
> >>    struct virtio_vsock *vsock =
> >>            container_of(work, struct virtio_vsock, rx_work);
> >>    struct virtqueue *vq;
> >> +  size_t vsock_hlen = vsock->mergeable ? sizeof(struct virtio_vsock_pkt) :
> >> +                  sizeof(struct virtio_vsock_hdr);
> >>
> >>    vq = vsock->vqs[VSOCK_VQ_RX];
> >>
> >> @@ -412,21 +486,26 @@ static void virtio_transport_rx_work(struct 
> >> work_struct *work)
> >>                            goto out;
> >>                    }
> >>
> >> -                  pkt = virtqueue_get_buf(vq, &len);
> >> -                  if (!pkt) {
> >> -                          break;
> >> -                  }
> >> +                  if (likely(vsock->mergeable)) {
> >> +                          pkt = receive_mergeable(vq, vsock, &len);
> >> +                          if (!pkt)
> >> +                                  break;
> >> +                  } else {
> >> +                          pkt = virtqueue_get_buf(vq, &len);
> >> +                          if (!pkt)
> >> +                                  break;
> >>
> > 
> > So looking at it, this seems to be the main source of the gain.
> > But why does this require host/guest changes?
> > 
> > 
> > The way I see it:
> >     - get a buffer and create an skb
> >     - get the next one, check header matches, if yes
> >       tack it on the skb as a fragment. If not then
> >       don't, deliver previous one and queue the new one.
> > 
> > 
> 
> Vhost change reason I explain as above, and I hope use kvec
> to instead buf, after all buf only can express a contiguous
> physical memory.
> 
> Thanks,
> Yiwen.


I got the reason but I am not yet convinced it's a good one.
You don't necessarily need host to skip headers
in all but the first chunk. Guest can do this just as well.



> > 
> >> -                  vsock->rx_buf_nr--;
> >> +                          vsock->rx_buf_nr--;
> >> +                  }
> >>
> >>                    /* Drop short/long packets */
> >> -                  if (unlikely(len < sizeof(pkt->hdr) ||
> >> -                               len > sizeof(pkt->hdr) + pkt->len)) {
> >> +                  if (unlikely(len < vsock_hlen ||
> >> +                               len > vsock_hlen + pkt->len)) {
> >>                            virtio_transport_free_pkt(pkt);
> >>                            continue;
> >>                    }
> >>
> >> -                  pkt->len = len - sizeof(pkt->hdr);
> >> +                  pkt->len = len - vsock_hlen;
> >>                    virtio_transport_deliver_tap_pkt(pkt);
> >>                    virtio_transport_recv_pkt(pkt);
> >>            }
> >> diff --git a/net/vmw_vsock/virtio_transport_common.c 
> >> b/net/vmw_vsock/virtio_transport_common.c
> >> index 3ae3a33..123a8b6 100644
> >> --- a/net/vmw_vsock/virtio_transport_common.c
> >> +++ b/net/vmw_vsock/virtio_transport_common.c
> >> @@ -44,6 +44,7 @@ static const struct virtio_transport 
> >> *virtio_transport_get_ops(void)
> >>  {
> >>    struct virtio_vsock_pkt *pkt;
> >>    int err;
> >> +  void *buf = NULL;
> >>
> >>    pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
> >>    if (!pkt)
> >> @@ -62,12 +63,16 @@ static const struct virtio_transport 
> >> *virtio_transport_get_ops(void)
> >>    pkt->vsk                = info->vsk;
> >>
> >>    if (info->msg && len > 0) {
> >> -          pkt->buf = kmalloc(len, GFP_KERNEL);
> >> -          if (!pkt->buf)
> >> +          buf = kmalloc(len, GFP_KERNEL);
> >> +          if (!buf)
> >>                    goto out_pkt;
> >> -          err = memcpy_from_msg(pkt->buf, info->msg, len);
> >> +          err = memcpy_from_msg(buf, info->msg, len);
> >>            if (err)
> >>                    goto out;
> >> +
> >> +          pkt->vec[0].iov_base = buf;
> >> +          pkt->vec[0].iov_len = len;
> >> +          pkt->nr_vecs = 1;
> >>    }
> >>
> >>    trace_virtio_transport_alloc_pkt(src_cid, src_port,
> >> @@ -80,7 +85,7 @@ static const struct virtio_transport 
> >> *virtio_transport_get_ops(void)
> >>    return pkt;
> >>
> >>  out:
> >> -  kfree(pkt->buf);
> >> +  kfree(buf);
> >>  out_pkt:
> >>    kfree(pkt);
> >>    return NULL;
> >> @@ -92,6 +97,7 @@ static struct sk_buff *virtio_transport_build_skb(void 
> >> *opaque)
> >>    struct virtio_vsock_pkt *pkt = opaque;
> >>    struct af_vsockmon_hdr *hdr;
> >>    struct sk_buff *skb;
> >> +  int i;
> >>
> >>    skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
> >>                    GFP_ATOMIC);
> >> @@ -134,7 +140,8 @@ static struct sk_buff *virtio_transport_build_skb(void 
> >> *opaque)
> >>    skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
> >>
> >>    if (pkt->len) {
> >> -          skb_put_data(skb, pkt->buf, pkt->len);
> >> +          for (i = 0; i < pkt->nr_vecs; i++)
> >> +                  skb_put_data(skb, pkt->vec[i].iov_base, 
> >> pkt->vec[i].iov_len);
> >>    }
> >>
> >>    return skb;
> >> @@ -260,6 +267,9 @@ static int virtio_transport_send_credit_update(struct 
> >> vsock_sock *vsk,
> >>
> >>    spin_lock_bh(&vvs->rx_lock);
> >>    while (total < len && !list_empty(&vvs->rx_queue)) {
> >> +          size_t copy_bytes, last_vec_total = 0, vec_off;
> >> +          int i;
> >> +
> >>            pkt = list_first_entry(&vvs->rx_queue,
> >>                                   struct virtio_vsock_pkt, list);
> >>
> >> @@ -272,14 +282,28 @@ static int 
> >> virtio_transport_send_credit_update(struct vsock_sock *vsk,
> >>             */
> >>            spin_unlock_bh(&vvs->rx_lock);
> >>
> >> -          err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
> >> -          if (err)
> >> -                  goto out;
> >> +          for (i = 0; i < pkt->nr_vecs; i++) {
> >> +                  if (pkt->off > last_vec_total + pkt->vec[i].iov_len) {
> >> +                          last_vec_total += pkt->vec[i].iov_len;
> >> +                          continue;
> >> +                  }
> >> +
> >> +                  vec_off = pkt->off - last_vec_total;
> >> +                  copy_bytes = min(pkt->vec[i].iov_len - vec_off, bytes);
> >> +                  err = memcpy_to_msg(msg, pkt->vec[i].iov_base + vec_off,
> >> +                                  copy_bytes);
> >> +                  if (err)
> >> +                          goto out;
> >> +
> >> +                  bytes -= copy_bytes;
> >> +                  pkt->off += copy_bytes;
> >> +                  total += copy_bytes;
> >> +                  last_vec_total += pkt->vec[i].iov_len;
> >> +                  if (!bytes)
> >> +                          break;
> >> +          }
> >>
> >>            spin_lock_bh(&vvs->rx_lock);
> >> -
> >> -          total += bytes;
> >> -          pkt->off += bytes;
> >>            if (pkt->off == pkt->len) {
> >>                    virtio_transport_dec_rx_pkt(vvs, pkt);
> >>                    list_del(&pkt->list);
> >> @@ -1050,8 +1074,17 @@ void virtio_transport_recv_pkt(struct 
> >> virtio_vsock_pkt *pkt)
> >>
> >>  void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
> >>  {
> >> -  kfree(pkt->buf);
> >> -  kfree(pkt);
> >> +  int i;
> >> +
> >> +  if (pkt->mergeable) {
> >> +          for (i = 0; i < pkt->nr_vecs; i++)
> >> +                  put_page(virt_to_head_page(pkt->vec[i].iov_base));
> >> +          put_page(virt_to_head_page((void *)pkt));
> >> +  } else {
> >> +          for (i = 0; i < pkt->nr_vecs; i++)
> >> +                  kfree(pkt->vec[i].iov_base);
> >> +          kfree(pkt);
> >> +  }
> >>  }
> >>  EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
> >>
> >> -- 
> >> 1.8.3.1
> >>
> > 
> > .
> > 
> 

Reply via email to