Commit 81ddb41f authored by Jason Gunthorpe's avatar Jason Gunthorpe
Browse files

RDMA/cm: Allow ib_send_cm_rej() to be done under lock

The first thing ib_send_cm_rej() does is obtain the lock, so use the usual
unlocked wrapper, locked actor pattern here.

This avoids a sketchy lock/unlock sequence (which could allow state to
change) during cm_destroy_id().

While here simplify some of the logic in the implementation.

Link: https://lore.kernel.org/r/20200310092545.251365-14-leon@kernel.org


Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 87cabf3e
Loading
Loading
Loading
Loading
+52 −40
Original line number Diff line number Diff line
@@ -87,6 +87,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
			       const void *private_data, u8 private_data_len);
static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
			       void *private_data, u8 private_data_len);
static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
			      enum ib_cm_rej_reason reason, void *ari,
			      u8 ari_length, const void *private_data,
			      u8 private_data_len);

static struct ib_client cm_client = {
	.name   = "cm",
@@ -1060,11 +1064,11 @@ retest:
	case IB_CM_REQ_SENT:
	case IB_CM_MRA_REQ_RCVD:
		ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
		spin_unlock_irq(&cm_id_priv->lock);
		ib_send_cm_rej(cm_id, IB_CM_REJ_TIMEOUT,
		cm_send_rej_locked(cm_id_priv, IB_CM_REJ_TIMEOUT,
				   &cm_id_priv->id.device->node_guid,
			       sizeof cm_id_priv->id.device->node_guid,
				   sizeof(cm_id_priv->id.device->node_guid),
				   NULL, 0);
		spin_unlock_irq(&cm_id_priv->lock);
		break;
	case IB_CM_REQ_RCVD:
		if (err == -ENOMEM) {
@@ -1072,9 +1076,10 @@ retest:
			cm_reset_to_idle(cm_id_priv);
			spin_unlock_irq(&cm_id_priv->lock);
		} else {
			cm_send_rej_locked(cm_id_priv,
					   IB_CM_REJ_CONSUMER_DEFINED, NULL, 0,
					   NULL, 0);
			spin_unlock_irq(&cm_id_priv->lock);
			ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
				       NULL, 0, NULL, 0);
		}
		break;
	case IB_CM_REP_SENT:
@@ -1084,9 +1089,9 @@ retest:
	case IB_CM_MRA_REQ_SENT:
	case IB_CM_REP_RCVD:
	case IB_CM_MRA_REP_SENT:
		cm_send_rej_locked(cm_id_priv, IB_CM_REJ_CONSUMER_DEFINED, NULL,
				   0, NULL, 0);
		spin_unlock_irq(&cm_id_priv->lock);
		ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
			       NULL, 0, NULL, 0);
		break;
	case IB_CM_ESTABLISHED:
		if (cm_id_priv->qp_type == IB_QPT_XRC_TGT) {
@@ -2899,65 +2904,72 @@ out:
	return -EINVAL;
}

int ib_send_cm_rej(struct ib_cm_id *cm_id,
		   enum ib_cm_rej_reason reason,
		   void *ari,
		   u8 ari_length,
		   const void *private_data,
static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
			      enum ib_cm_rej_reason reason, void *ari,
			      u8 ari_length, const void *private_data,
			      u8 private_data_len)
{
	struct cm_id_private *cm_id_priv;
	struct ib_mad_send_buf *msg;
	unsigned long flags;
	int ret;

	lockdep_assert_held(&cm_id_priv->lock);

	if ((private_data && private_data_len > IB_CM_REJ_PRIVATE_DATA_SIZE) ||
	    (ari && ari_length > IB_CM_REJ_ARI_LENGTH))
		return -EINVAL;

	cm_id_priv = container_of(cm_id, struct cm_id_private, id);

	spin_lock_irqsave(&cm_id_priv->lock, flags);
	switch (cm_id->state) {
	switch (cm_id_priv->id.state) {
	case IB_CM_REQ_SENT:
	case IB_CM_MRA_REQ_RCVD:
	case IB_CM_REQ_RCVD:
	case IB_CM_MRA_REQ_SENT:
	case IB_CM_REP_RCVD:
	case IB_CM_MRA_REP_SENT:
		ret = cm_alloc_msg(cm_id_priv, &msg);
		if (!ret)
			cm_format_rej((struct cm_rej_msg *) msg->mad,
				      cm_id_priv, reason, ari, ari_length,
				      private_data, private_data_len);

		cm_reset_to_idle(cm_id_priv);
		ret = cm_alloc_msg(cm_id_priv, &msg);
		if (ret)
			return ret;
		cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
			      ari, ari_length, private_data, private_data_len);
		break;
	case IB_CM_REP_SENT:
	case IB_CM_MRA_REP_RCVD:
		ret = cm_alloc_msg(cm_id_priv, &msg);
		if (!ret)
			cm_format_rej((struct cm_rej_msg *) msg->mad,
				      cm_id_priv, reason, ari, ari_length,
				      private_data, private_data_len);

		cm_enter_timewait(cm_id_priv);
		ret = cm_alloc_msg(cm_id_priv, &msg);
		if (ret)
			return ret;
		cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
			      ari, ari_length, private_data, private_data_len);
		break;
	default:
		pr_debug("%s: local_id %d, cm_id->state: %d\n", __func__,
			 be32_to_cpu(cm_id_priv->id.local_id), cm_id->state);
		ret = -EINVAL;
		goto out;
			 be32_to_cpu(cm_id_priv->id.local_id),
			 cm_id_priv->id.state);
		return -EINVAL;
	}

	if (ret)
		goto out;

	ret = ib_post_send_mad(msg, NULL);
	if (ret)
	if (ret) {
		cm_free_msg(msg);
		return ret;
	}

out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
	return 0;
}

int ib_send_cm_rej(struct ib_cm_id *cm_id, enum ib_cm_rej_reason reason,
		   void *ari, u8 ari_length, const void *private_data,
		   u8 private_data_len)
{
	struct cm_id_private *cm_id_priv =
		container_of(cm_id, struct cm_id_private, id);
	unsigned long flags;
	int ret;

	spin_lock_irqsave(&cm_id_priv->lock, flags);
	ret = cm_send_rej_locked(cm_id_priv, reason, ari, ari_length,
				 private_data, private_data_len);
	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
	return ret;
}
EXPORT_SYMBOL(ib_send_cm_rej);