#include <linux/mlx5/vport.h>
#include "mlx5_core.h"
+/* Mutex to hold while enabling or disabling RoCE */
+static DEFINE_MUTEX(mlx5_roce_en_lock);
+
static int _mlx5_query_vport_state(struct mlx5_core_dev *mdev, u8 opmod,
u16 vport, u32 *out, int outlen)
{
int mlx5_nic_vport_enable_roce(struct mlx5_core_dev *mdev)
{
- if (atomic_inc_return(&mdev->roce.roce_en) != 1)
- return 0;
- return mlx5_nic_vport_update_roce_state(mdev, MLX5_VPORT_ROCE_ENABLED);
+ int err = 0;
+
+ mutex_lock(&mlx5_roce_en_lock);
+ if (!mdev->roce.roce_en)
+ err = mlx5_nic_vport_update_roce_state(mdev, MLX5_VPORT_ROCE_ENABLED);
+
+ if (!err)
+ mdev->roce.roce_en++;
+ mutex_unlock(&mlx5_roce_en_lock);
+
+ return err;
}
EXPORT_SYMBOL_GPL(mlx5_nic_vport_enable_roce);
int mlx5_nic_vport_disable_roce(struct mlx5_core_dev *mdev)
{
- if (atomic_dec_return(&mdev->roce.roce_en) != 0)
- return 0;
- return mlx5_nic_vport_update_roce_state(mdev, MLX5_VPORT_ROCE_DISABLED);
+ int err = 0;
+
+ mutex_lock(&mlx5_roce_en_lock);
+ if (mdev->roce.roce_en) {
+ mdev->roce.roce_en--;
+ if (mdev->roce.roce_en == 0)
+ err = mlx5_nic_vport_update_roce_state(mdev, MLX5_VPORT_ROCE_DISABLED);
+
+ if (err)
+ mdev->roce.roce_en++;
+ }
+ mutex_unlock(&mlx5_roce_en_lock);
+ return err;
}
EXPORT_SYMBOL_GPL(mlx5_nic_vport_disable_roce);