summaryrefslogtreecommitdiffstats
path: root/drivers/vfio/vfio.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vfio/vfio.c')
-rw-r--r--drivers/vfio/vfio.c210
1 files changed, 62 insertions, 148 deletions
diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index 38779e6fd80c..5e631c359ef2 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -46,7 +46,6 @@ static struct vfio {
struct mutex group_lock;
struct cdev group_cdev;
dev_t group_devt;
- wait_queue_head_t release_q;
} vfio;
struct vfio_iommu_driver {
@@ -90,15 +89,6 @@ struct vfio_group {
struct blocking_notifier_head notifier;
};
-struct vfio_device {
- struct kref kref;
- struct device *dev;
- const struct vfio_device_ops *ops;
- struct vfio_group *group;
- struct list_head group_next;
- void *device_data;
-};
-
#ifdef CONFIG_VFIO_NOIOMMU
static bool noiommu __read_mostly;
module_param_named(enable_unsafe_noiommu_mode,
@@ -109,8 +99,8 @@ MODULE_PARM_DESC(enable_unsafe_noiommu_mode, "Enable UNSAFE, no-IOMMU mode. Thi
/*
* vfio_iommu_group_{get,put} are only intended for VFIO bus driver probe
* and remove functions, any use cases other than acquiring the first
- * reference for the purpose of calling vfio_add_group_dev() or removing
- * that symmetric reference after vfio_del_group_dev() should use the raw
+ * reference for the purpose of calling vfio_register_group_dev() or removing
+ * that symmetric reference after vfio_unregister_group_dev() should use the raw
* iommu_group_{get,put} functions. In particular, vfio_iommu_group_put()
* removes the device from the dummy group and cannot be nested.
*/
@@ -532,67 +522,17 @@ static struct vfio_group *vfio_group_get_from_dev(struct device *dev)
/**
* Device objects - create, release, get, put, search
*/
-static
-struct vfio_device *vfio_group_create_device(struct vfio_group *group,
- struct device *dev,
- const struct vfio_device_ops *ops,
- void *device_data)
-{
- struct vfio_device *device;
-
- device = kzalloc(sizeof(*device), GFP_KERNEL);
- if (!device)
- return ERR_PTR(-ENOMEM);
-
- kref_init(&device->kref);
- device->dev = dev;
- device->group = group;
- device->ops = ops;
- device->device_data = device_data;
- dev_set_drvdata(dev, device);
-
- /* No need to get group_lock, caller has group reference */
- vfio_group_get(group);
-
- mutex_lock(&group->device_lock);
- list_add(&device->group_next, &group->device_list);
- group->dev_counter++;
- mutex_unlock(&group->device_lock);
-
- return device;
-}
-
-static void vfio_device_release(struct kref *kref)
-{
- struct vfio_device *device = container_of(kref,
- struct vfio_device, kref);
- struct vfio_group *group = device->group;
-
- list_del(&device->group_next);
- group->dev_counter--;
- mutex_unlock(&group->device_lock);
-
- dev_set_drvdata(device->dev, NULL);
-
- kfree(device);
-
- /* vfio_del_group_dev may be waiting for this device */
- wake_up(&vfio.release_q);
-}
-
/* Device reference always implies a group reference */
void vfio_device_put(struct vfio_device *device)
{
- struct vfio_group *group = device->group;
- kref_put_mutex(&device->kref, vfio_device_release, &group->device_lock);
- vfio_group_put(group);
+ if (refcount_dec_and_test(&device->refcount))
+ complete(&device->comp);
}
EXPORT_SYMBOL_GPL(vfio_device_put);
-static void vfio_device_get(struct vfio_device *device)
+static bool vfio_device_try_get(struct vfio_device *device)
{
- vfio_group_get(device->group);
- kref_get(&device->kref);
+ return refcount_inc_not_zero(&device->refcount);
}
static struct vfio_device *vfio_group_get_device(struct vfio_group *group,
@@ -602,8 +542,7 @@ static struct vfio_device *vfio_group_get_device(struct vfio_group *group,
mutex_lock(&group->device_lock);
list_for_each_entry(device, &group->device_list, group_next) {
- if (device->dev == dev) {
- vfio_device_get(device);
+ if (device->dev == dev && vfio_device_try_get(device)) {
mutex_unlock(&group->device_lock);
return device;
}
@@ -801,14 +740,22 @@ static int vfio_iommu_group_notifier(struct notifier_block *nb,
/**
* VFIO driver API
*/
-int vfio_add_group_dev(struct device *dev,
- const struct vfio_device_ops *ops, void *device_data)
+void vfio_init_group_dev(struct vfio_device *device, struct device *dev,
+ const struct vfio_device_ops *ops)
+{
+ init_completion(&device->comp);
+ device->dev = dev;
+ device->ops = ops;
+}
+EXPORT_SYMBOL_GPL(vfio_init_group_dev);
+
+int vfio_register_group_dev(struct vfio_device *device)
{
+ struct vfio_device *existing_device;
struct iommu_group *iommu_group;
struct vfio_group *group;
- struct vfio_device *device;
- iommu_group = iommu_group_get(dev);
+ iommu_group = iommu_group_get(device->dev);
if (!iommu_group)
return -EINVAL;
@@ -827,31 +774,29 @@ int vfio_add_group_dev(struct device *dev,
iommu_group_put(iommu_group);
}
- device = vfio_group_get_device(group, dev);
- if (device) {
- dev_WARN(dev, "Device already exists on group %d\n",
+ existing_device = vfio_group_get_device(group, device->dev);
+ if (existing_device) {
+ dev_WARN(device->dev, "Device already exists on group %d\n",
iommu_group_id(iommu_group));
- vfio_device_put(device);
+ vfio_device_put(existing_device);
vfio_group_put(group);
return -EBUSY;
}
- device = vfio_group_create_device(group, dev, ops, device_data);
- if (IS_ERR(device)) {
- vfio_group_put(group);
- return PTR_ERR(device);
- }
+ /* Our reference on group is moved to the device */
+ device->group = group;
- /*
- * Drop all but the vfio_device reference. The vfio_device holds
- * a reference to the vfio_group, which holds a reference to the
- * iommu_group.
- */
- vfio_group_put(group);
+ /* Refcounting can't start until the driver calls register */
+ refcount_set(&device->refcount, 1);
+
+ mutex_lock(&group->device_lock);
+ list_add(&device->group_next, &group->device_list);
+ group->dev_counter++;
+ mutex_unlock(&group->device_lock);
return 0;
}
-EXPORT_SYMBOL_GPL(vfio_add_group_dev);
+EXPORT_SYMBOL_GPL(vfio_register_group_dev);
/**
* Get a reference to the vfio_device for a device. Even if the
@@ -886,7 +831,7 @@ static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
int ret;
if (it->ops->match) {
- ret = it->ops->match(it->device_data, buf);
+ ret = it->ops->match(it, buf);
if (ret < 0) {
device = ERR_PTR(ret);
break;
@@ -895,9 +840,8 @@ static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
ret = !strcmp(dev_name(it->dev), buf);
}
- if (ret) {
+ if (ret && vfio_device_try_get(it)) {
device = it;
- vfio_device_get(device);
break;
}
}
@@ -907,32 +851,15 @@ static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
}
/*
- * Caller must hold a reference to the vfio_device
- */
-void *vfio_device_data(struct vfio_device *device)
-{
- return device->device_data;
-}
-EXPORT_SYMBOL_GPL(vfio_device_data);
-
-/*
* Decrement the device reference count and wait for the device to be
* removed. Open file descriptors for the device... */
-void *vfio_del_group_dev(struct device *dev)
+void vfio_unregister_group_dev(struct vfio_device *device)
{
- DEFINE_WAIT_FUNC(wait, woken_wake_function);
- struct vfio_device *device = dev_get_drvdata(dev);
struct vfio_group *group = device->group;
- void *device_data = device->device_data;
struct vfio_unbound_dev *unbound;
unsigned int i = 0;
bool interrupted = false;
-
- /*
- * The group exists so long as we have a device reference. Get
- * a group reference and use it to scan for the device going away.
- */
- vfio_group_get(group);
+ long rc;
/*
* When the device is removed from the group, the group suddenly
@@ -945,7 +872,7 @@ void *vfio_del_group_dev(struct device *dev)
*/
unbound = kzalloc(sizeof(*unbound), GFP_KERNEL);
if (unbound) {
- unbound->dev = dev;
+ unbound->dev = device->dev;
mutex_lock(&group->unbound_lock);
list_add(&unbound->unbound_next, &group->unbound_list);
mutex_unlock(&group->unbound_lock);
@@ -953,44 +880,33 @@ void *vfio_del_group_dev(struct device *dev)
WARN_ON(!unbound);
vfio_device_put(device);
-
- /*
- * If the device is still present in the group after the above
- * 'put', then it is in use and we need to request it from the
- * bus driver. The driver may in turn need to request the
- * device from the user. We send the request on an arbitrary
- * interval with counter to allow the driver to take escalating
- * measures to release the device if it has the ability to do so.
- */
- add_wait_queue(&vfio.release_q, &wait);
-
- do {
- device = vfio_group_get_device(group, dev);
- if (!device)
- break;
-
+ rc = try_wait_for_completion(&device->comp);
+ while (rc <= 0) {
if (device->ops->request)
- device->ops->request(device_data, i++);
-
- vfio_device_put(device);
+ device->ops->request(device, i++);
if (interrupted) {
- wait_woken(&wait, TASK_UNINTERRUPTIBLE, HZ * 10);
+ rc = wait_for_completion_timeout(&device->comp,
+ HZ * 10);
} else {
- wait_woken(&wait, TASK_INTERRUPTIBLE, HZ * 10);
- if (signal_pending(current)) {
+ rc = wait_for_completion_interruptible_timeout(
+ &device->comp, HZ * 10);
+ if (rc < 0) {
interrupted = true;
- dev_warn(dev,
+ dev_warn(device->dev,
"Device is currently in use, task"
" \"%s\" (%d) "
"blocked until device is released",
current->comm, task_pid_nr(current));
}
}
+ }
- } while (1);
+ mutex_lock(&group->device_lock);
+ list_del(&device->group_next);
+ group->dev_counter--;
+ mutex_unlock(&group->device_lock);
- remove_wait_queue(&vfio.release_q, &wait);
/*
* In order to support multiple devices per group, devices can be
* plucked from the group while other devices in the group are still
@@ -1008,11 +924,10 @@ void *vfio_del_group_dev(struct device *dev)
if (list_empty(&group->device_list))
wait_event(group->container_q, !group->container);
+ /* Matches the get in vfio_register_group_dev() */
vfio_group_put(group);
-
- return device_data;
}
-EXPORT_SYMBOL_GPL(vfio_del_group_dev);
+EXPORT_SYMBOL_GPL(vfio_unregister_group_dev);
/**
* VFIO base fd, /dev/vfio/vfio
@@ -1454,7 +1369,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
if (IS_ERR(device))
return PTR_ERR(device);
- ret = device->ops->open(device->device_data);
+ ret = device->ops->open(device);
if (ret) {
vfio_device_put(device);
return ret;
@@ -1466,7 +1381,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
*/
ret = get_unused_fd_flags(O_CLOEXEC);
if (ret < 0) {
- device->ops->release(device->device_data);
+ device->ops->release(device);
vfio_device_put(device);
return ret;
}
@@ -1476,7 +1391,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
if (IS_ERR(filep)) {
put_unused_fd(ret);
ret = PTR_ERR(filep);
- device->ops->release(device->device_data);
+ device->ops->release(device);
vfio_device_put(device);
return ret;
}
@@ -1633,7 +1548,7 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep)
{
struct vfio_device *device = filep->private_data;
- device->ops->release(device->device_data);
+ device->ops->release(device);
vfio_group_try_dissolve_container(device->group);
@@ -1650,7 +1565,7 @@ static long vfio_device_fops_unl_ioctl(struct file *filep,
if (unlikely(!device->ops->ioctl))
return -EINVAL;
- return device->ops->ioctl(device->device_data, cmd, arg);
+ return device->ops->ioctl(device, cmd, arg);
}
static ssize_t vfio_device_fops_read(struct file *filep, char __user *buf,
@@ -1661,7 +1576,7 @@ static ssize_t vfio_device_fops_read(struct file *filep, char __user *buf,
if (unlikely(!device->ops->read))
return -EINVAL;
- return device->ops->read(device->device_data, buf, count, ppos);
+ return device->ops->read(device, buf, count, ppos);
}
static ssize_t vfio_device_fops_write(struct file *filep,
@@ -1673,7 +1588,7 @@ static ssize_t vfio_device_fops_write(struct file *filep,
if (unlikely(!device->ops->write))
return -EINVAL;
- return device->ops->write(device->device_data, buf, count, ppos);
+ return device->ops->write(device, buf, count, ppos);
}
static int vfio_device_fops_mmap(struct file *filep, struct vm_area_struct *vma)
@@ -1683,7 +1598,7 @@ static int vfio_device_fops_mmap(struct file *filep, struct vm_area_struct *vma)
if (unlikely(!device->ops->mmap))
return -EINVAL;
- return device->ops->mmap(device->device_data, vma);
+ return device->ops->mmap(device, vma);
}
static const struct file_operations vfio_device_fops = {
@@ -2379,7 +2294,6 @@ static int __init vfio_init(void)
mutex_init(&vfio.iommu_drivers_lock);
INIT_LIST_HEAD(&vfio.group_list);
INIT_LIST_HEAD(&vfio.iommu_drivers_list);
- init_waitqueue_head(&vfio.release_q);
ret = misc_register(&vfio_dev);
if (ret) {