]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - drivers/iommu/amd_iommu.c
Merge tag 'for-5.4-rc4-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/kdave/linux
[linux.git] / drivers / iommu / amd_iommu.c
index 61de81965c44ed95b52ef857b4f31a4cb8519a28..dd555078258c437c375902538ecc2f9040efd7cc 100644 (file)
@@ -70,7 +70,6 @@
  */
 #define AMD_IOMMU_PGSIZES      ((~0xFFFUL) & ~(2ULL << 38))
 
-static DEFINE_SPINLOCK(amd_iommu_devtable_lock);
 static DEFINE_SPINLOCK(pd_bitmap_lock);
 
 /* List of all available dev_data structures */
@@ -202,6 +201,7 @@ static struct iommu_dev_data *alloc_dev_data(u16 devid)
        if (!dev_data)
                return NULL;
 
+       spin_lock_init(&dev_data->lock);
        dev_data->devid = devid;
        ratelimit_default_init(&dev_data->rs);
 
@@ -436,7 +436,7 @@ static int iommu_init_device(struct device *dev)
         * invalid address), we ignore the capability for the device so
         * it'll be forced to go into translation mode.
         */
-       if ((iommu_pass_through || !amd_iommu_force_isolation) &&
+       if ((iommu_default_passthrough() || !amd_iommu_force_isolation) &&
            dev_is_pci(dev) && pci_iommuv2_capable(to_pci_dev(dev))) {
                struct amd_iommu *iommu;
 
@@ -501,6 +501,29 @@ static void iommu_uninit_device(struct device *dev)
         */
 }
 
+/*
+ * Helper function to get the first pte of a large mapping
+ */
+static u64 *first_pte_l7(u64 *pte, unsigned long *page_size,
+                        unsigned long *count)
+{
+       unsigned long pte_mask, pg_size, cnt;
+       u64 *fpte;
+
+       pg_size  = PTE_PAGE_SIZE(*pte);
+       cnt      = PAGE_SIZE_PTE_COUNT(pg_size);
+       pte_mask = ~((cnt << 3) - 1);
+       fpte     = (u64 *)(((unsigned long)pte) & pte_mask);
+
+       if (page_size)
+               *page_size = pg_size;
+
+       if (count)
+               *count = cnt;
+
+       return fpte;
+}
+
 /****************************************************************************
  *
  * Interrupt handling functions
@@ -560,7 +583,8 @@ static void iommu_print_event(struct amd_iommu *iommu, void *__evt)
 retry:
        type    = (event[1] >> EVENT_TYPE_SHIFT)  & EVENT_TYPE_MASK;
        devid   = (event[0] >> EVENT_DEVID_SHIFT) & EVENT_DEVID_MASK;
-       pasid   = PPR_PASID(*(u64 *)&event[0]);
+       pasid   = (event[0] & EVENT_DOMID_MASK_HI) |
+                 (event[1] & EVENT_DOMID_MASK_LO);
        flags   = (event[1] >> EVENT_FLAGS_SHIFT) & EVENT_FLAGS_MASK;
        address = (u64)(((u64)event[3]) << 32) | event[2];
 
@@ -593,7 +617,7 @@ static void iommu_print_event(struct amd_iommu *iommu, void *__evt)
                        address, flags);
                break;
        case EVENT_TYPE_PAGE_TAB_ERR:
-               dev_err(dev, "Event logged [PAGE_TAB_HARDWARE_ERROR device=%02x:%02x.%x domain=0x%04x address=0x%llx flags=0x%04x]\n",
+               dev_err(dev, "Event logged [PAGE_TAB_HARDWARE_ERROR device=%02x:%02x.%x pasid=0x%04x address=0x%llx flags=0x%04x]\n",
                        PCI_BUS_NUM(devid), PCI_SLOT(devid), PCI_FUNC(devid),
                        pasid, address, flags);
                break;
@@ -1311,8 +1335,12 @@ static void domain_flush_np_cache(struct protection_domain *domain,
                dma_addr_t iova, size_t size)
 {
        if (unlikely(amd_iommu_np_cache)) {
+               unsigned long flags;
+
+               spin_lock_irqsave(&domain->lock, flags);
                domain_flush_pages(domain, iova, size);
                domain_flush_complete(domain);
+               spin_unlock_irqrestore(&domain->lock, flags);
        }
 }
 
@@ -1425,7 +1453,7 @@ static void free_pagetable(struct protection_domain *domain)
        BUG_ON(domain->mode < PAGE_MODE_NONE ||
               domain->mode > PAGE_MODE_6_LEVEL);
 
-       free_sub_pt(root, domain->mode, freelist);
+       freelist = free_sub_pt(root, domain->mode, freelist);
 
        free_page_list(freelist);
 }
@@ -1435,16 +1463,18 @@ static void free_pagetable(struct protection_domain *domain)
  * another level increases the size of the address space by 9 bits to a size up
  * to 64 bits.
  */
-static void increase_address_space(struct protection_domain *domain,
+static bool increase_address_space(struct protection_domain *domain,
+                                  unsigned long address,
                                   gfp_t gfp)
 {
        unsigned long flags;
+       bool ret = false;
        u64 *pte;
 
        spin_lock_irqsave(&domain->lock, flags);
 
-       if (WARN_ON_ONCE(domain->mode == PAGE_MODE_6_LEVEL))
-               /* address space already 64 bit large */
+       if (address <= PM_LEVEL_SIZE(domain->mode) ||
+           WARN_ON_ONCE(domain->mode == PAGE_MODE_6_LEVEL))
                goto out;
 
        pte = (void *)get_zeroed_page(gfp);
@@ -1455,19 +1485,21 @@ static void increase_address_space(struct protection_domain *domain,
                                        iommu_virt_to_phys(domain->pt_root));
        domain->pt_root  = pte;
        domain->mode    += 1;
-       domain->updated  = true;
+
+       ret = true;
 
 out:
        spin_unlock_irqrestore(&domain->lock, flags);
 
-       return;
+       return ret;
 }
 
 static u64 *alloc_pte(struct protection_domain *domain,
                      unsigned long address,
                      unsigned long page_size,
                      u64 **pte_page,
-                     gfp_t gfp)
+                     gfp_t gfp,
+                     bool *updated)
 {
        int level, end_lvl;
        u64 *pte, *page;
@@ -1475,7 +1507,7 @@ static u64 *alloc_pte(struct protection_domain *domain,
        BUG_ON(!is_power_of_2(page_size));
 
        while (address > PM_LEVEL_SIZE(domain->mode))
-               increase_address_space(domain, gfp);
+               *updated = increase_address_space(domain, address, gfp) || *updated;
 
        level   = domain->mode - 1;
        pte     = &domain->pt_root[PM_LEVEL_INDEX(level, address)];
@@ -1489,9 +1521,32 @@ static u64 *alloc_pte(struct protection_domain *domain,
                __pte     = *pte;
                pte_level = PM_PTE_LEVEL(__pte);
 
-               if (!IOMMU_PTE_PRESENT(__pte) ||
+               /*
+                * If we replace a series of large PTEs, we need
+                * to tear down all of them.
+                */
+               if (IOMMU_PTE_PRESENT(__pte) &&
                    pte_level == PAGE_MODE_7_LEVEL) {
+                       unsigned long count, i;
+                       u64 *lpte;
+
+                       lpte = first_pte_l7(pte, NULL, &count);
+
+                       /*
+                        * Unmap the replicated PTEs that still match the
+                        * original large mapping
+                        */
+                       for (i = 0; i < count; ++i)
+                               cmpxchg64(&lpte[i], __pte, 0ULL);
+
+                       *updated = true;
+                       continue;
+               }
+
+               if (!IOMMU_PTE_PRESENT(__pte) ||
+                   pte_level == PAGE_MODE_NONE) {
                        page = (u64 *)get_zeroed_page(gfp);
+
                        if (!page)
                                return NULL;
 
@@ -1500,8 +1555,8 @@ static u64 *alloc_pte(struct protection_domain *domain,
                        /* pte could have been changed somewhere. */
                        if (cmpxchg64(pte, __pte, __npte) != __pte)
                                free_page((unsigned long)page);
-                       else if (pte_level == PAGE_MODE_7_LEVEL)
-                               domain->updated = true;
+                       else if (IOMMU_PTE_PRESENT(__pte))
+                               *updated = true;
 
                        continue;
                }
@@ -1566,17 +1621,12 @@ static u64 *fetch_pte(struct protection_domain *domain,
                *page_size = PTE_LEVEL_PAGE_SIZE(level);
        }
 
-       if (PM_PTE_LEVEL(*pte) == 0x07) {
-               unsigned long pte_mask;
-
-               /*
-                * If we have a series of large PTEs, make
-                * sure to return a pointer to the first one.
-                */
-               *page_size = pte_mask = PTE_PAGE_SIZE(*pte);
-               pte_mask   = ~((PAGE_SIZE_PTE_COUNT(pte_mask) << 3) - 1);
-               pte        = (u64 *)(((unsigned long)pte) & pte_mask);
-       }
+       /*
+        * If we have a series of large PTEs, make
+        * sure to return a pointer to the first one.
+        */
+       if (PM_PTE_LEVEL(*pte) == PAGE_MODE_7_LEVEL)
+               pte = first_pte_l7(pte, page_size, NULL);
 
        return pte;
 }
@@ -1615,26 +1665,29 @@ static int iommu_map_page(struct protection_domain *dom,
                          gfp_t gfp)
 {
        struct page *freelist = NULL;
+       bool updated = false;
        u64 __pte, *pte;
-       int i, count;
+       int ret, i, count;
 
        BUG_ON(!IS_ALIGNED(bus_addr, page_size));
        BUG_ON(!IS_ALIGNED(phys_addr, page_size));
 
+       ret = -EINVAL;
        if (!(prot & IOMMU_PROT_MASK))
-               return -EINVAL;
+               goto out;
 
        count = PAGE_SIZE_PTE_COUNT(page_size);
-       pte   = alloc_pte(dom, bus_addr, page_size, NULL, gfp);
+       pte   = alloc_pte(dom, bus_addr, page_size, NULL, gfp, &updated);
 
+       ret = -ENOMEM;
        if (!pte)
-               return -ENOMEM;
+               goto out;
 
        for (i = 0; i < count; ++i)
                freelist = free_clear_pte(&pte[i], pte[i], freelist);
 
        if (freelist != NULL)
-               dom->updated = true;
+               updated = true;
 
        if (count > 1) {
                __pte = PAGE_SIZE_PTE(__sme_set(phys_addr), page_size);
@@ -1650,12 +1703,21 @@ static int iommu_map_page(struct protection_domain *dom,
        for (i = 0; i < count; ++i)
                pte[i] = __pte;
 
-       update_domain(dom);
+       ret = 0;
+
+out:
+       if (updated) {
+               unsigned long flags;
+
+               spin_lock_irqsave(&dom->lock, flags);
+               update_domain(dom);
+               spin_unlock_irqrestore(&dom->lock, flags);
+       }
 
        /* Everything flushed out, free pages now */
        free_page_list(freelist);
 
-       return 0;
+       return ret;
 }
 
 static unsigned long iommu_unmap_page(struct protection_domain *dom,
@@ -1806,8 +1868,12 @@ static void free_gcr3_table(struct protection_domain *domain)
 
 static void dma_ops_domain_flush_tlb(struct dma_ops_domain *dom)
 {
+       unsigned long flags;
+
+       spin_lock_irqsave(&dom->domain.lock, flags);
        domain_flush_tlb(&dom->domain);
        domain_flush_complete(&dom->domain);
+       spin_unlock_irqrestore(&dom->domain.lock, flags);
 }
 
 static void iova_domain_flush_tlb(struct iova_domain *iovad)
@@ -2022,36 +2088,6 @@ static void do_detach(struct iommu_dev_data *dev_data)
        domain->dev_cnt                 -= 1;
 }
 
-/*
- * If a device is not yet associated with a domain, this function makes the
- * device visible in the domain
- */
-static int __attach_device(struct iommu_dev_data *dev_data,
-                          struct protection_domain *domain)
-{
-       int ret;
-
-       /* lock domain */
-       spin_lock(&domain->lock);
-
-       ret = -EBUSY;
-       if (dev_data->domain != NULL)
-               goto out_unlock;
-
-       /* Attach alias group root */
-       do_attach(dev_data, domain);
-
-       ret = 0;
-
-out_unlock:
-
-       /* ready */
-       spin_unlock(&domain->lock);
-
-       return ret;
-}
-
-
 static void pdev_iommuv2_disable(struct pci_dev *pdev)
 {
        pci_disable_ats(pdev);
@@ -2133,19 +2169,28 @@ static int attach_device(struct device *dev,
        unsigned long flags;
        int ret;
 
+       spin_lock_irqsave(&domain->lock, flags);
+
        dev_data = get_dev_data(dev);
 
+       spin_lock(&dev_data->lock);
+
+       ret = -EBUSY;
+       if (dev_data->domain != NULL)
+               goto out;
+
        if (!dev_is_pci(dev))
                goto skip_ats_check;
 
        pdev = to_pci_dev(dev);
        if (domain->flags & PD_IOMMUV2_MASK) {
+               ret = -EINVAL;
                if (!dev_data->passthrough)
-                       return -EINVAL;
+                       goto out;
 
                if (dev_data->iommu_v2) {
                        if (pdev_iommuv2_enable(pdev) != 0)
-                               return -EINVAL;
+                               goto out;
 
                        dev_data->ats.enabled = true;
                        dev_data->ats.qdep    = pci_ats_queue_depth(pdev);
@@ -2158,9 +2203,9 @@ static int attach_device(struct device *dev,
        }
 
 skip_ats_check:
-       spin_lock_irqsave(&amd_iommu_devtable_lock, flags);
-       ret = __attach_device(dev_data, domain);
-       spin_unlock_irqrestore(&amd_iommu_devtable_lock, flags);
+       ret = 0;
+
+       do_attach(dev_data, domain);
 
        /*
         * We might boot into a crash-kernel here. The crashed kernel
@@ -2169,23 +2214,14 @@ static int attach_device(struct device *dev,
         */
        domain_flush_tlb_pde(domain);
 
-       return ret;
-}
-
-/*
- * Removes a device from a protection domain (unlocked)
- */
-static void __detach_device(struct iommu_dev_data *dev_data)
-{
-       struct protection_domain *domain;
-
-       domain = dev_data->domain;
+       domain_flush_complete(domain);
 
-       spin_lock(&domain->lock);
+out:
+       spin_unlock(&dev_data->lock);
 
-       do_detach(dev_data);
+       spin_unlock_irqrestore(&domain->lock, flags);
 
-       spin_unlock(&domain->lock);
+       return ret;
 }
 
 /*
@@ -2200,6 +2236,10 @@ static void detach_device(struct device *dev)
        dev_data = get_dev_data(dev);
        domain   = dev_data->domain;
 
+       spin_lock_irqsave(&domain->lock, flags);
+
+       spin_lock(&dev_data->lock);
+
        /*
         * First check if the device is still attached. It might already
         * be detached from its domain because the generic
@@ -2207,15 +2247,12 @@ static void detach_device(struct device *dev)
         * our alias handling.
         */
        if (WARN_ON(!dev_data->domain))
-               return;
+               goto out;
 
-       /* lock device table */
-       spin_lock_irqsave(&amd_iommu_devtable_lock, flags);
-       __detach_device(dev_data);
-       spin_unlock_irqrestore(&amd_iommu_devtable_lock, flags);
+       do_detach(dev_data);
 
        if (!dev_is_pci(dev))
-               return;
+               goto out;
 
        if (domain->flags & PD_IOMMUV2_MASK && dev_data->iommu_v2)
                pdev_iommuv2_disable(to_pci_dev(dev));
@@ -2223,6 +2260,11 @@ static void detach_device(struct device *dev)
                pci_disable_ats(to_pci_dev(dev));
 
        dev_data->ats.enabled = false;
+
+out:
+       spin_unlock(&dev_data->lock);
+
+       spin_unlock_irqrestore(&domain->lock, flags);
 }
 
 static int amd_iommu_add_device(struct device *dev)
@@ -2256,7 +2298,7 @@ static int amd_iommu_add_device(struct device *dev)
 
        BUG_ON(!dev_data);
 
-       if (iommu_pass_through || dev_data->iommu_v2)
+       if (dev_data->iommu_v2)
                iommu_request_dm_for_dev(dev);
 
        /* Domains are initialized for this device - have a look what we ended up with */
@@ -2354,15 +2396,10 @@ static void update_device_table(struct protection_domain *domain)
 
 static void update_domain(struct protection_domain *domain)
 {
-       if (!domain->updated)
-               return;
-
        update_device_table(domain);
 
        domain_flush_devices(domain);
        domain_flush_tlb_pde(domain);
-
-       domain->updated = false;
 }
 
 static int dir2prot(enum dma_data_direction direction)
@@ -2392,6 +2429,7 @@ static dma_addr_t __map_single(struct device *dev,
 {
        dma_addr_t offset = paddr & ~PAGE_MASK;
        dma_addr_t address, start, ret;
+       unsigned long flags;
        unsigned int pages;
        int prot = 0;
        int i;
@@ -2429,8 +2467,10 @@ static dma_addr_t __map_single(struct device *dev,
                iommu_unmap_page(&dma_dom->domain, start, PAGE_SIZE);
        }
 
+       spin_lock_irqsave(&dma_dom->domain.lock, flags);
        domain_flush_tlb(&dma_dom->domain);
        domain_flush_complete(&dma_dom->domain);
+       spin_unlock_irqrestore(&dma_dom->domain.lock, flags);
 
        dma_ops_free_iova(dma_dom, address, pages);
 
@@ -2459,8 +2499,12 @@ static void __unmap_single(struct dma_ops_domain *dma_dom,
        }
 
        if (amd_iommu_unmap_flush) {
+               unsigned long flags;
+
+               spin_lock_irqsave(&dma_dom->domain.lock, flags);
                domain_flush_tlb(&dma_dom->domain);
                domain_flush_complete(&dma_dom->domain);
+               spin_unlock_irqrestore(&dma_dom->domain.lock, flags);
                dma_ops_free_iova(dma_dom, dma_addr, pages);
        } else {
                pages = __roundup_pow_of_two(pages);
@@ -2577,7 +2621,9 @@ static int map_sg(struct device *dev, struct scatterlist *sglist,
 
                        bus_addr  = address + s->dma_address + (j << PAGE_SHIFT);
                        phys_addr = (sg_phys(s) & PAGE_MASK) + (j << PAGE_SHIFT);
-                       ret = iommu_map_page(domain, bus_addr, phys_addr, PAGE_SIZE, prot, GFP_ATOMIC);
+                       ret = iommu_map_page(domain, bus_addr, phys_addr,
+                                            PAGE_SIZE, prot,
+                                            GFP_ATOMIC | __GFP_NOWARN);
                        if (ret)
                                goto out_unmap;
 
@@ -2752,6 +2798,8 @@ static const struct dma_map_ops amd_iommu_dma_ops = {
        .map_sg         = map_sg,
        .unmap_sg       = unmap_sg,
        .dma_supported  = amd_iommu_dma_supported,
+       .mmap           = dma_common_mmap,
+       .get_sgtable    = dma_common_get_sgtable,
 };
 
 static int init_reserved_iova_ranges(void)
@@ -2835,7 +2883,7 @@ int __init amd_iommu_init_api(void)
 
 int __init amd_iommu_init_dma_ops(void)
 {
-       swiotlb        = (iommu_pass_through || sme_me_mask) ? 1 : 0;
+       swiotlb        = (iommu_default_passthrough() || sme_me_mask) ? 1 : 0;
        iommu_detected = 1;
 
        if (amd_iommu_unmap_flush)
@@ -2862,16 +2910,16 @@ static void cleanup_domain(struct protection_domain *domain)
        struct iommu_dev_data *entry;
        unsigned long flags;
 
-       spin_lock_irqsave(&amd_iommu_devtable_lock, flags);
+       spin_lock_irqsave(&domain->lock, flags);
 
        while (!list_empty(&domain->dev_list)) {
                entry = list_first_entry(&domain->dev_list,
                                         struct iommu_dev_data, list);
                BUG_ON(!entry->domain);
-               __detach_device(entry);
+               do_detach(entry);
        }
 
-       spin_unlock_irqrestore(&amd_iommu_devtable_lock, flags);
+       spin_unlock_irqrestore(&domain->lock, flags);
 }
 
 static void protection_domain_free(struct protection_domain *domain)
@@ -3085,7 +3133,8 @@ static int amd_iommu_map(struct iommu_domain *dom, unsigned long iova,
 }
 
 static size_t amd_iommu_unmap(struct iommu_domain *dom, unsigned long iova,
-                          size_t page_size)
+                             size_t page_size,
+                             struct iommu_iotlb_gather *gather)
 {
        struct protection_domain *domain = to_pdomain(dom);
        size_t unmap_size;
@@ -3221,14 +3270,18 @@ static bool amd_iommu_is_attach_deferred(struct iommu_domain *domain,
 static void amd_iommu_flush_iotlb_all(struct iommu_domain *domain)
 {
        struct protection_domain *dom = to_pdomain(domain);
+       unsigned long flags;
 
+       spin_lock_irqsave(&dom->lock, flags);
        domain_flush_tlb_pde(dom);
        domain_flush_complete(dom);
+       spin_unlock_irqrestore(&dom->lock, flags);
 }
 
-static void amd_iommu_iotlb_range_add(struct iommu_domain *domain,
-                                     unsigned long iova, size_t size)
+static void amd_iommu_iotlb_sync(struct iommu_domain *domain,
+                                struct iommu_iotlb_gather *gather)
 {
+       amd_iommu_flush_iotlb_all(domain);
 }
 
 const struct iommu_ops amd_iommu_ops = {
@@ -3249,8 +3302,7 @@ const struct iommu_ops amd_iommu_ops = {
        .is_attach_deferred = amd_iommu_is_attach_deferred,
        .pgsize_bitmap  = AMD_IOMMU_PGSIZES,
        .flush_iotlb_all = amd_iommu_flush_iotlb_all,
-       .iotlb_range_add = amd_iommu_iotlb_range_add,
-       .iotlb_sync = amd_iommu_flush_iotlb_all,
+       .iotlb_sync = amd_iommu_iotlb_sync,
 };
 
 /*****************************************************************************
@@ -3285,7 +3337,6 @@ void amd_iommu_domain_direct_map(struct iommu_domain *dom)
 
        /* Update data structure */
        domain->mode    = PAGE_MODE_NONE;
-       domain->updated = true;
 
        /* Make changes visible to IOMMUs */
        update_domain(domain);
@@ -3331,7 +3382,6 @@ int amd_iommu_domain_enable_v2(struct iommu_domain *dom, int pasids)
 
        domain->glx      = levels;
        domain->flags   |= PD_IOMMUV2_MASK;
-       domain->updated  = true;
 
        update_domain(domain);
 
@@ -4343,13 +4393,62 @@ static const struct irq_domain_ops amd_ir_domain_ops = {
        .deactivate = irq_remapping_deactivate,
 };
 
+int amd_iommu_activate_guest_mode(void *data)
+{
+       struct amd_ir_data *ir_data = (struct amd_ir_data *)data;
+       struct irte_ga *entry = (struct irte_ga *) ir_data->entry;
+
+       if (!AMD_IOMMU_GUEST_IR_VAPIC(amd_iommu_guest_ir) ||
+           !entry || entry->lo.fields_vapic.guest_mode)
+               return 0;
+
+       entry->lo.val = 0;
+       entry->hi.val = 0;
+
+       entry->lo.fields_vapic.guest_mode  = 1;
+       entry->lo.fields_vapic.ga_log_intr = 1;
+       entry->hi.fields.ga_root_ptr       = ir_data->ga_root_ptr;
+       entry->hi.fields.vector            = ir_data->ga_vector;
+       entry->lo.fields_vapic.ga_tag      = ir_data->ga_tag;
+
+       return modify_irte_ga(ir_data->irq_2_irte.devid,
+                             ir_data->irq_2_irte.index, entry, NULL);
+}
+EXPORT_SYMBOL(amd_iommu_activate_guest_mode);
+
+int amd_iommu_deactivate_guest_mode(void *data)
+{
+       struct amd_ir_data *ir_data = (struct amd_ir_data *)data;
+       struct irte_ga *entry = (struct irte_ga *) ir_data->entry;
+       struct irq_cfg *cfg = ir_data->cfg;
+
+       if (!AMD_IOMMU_GUEST_IR_VAPIC(amd_iommu_guest_ir) ||
+           !entry || !entry->lo.fields_vapic.guest_mode)
+               return 0;
+
+       entry->lo.val = 0;
+       entry->hi.val = 0;
+
+       entry->lo.fields_remap.dm          = apic->irq_dest_mode;
+       entry->lo.fields_remap.int_type    = apic->irq_delivery_mode;
+       entry->hi.fields.vector            = cfg->vector;
+       entry->lo.fields_remap.destination =
+                               APICID_TO_IRTE_DEST_LO(cfg->dest_apicid);
+       entry->hi.fields.destination =
+                               APICID_TO_IRTE_DEST_HI(cfg->dest_apicid);
+
+       return modify_irte_ga(ir_data->irq_2_irte.devid,
+                             ir_data->irq_2_irte.index, entry, NULL);
+}
+EXPORT_SYMBOL(amd_iommu_deactivate_guest_mode);
+
 static int amd_ir_set_vcpu_affinity(struct irq_data *data, void *vcpu_info)
 {
+       int ret;
        struct amd_iommu *iommu;
        struct amd_iommu_pi_data *pi_data = vcpu_info;
        struct vcpu_data *vcpu_pi_info = pi_data->vcpu_data;
        struct amd_ir_data *ir_data = data->chip_data;
-       struct irte_ga *irte = (struct irte_ga *) ir_data->entry;
        struct irq_2_irte *irte_info = &ir_data->irq_2_irte;
        struct iommu_dev_data *dev_data = search_dev_data(irte_info->devid);
 
@@ -4360,6 +4459,7 @@ static int amd_ir_set_vcpu_affinity(struct irq_data *data, void *vcpu_info)
        if (!dev_data || !dev_data->use_vapic)
                return 0;
 
+       ir_data->cfg = irqd_cfg(data);
        pi_data->ir_data = ir_data;
 
        /* Note:
@@ -4378,37 +4478,24 @@ static int amd_ir_set_vcpu_affinity(struct irq_data *data, void *vcpu_info)
 
        pi_data->prev_ga_tag = ir_data->cached_ga_tag;
        if (pi_data->is_guest_mode) {
-               /* Setting */
-               irte->hi.fields.ga_root_ptr = (pi_data->base >> 12);
-               irte->hi.fields.vector = vcpu_pi_info->vector;
-               irte->lo.fields_vapic.ga_log_intr = 1;
-               irte->lo.fields_vapic.guest_mode = 1;
-               irte->lo.fields_vapic.ga_tag = pi_data->ga_tag;
-
-               ir_data->cached_ga_tag = pi_data->ga_tag;
+               ir_data->ga_root_ptr = (pi_data->base >> 12);
+               ir_data->ga_vector = vcpu_pi_info->vector;
+               ir_data->ga_tag = pi_data->ga_tag;
+               ret = amd_iommu_activate_guest_mode(ir_data);
+               if (!ret)
+                       ir_data->cached_ga_tag = pi_data->ga_tag;
        } else {
-               /* Un-Setting */
-               struct irq_cfg *cfg = irqd_cfg(data);
-
-               irte->hi.val = 0;
-               irte->lo.val = 0;
-               irte->hi.fields.vector = cfg->vector;
-               irte->lo.fields_remap.guest_mode = 0;
-               irte->lo.fields_remap.destination =
-                               APICID_TO_IRTE_DEST_LO(cfg->dest_apicid);
-               irte->hi.fields.destination =
-                               APICID_TO_IRTE_DEST_HI(cfg->dest_apicid);
-               irte->lo.fields_remap.int_type = apic->irq_delivery_mode;
-               irte->lo.fields_remap.dm = apic->irq_dest_mode;
+               ret = amd_iommu_deactivate_guest_mode(ir_data);
 
                /*
                 * This communicates the ga_tag back to the caller
                 * so that it can do all the necessary clean up.
                 */
-               ir_data->cached_ga_tag = 0;
+               if (!ret)
+                       ir_data->cached_ga_tag = 0;
        }
 
-       return modify_irte_ga(irte_info->devid, irte_info->index, irte, ir_data);
+       return ret;
 }