diff options
Diffstat (limited to 'drivers/infiniband/sw/rxe/rxe_verbs.c')
-rw-r--r-- | drivers/infiniband/sw/rxe/rxe_verbs.c | 101 |
1 files changed, 50 insertions, 51 deletions
diff --git a/drivers/infiniband/sw/rxe/rxe_verbs.c b/drivers/infiniband/sw/rxe/rxe_verbs.c index f4bab2cd0ec2..2cb52fd48cf1 100644 --- a/drivers/infiniband/sw/rxe/rxe_verbs.c +++ b/drivers/infiniband/sw/rxe/rxe_verbs.c @@ -77,40 +77,6 @@ out: return rc; } -static int rxe_query_gid(struct ib_device *device, - u8 port_num, int index, union ib_gid *gid) -{ - int ret; - - if (index > RXE_PORT_GID_TBL_LEN) - return -EINVAL; - - ret = ib_get_cached_gid(device, port_num, index, gid, NULL); - if (ret == -EAGAIN) { - memcpy(gid, &zgid, sizeof(*gid)); - return 0; - } - - return ret; -} - -static int rxe_add_gid(struct ib_device *device, u8 port_num, unsigned int - index, const union ib_gid *gid, - const struct ib_gid_attr *attr, void **context) -{ - if (index >= RXE_PORT_GID_TBL_LEN) - return -EINVAL; - return 0; -} - -static int rxe_del_gid(struct ib_device *device, u8 port_num, unsigned int - index, void **context) -{ - if (index >= RXE_PORT_GID_TBL_LEN) - return -EINVAL; - return 0; -} - static struct net_device *rxe_get_netdev(struct ib_device *device, u8 port_num) { @@ -273,9 +239,7 @@ static int rxe_init_av(struct rxe_dev *rxe, struct rdma_ah_attr *attr, rxe_av_from_attr(rdma_ah_get_port_num(attr), av, attr); rxe_av_fill_ip_info(av, attr, &sgid_attr, &sgid); - - if (sgid_attr.ndev) - dev_put(sgid_attr.ndev); + dev_put(sgid_attr.ndev); return 0; } @@ -407,6 +371,13 @@ static struct ib_srq *rxe_create_srq(struct ib_pd *ibpd, struct rxe_pd *pd = to_rpd(ibpd); struct rxe_srq *srq; struct ib_ucontext *context = udata ? ibpd->uobject->context : NULL; + struct rxe_create_srq_resp __user *uresp = NULL; + + if (udata) { + if (udata->outlen < sizeof(*uresp)) + return ERR_PTR(-EINVAL); + uresp = udata->outbuf; + } err = rxe_srq_chk_attr(rxe, NULL, &init->attr, IB_SRQ_INIT_MASK); if (err) @@ -422,7 +393,7 @@ static struct ib_srq *rxe_create_srq(struct ib_pd *ibpd, rxe_add_ref(pd); srq->pd = pd; - err = rxe_srq_from_init(rxe, srq, init, context, udata); + err = rxe_srq_from_init(rxe, srq, init, context, uresp); if (err) goto err2; @@ -443,12 +414,22 @@ static int rxe_modify_srq(struct ib_srq *ibsrq, struct ib_srq_attr *attr, int err; struct rxe_srq *srq = to_rsrq(ibsrq); struct rxe_dev *rxe = to_rdev(ibsrq->device); + struct rxe_modify_srq_cmd ucmd = {}; + + if (udata) { + if (udata->inlen < sizeof(ucmd)) + return -EINVAL; + + err = ib_copy_from_udata(&ucmd, udata, sizeof(ucmd)); + if (err) + return err; + } err = rxe_srq_chk_attr(rxe, srq, attr, mask); if (err) goto err1; - err = rxe_srq_from_attr(rxe, srq, attr, mask, udata); + err = rxe_srq_from_attr(rxe, srq, attr, mask, &ucmd); if (err) goto err1; @@ -517,6 +498,13 @@ static struct ib_qp *rxe_create_qp(struct ib_pd *ibpd, struct rxe_dev *rxe = to_rdev(ibpd->device); struct rxe_pd *pd = to_rpd(ibpd); struct rxe_qp *qp; + struct rxe_create_qp_resp __user *uresp = NULL; + + if (udata) { + if (udata->outlen < sizeof(*uresp)) + return ERR_PTR(-EINVAL); + uresp = udata->outbuf; + } err = rxe_qp_chk_init(rxe, init); if (err) @@ -538,7 +526,7 @@ static struct ib_qp *rxe_create_qp(struct ib_pd *ibpd, rxe_add_index(qp); - err = rxe_qp_from_init(rxe, qp, pd, init, udata, ibpd); + err = rxe_qp_from_init(rxe, qp, pd, init, uresp, ibpd); if (err) goto err3; @@ -711,9 +699,8 @@ static int init_send_wqe(struct rxe_qp *qp, struct ib_send_wr *ibwr, memcpy(wqe->dma.sge, ibwr->sg_list, num_sge * sizeof(struct ib_sge)); - wqe->iova = (mask & WR_ATOMIC_MASK) ? - atomic_wr(ibwr)->remote_addr : - rdma_wr(ibwr)->remote_addr; + wqe->iova = mask & WR_ATOMIC_MASK ? atomic_wr(ibwr)->remote_addr : + mask & WR_READ_OR_WRITE_MASK ? rdma_wr(ibwr)->remote_addr : 0; wqe->mask = mask; wqe->dma.length = length; wqe->dma.resid = length; @@ -889,11 +876,18 @@ static struct ib_cq *rxe_create_cq(struct ib_device *dev, int err; struct rxe_dev *rxe = to_rdev(dev); struct rxe_cq *cq; + struct rxe_create_cq_resp __user *uresp = NULL; + + if (udata) { + if (udata->outlen < sizeof(*uresp)) + return ERR_PTR(-EINVAL); + uresp = udata->outbuf; + } if (attr->flags) return ERR_PTR(-EINVAL); - err = rxe_cq_chk_attr(rxe, NULL, attr->cqe, attr->comp_vector, udata); + err = rxe_cq_chk_attr(rxe, NULL, attr->cqe, attr->comp_vector); if (err) goto err1; @@ -904,7 +898,7 @@ static struct ib_cq *rxe_create_cq(struct ib_device *dev, } err = rxe_cq_from_init(rxe, cq, attr->cqe, attr->comp_vector, - context, udata); + context, uresp); if (err) goto err2; @@ -931,12 +925,19 @@ static int rxe_resize_cq(struct ib_cq *ibcq, int cqe, struct ib_udata *udata) int err; struct rxe_cq *cq = to_rcq(ibcq); struct rxe_dev *rxe = to_rdev(ibcq->device); + struct rxe_resize_cq_resp __user *uresp = NULL; + + if (udata) { + if (udata->outlen < sizeof(*uresp)) + return -EINVAL; + uresp = udata->outbuf; + } - err = rxe_cq_chk_attr(rxe, cq, cqe, 0, udata); + err = rxe_cq_chk_attr(rxe, cq, cqe, 0); if (err) goto err1; - err = rxe_cq_resize_queue(cq, cqe, udata); + err = rxe_cq_resize_queue(cq, cqe, uresp); if (err) goto err1; @@ -1207,7 +1208,7 @@ int rxe_register_device(struct rxe_dev *rxe) rxe->ndev->dev_addr); dev->dev.dma_ops = &dma_virt_ops; dma_coerce_mask_and_coherent(&dev->dev, - dma_get_required_mask(dev->dev.parent)); + dma_get_required_mask(&dev->dev)); dev->uverbs_abi_ver = RXE_UVERBS_ABI_VERSION; dev->uverbs_cmd_mask = BIT_ULL(IB_USER_VERBS_CMD_GET_CONTEXT) @@ -1248,10 +1249,7 @@ int rxe_register_device(struct rxe_dev *rxe) dev->query_port = rxe_query_port; dev->modify_port = rxe_modify_port; dev->get_link_layer = rxe_get_link_layer; - dev->query_gid = rxe_query_gid; dev->get_netdev = rxe_get_netdev; - dev->add_gid = rxe_add_gid; - dev->del_gid = rxe_del_gid; dev->query_pkey = rxe_query_pkey; dev->alloc_ucontext = rxe_alloc_ucontext; dev->dealloc_ucontext = rxe_dealloc_ucontext; @@ -1298,6 +1296,7 @@ int rxe_register_device(struct rxe_dev *rxe) } rxe->tfm = tfm; + dev->driver_id = RDMA_DRIVER_RXE; err = ib_register_device(dev, NULL); if (err) { pr_warn("%s failed with error %d\n", __func__, err); |