diff --git a/lib/vhost/vhost.c b/lib/vhost/vhost.c index 305d5146ae..85b712e0c6 100644 --- a/lib/vhost/vhost.c +++ b/lib/vhost/vhost.c @@ -138,6 +138,10 @@ spdk_vhost_vq_get_desc(struct spdk_vhost_dev *vdev, struct spdk_vhost_virtqueue *desc_table = spdk_vhost_gpa_to_vva(vdev, (*desc)->addr); *desc_table_size = (*desc)->len / sizeof(**desc); *desc = *desc_table; + if (*desc == NULL) { + return -1; + } + return 0; } diff --git a/lib/vhost/vhost_internal.h b/lib/vhost/vhost_internal.h index e2b268b4ad..a126733f9b 100644 --- a/lib/vhost/vhost_internal.h +++ b/lib/vhost/vhost_internal.h @@ -148,7 +148,7 @@ bool spdk_vhost_vq_should_notify(struct spdk_vhost_dev *vdev, struct spdk_vhost_ * table. * \param desc_table_size size of the *desc_table* * \return 0 on success, -1 if given index is invalid. - * If -1 is returned, the params won't be changed. + * If -1 is returned, the content of params is undefined. */ int spdk_vhost_vq_get_desc(struct spdk_vhost_dev *vdev, struct spdk_vhost_virtqueue *vq, uint16_t req_idx, struct vring_desc **desc, struct vring_desc **desc_table, diff --git a/lib/vhost/vhost_scsi.c b/lib/vhost/vhost_scsi.c index 2e298e9d44..c5642ae984 100644 --- a/lib/vhost/vhost_scsi.c +++ b/lib/vhost/vhost_scsi.c @@ -335,6 +335,11 @@ process_ctrl_request(struct spdk_vhost_scsi_task *task) switch (ctrl_req->type) { case VIRTIO_SCSI_T_TMF: task->tmf_resp = spdk_vhost_gpa_to_vva(vdev, desc->addr); + if (spdk_unlikely(desc->len < sizeof(struct virtio_scsi_ctrl_tmf_resp) || task->tmf_resp == NULL)) { + SPDK_ERRLOG("%s: TMF response descriptor at index %d points to invalid guest memory region\n", + vdev->name, task->req_idx); + goto out; + } /* Check if we are processing a valid request */ if (task->scsi_dev == NULL) { @@ -359,6 +364,12 @@ process_ctrl_request(struct spdk_vhost_scsi_task *task) case VIRTIO_SCSI_T_AN_QUERY: case VIRTIO_SCSI_T_AN_SUBSCRIBE: { an_resp = spdk_vhost_gpa_to_vva(vdev, desc->addr); + if (spdk_unlikely(desc->len < sizeof(struct virtio_scsi_ctrl_an_resp) || an_resp == NULL)) { + SPDK_WARNLOG("%s: Asynchronous response descriptor points to invalid guest memory region\n", + vdev->name); + goto out; + } + an_resp->response = VIRTIO_SCSI_S_ABORTED; break; } @@ -393,13 +404,19 @@ task_data_setup(struct spdk_vhost_scsi_task *task, rc = spdk_vhost_vq_get_desc(vdev, task->vq, task->req_idx, &desc, &desc_table, &desc_table_len); /* First descriptor must be readable */ - if (rc != 0 || spdk_unlikely(spdk_vhost_vring_desc_is_wr(desc))) { + if (spdk_unlikely(rc != 0 || spdk_vhost_vring_desc_is_wr(desc) || + desc->len < sizeof(struct virtio_scsi_cmd_req))) { SPDK_WARNLOG("%s: invalid first (request) descriptor at index %"PRIu16".\n", vdev->name, task->req_idx); goto invalid_task; } *req = spdk_vhost_gpa_to_vva(vdev, desc->addr); + if (spdk_unlikely(*req == NULL)) { + SPDK_WARNLOG("%s: Request descriptor at index %d points to invalid guest memory region\n", + vdev->name, task->req_idx); + goto invalid_task; + } /* Each request must have at least 2 descriptors (e.g. request and response) */ spdk_vhost_vring_desc_get_next(&desc, desc_table, desc_table_len); @@ -417,7 +434,11 @@ task_data_setup(struct spdk_vhost_scsi_task *task, * FROM_DEV (READ): [RD_req][WR_resp][WR_buf0]...[WR_bufN] */ task->resp = spdk_vhost_gpa_to_vva(vdev, desc->addr); - + if (spdk_unlikely(desc->len < sizeof(struct virtio_scsi_cmd_resp) || task->resp == NULL)) { + SPDK_WARNLOG("%s: Response descriptor at index %d points to invalid guest memory region\n", + vdev->name, task->req_idx); + goto invalid_task; + } rc = spdk_vhost_vring_desc_get_next(&desc, desc_table, desc_table_len); if (spdk_unlikely(rc != 0)) { SPDK_WARNLOG("%s: invalid descriptor chain at request index %d (descriptor id overflow?).\n", @@ -480,6 +501,11 @@ task_data_setup(struct spdk_vhost_scsi_task *task, } task->resp = spdk_vhost_gpa_to_vva(vdev, desc->addr); + if (spdk_unlikely(desc->len < sizeof(struct virtio_scsi_cmd_resp) || task->resp == NULL)) { + SPDK_WARNLOG("%s: Response descriptor at index %d points to invalid guest memory region\n", + vdev->name, task->req_idx); + goto invalid_task; + } } if (iovcnt == iovcnt_max) {