extern void powernv_set_nmmu_ptcr(unsigned long ptcr);
extern struct npu_context *pnv_npu2_init_context(struct pci_dev *gpdev,
unsigned long flags,
- struct npu_context *(*cb)(struct npu_context *, void *),
+ void (*cb)(struct npu_context *, void *),
void *priv);
extern void pnv_npu2_destroy_context(struct npu_context *context,
struct pci_dev *gpdev);
bool nmmu_flush;
/* Callback to stop translation requests on a given GPU */
- struct npu_context *(*release_cb)(struct npu_context *, void *);
+ void (*release_cb)(struct npu_context *context, void *priv);
/*
* Private pointer passed to the above callback for usage by
*/
struct npu_context *pnv_npu2_init_context(struct pci_dev *gpdev,
unsigned long flags,
- struct npu_context *(*cb)(struct npu_context *, void *),
+ void (*cb)(struct npu_context *, void *),
void *priv)
{
int rc;
*/
spin_lock(&npu_context_lock);
npu_context = mm->context.npu_context;
- if (npu_context)
+ if (npu_context) {
+ if (npu_context->release_cb != cb ||
+ npu_context->priv != priv) {
+ spin_unlock(&npu_context_lock);
+ opal_npu_destroy_context(nphb->opal_id, mm->context.id,
+ PCI_DEVID(gpdev->bus->number,
+ gpdev->devfn));
+ return ERR_PTR(-EINVAL);
+ }
+
WARN_ON(!kref_get_unless_zero(&npu_context->kref));
+ }
spin_unlock(&npu_context_lock);
if (!npu_context) {