// SPDX-License-Identifier: GPL-2.0
/*
 * Loongson IOMMU Driver
 *
 * Copyright (C) 2020-2021 Loongson Technology Ltd.
 * Author:	Lv Chen <lvchen@loongson.cn>
 *		Wang Yang <wangyang@loongson.cn>
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 */

#include <linux/printk.h>
#include <linux/device.h>
#include <linux/errno.h>
#include <linux/list.h>
#include <linux/spinlock.h>
#include <linux/iommu.h>
#include <linux/sizes.h>
#include <asm/addrspace.h>
#include <asm/mach-la64/mem.h>
#include <linux/delay.h>
#include <linux/slab.h>
#include <linux/pci.h>
#include <linux/acpi.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/io.h>
#include <linux/interrupt.h>
#include <linux/err.h>
#include <linux/pci_regs.h>
#include "loongarch_iommu.h"

#define LOOP_TIMEOUT			100000

#define IVRS_HEADER_LENGTH		48
#define ACPI_IVHD_TYPE_MAX_SUPPORTED	0x40
#define IVHD_DEV_ALL                    0x01
#define IVHD_DEV_SELECT                 0x02
#define IVHD_DEV_SELECT_RANGE_START     0x03
#define IVHD_DEV_RANGE_END              0x04
#define IVHD_DEV_ALIAS                  0x42
#define IVHD_DEV_EXT_SELECT             0x46
#define IVHD_DEV_ACPI_HID		0xf0

#define IVHD_HEAD_TYPE10		0x10
#define IVHD_HEAD_TYPE11		0x11
#define IVHD_HEAD_TYPE40		0x40

#define MAX_BDF_NUM			0xffff

#define RLOOKUP_TABLE_ENTRY_SIZE	(sizeof(void *))

/*
 * structure describing one IOMMU in the ACPI table. Typically followed by one
 * or more ivhd_entries.
 */
struct ivhd_header {
	u8 type;
	u8 flags;
	u16 length;
	u16 devid;
	u16 cap_ptr;
	u64 mmio_phys;
	u16 pci_seg;
	u16 info;
	u32 efr_attr;

	/* Following only valid on IVHD type 11h and 40h */
	u64 efr_reg; /* Exact copy of MMIO_EXT_FEATURES */
	u64 res;
} __packed;

/*
 * A device entry describing which devices a specific IOMMU translates and
 * which requestor ids they use.
 */
struct ivhd_entry {
	u8 type;
	u16 devid;
	u8 flags;
	u32 ext;
	u32 hidh;
	u64 cid;
	u8 uidf;
	u8 uidl;
	u8 uid;
} __packed;

LIST_HEAD(la_rlookup_iommu_list);
LIST_HEAD(la_iommu_list);			/* list of all loongarch
						 * IOMMUs in the system
						 */

static u32 rlookup_table_size;			/* size if the rlookup table */
static int la_iommu_target_ivhd_type;
u16	la_iommu_last_bdf;			/* largest PCI device id
						 *  we have to handle
						 */

static struct iommu_ops la_iommu_ops;
int loongarch_iommu_disable;

static void iommu_write_regl(struct loongarch_iommu *iommu,
		unsigned long off, u32 val)
{
	*(u32 *)(iommu->membase + off) = val;
	/*
	 * Make sure that iommu->membase data is updated before writing the other date:
	 * iommu->membase data is written to the memory, the other data is read from the
	 * memory, thus we need a full memory barrier to ensure the ordering.
	 */
	mb();
}

static u32 iommu_read_regl(struct loongarch_iommu *iommu, unsigned long off)
{
	u32 val;

	val = *(u32 *)(iommu->membase + off);
	/*
	 * Make sure that iommu->membase data is updated before writing the other date:
	 * iommu->membase data is written to the memory, the other data is read from the
	 * memory, thus we need a full memory barrier to ensure the ordering.
	 */
	mb();
	return val;
}

static void iommu_translate_disable(struct loongarch_iommu *iommu)
{
	u32 val;

	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return;
	}

	val = iommu_read_regl(iommu, LA_IOMMU_EIVDB);

	/* Disable */
	val &= ~(1 << 31);
	iommu_write_regl(iommu, LA_IOMMU_EIVDB, val);

	/* Write cmd */
	val = iommu_read_regl(iommu, LA_IOMMU_CMD);
	val &= 0xfffffffc;
	iommu_write_regl(iommu, LA_IOMMU_CMD, val);
}

static void iommu_translate_enable(struct loongarch_iommu *iommu)
{
	u32 val = 0;

	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return;
	}

	val = iommu_read_regl(iommu, LA_IOMMU_EIVDB);

	/* Enable */
	val |= (1 << 31);
	iommu_write_regl(iommu, LA_IOMMU_EIVDB, val);

	/* Write cmd */
	val = iommu_read_regl(iommu, LA_IOMMU_CMD);
	val &= 0xfffffffc;
	iommu_write_regl(iommu, LA_IOMMU_CMD, val);
}

static bool la_iommu_capable(enum iommu_cap cap)
{
	switch (cap) {
	case IOMMU_CAP_CACHE_COHERENCY:
		return true;
	default:
		return false;
	}
}

static struct dom_info *to_dom_info(struct iommu_domain *dom)
{
	return container_of(dom, struct dom_info, domain);
}

/*
 * Check whether the system has a priv.
 * If yes, it returns 1 and if not, it returns 0
 */
static int has_dom(struct loongarch_iommu *iommu)
{
	spin_lock(&iommu->dom_info_lock);
	while (!list_empty(&iommu->dom_list)) {
		spin_unlock(&iommu->dom_info_lock);
		return 1;
	}
	spin_unlock(&iommu->dom_info_lock);

	return 0;
}

static int update_dev_table(struct la_iommu_dev_data *dev_data, int flag)
{
	u32 val = 0;
	int index;
	unsigned short bdf;
	struct loongarch_iommu *iommu;
	u16 domain_id;

	if (dev_data == NULL) {
		pr_err("%s dev_data is NULL", __func__);
		return 0;
	}

	if (dev_data->iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return 0;
	}

	if (dev_data->iommu_entry == NULL) {
		pr_err("%s iommu_entry is NULL", __func__);
		return 0;
	}

	iommu = dev_data->iommu;
	domain_id = dev_data->iommu_entry->id;
	bdf = dev_data->bdf;

	/* Set device table */
	if (flag) {
		index = find_first_zero_bit(iommu->devtable_bitmap,
						MAX_ATTACHED_DEV_ID);
		if (index < MAX_ATTACHED_DEV_ID) {
			__set_bit(index, iommu->devtable_bitmap);
			dev_data->index = index;
		} else {
			pr_err("%s get id from dev table failed\n", __func__);
			return 0;
		}

		pr_info(
		"%s bdf %x domain_id %d index %x iommu segment %d flag %x\n",
				__func__, bdf, domain_id, index,
				iommu->segment, flag);

		val = bdf & 0xffff;
		val |= ((domain_id & 0xf) << 16);	/* domain id */
		val |= ((index & 0xf) << 24);		/* index */
		val |= (0x1 << 20);			/* valid */
		val |= (0x1 << 31);			/* enable */
		iommu_write_regl(iommu, LA_IOMMU_EIVDB, val);

		val = iommu_read_regl(iommu, LA_IOMMU_CMD);
		val &= 0xfffffffc;
		iommu_write_regl(iommu, LA_IOMMU_CMD, val);
	} else {
		/* Flush device table */
		index = dev_data->index;
		pr_info(
		"%s bdf %x domain_id %d index %x iommu segment %d flag %x\n",
				__func__, bdf, domain_id, index,
				iommu->segment, flag);

		val = iommu_read_regl(iommu, LA_IOMMU_EIVDB);
		val &= ~(0x7fffffff);
		val |= ((index & 0xf) << 24);		/* index */
		iommu_write_regl(iommu, LA_IOMMU_EIVDB, val);

		val = iommu_read_regl(iommu, LA_IOMMU_CMD);
		val &= 0xfffffffc;
		iommu_write_regl(iommu, LA_IOMMU_CMD, val);

		if (index < MAX_ATTACHED_DEV_ID)
			__clear_bit(index, iommu->devtable_bitmap);
	}

	return 0;
}

static void flush_iotlb(struct loongarch_iommu *iommu)
{
	u32 val, cmd;

	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return;
	}

	val = iommu_read_regl(iommu, LA_IOMMU_VBTC);
	val &= ~0x1f;

	/* Flush all tlb */
	val |= 0x5;
	iommu_write_regl(iommu, LA_IOMMU_VBTC, val);

	cmd = iommu_read_regl(iommu, LA_IOMMU_CMD);
	cmd &= 0xfffffffc;
	iommu_write_regl(iommu, LA_IOMMU_CMD, cmd);
}

static int flush_pgtable_is_busy(struct loongarch_iommu *iommu)
{
	u32 val;

	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return 0;
	}

	val = iommu_read_regl(iommu, LA_IOMMU_VBTC);

	return val & IOMMU_PGTABLE_BUSY;
}

static int loongarch_iommu_flush_iotlb_all(struct loongarch_iommu *iommu)
{
	u32 retry = 0;

	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return 0;
	}

	flush_iotlb(iommu);
	while (flush_pgtable_is_busy(iommu)) {
		if (retry == LOOP_TIMEOUT) {
			pr_err("LA-IOMMU: iotlb flush busy\n");
			return -EIO;
		}
		retry++;
		udelay(1);
	}
	iommu_translate_enable(iommu);

	return 0;
}

static void priv_flush_iotlb_pde(struct loongarch_iommu *iommu)
{
	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return;
	}

	loongarch_iommu_flush_iotlb_all(iommu);
}

static void do_attach(struct iommu_info *info, struct la_iommu_dev_data *dev_data)
{
	if (!dev_data->count)
		return;

	dev_data->iommu_entry = info;

	spin_lock(&info->devlock);
	list_add(&dev_data->list, &info->dev_list);
	info->dev_cnt += 1;
	spin_unlock(&info->devlock);

	update_dev_table(dev_data, 1);
	if (info->dev_cnt > 0)
		priv_flush_iotlb_pde(dev_data->iommu);
}

static void do_detach(struct la_iommu_dev_data *dev_data)
{
	struct iommu_info *iommu_entry = NULL;

	if (dev_data == NULL) {
		pr_err("%s dev_data is NULL", __func__);
		return;
	}

	if (dev_data->count)
		return;

	iommu_entry = dev_data->iommu_entry;
	if (iommu_entry == NULL) {
		pr_err("%s iommu_entry is NULL", __func__);
		return;
	}

	list_del(&dev_data->list);
	iommu_entry->dev_cnt -= 1;

	update_dev_table(dev_data, 0);
	dev_data->iommu_entry = NULL;
}

static void cleanup_iommu_entry(struct iommu_info *iommu_entry)
{
	struct la_iommu_dev_data *dev_data = NULL;

	spin_lock(&iommu_entry->devlock);
	while (!list_empty(&iommu_entry->dev_list)) {
		dev_data = list_first_entry(&iommu_entry->dev_list,
				struct la_iommu_dev_data, list);
		do_detach(dev_data);
	}
	spin_unlock(&iommu_entry->devlock);
}

static int domain_id_alloc(struct loongarch_iommu *iommu)
{
	int id = -1;

	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return id;
	}

	spin_lock(&iommu->domain_bitmap_lock);
	id = find_first_zero_bit(iommu->domain_bitmap,
			MAX_DOMAIN_ID);
	if (id < MAX_DOMAIN_ID)
		__set_bit(id, iommu->domain_bitmap);
	else
		pr_err("LA-IOMMU: Alloc domain id over max domain id\n");

	spin_unlock(&iommu->domain_bitmap_lock);

	return id;
}

static void domain_id_free(struct loongarch_iommu *iommu, int id)
{
	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return;
	}

	spin_lock(&iommu->domain_bitmap_lock);
	if ((id >= 0) && (id < MAX_DOMAIN_ID))
		__clear_bit(id, iommu->domain_bitmap);

	spin_unlock(&iommu->domain_bitmap_lock);
}

/*
 *  This function adds a private domain to the global domain list
 */
static void add_domain_to_list(struct loongarch_iommu *iommu, struct dom_info *priv)
{
	spin_lock(&iommu->dom_info_lock);
	list_add(&priv->list, &iommu->dom_list);
	spin_unlock(&iommu->dom_info_lock);
}

static void del_domain_from_list(struct loongarch_iommu *iommu, struct dom_info *priv)
{
	spin_lock(&iommu->dom_info_lock);
	list_del(&priv->list);
	spin_unlock(&iommu->dom_info_lock);
}

static struct spt_entry *iommu_zalloc_page(struct loongarch_iommu *iommu)
{
	int index;
	struct spt_entry *shd_entry;
	void *addr;

	spin_lock(&iommu->pgtable_bitmap_lock);
	index = find_first_zero_bit(iommu->pgtable_bitmap, iommu->maxpages);
	if (index < iommu->maxpages)
		__set_bit(index, iommu->pgtable_bitmap);
	spin_unlock(&iommu->pgtable_bitmap_lock);

	shd_entry = NULL;
	if (index < iommu->maxpages) {
		shd_entry = kmalloc(sizeof(*shd_entry), GFP_KERNEL);
		if (!shd_entry)
			goto fail;

		shd_entry->shadow_ptable = (unsigned long *)get_zeroed_page(GFP_KERNEL);
		if (!shd_entry->shadow_ptable) {
			pr_err("LA-IOMMU: get zeroed page err\n");
			kfree(shd_entry);
			goto fail;
		}

		addr = iommu->pgtbase + index * IOMMU_PAGE_SIZE;
		memset(addr, 0x0, IOMMU_PAGE_SIZE);
		shd_entry->index = index;
		shd_entry->gmem_ptable = addr;
	}

	return shd_entry;
fail:
	spin_lock(&iommu->pgtable_bitmap_lock);
	__clear_bit(index, iommu->pgtable_bitmap);
	spin_unlock(&iommu->pgtable_bitmap_lock);
	return NULL;
}

static void iommu_free_page(struct loongarch_iommu *iommu,
		struct spt_entry *shadw_entry)
{
	void *addr;

	if (shadw_entry->index < iommu->maxpages) {
		addr = shadw_entry->gmem_ptable;
		memset(addr, 0x0, IOMMU_PAGE_SIZE);

		spin_lock(&iommu->pgtable_bitmap_lock);
		__clear_bit(shadw_entry->index, iommu->pgtable_bitmap);
		spin_unlock(&iommu->pgtable_bitmap_lock);

		shadw_entry->index = -1;
		free_page((unsigned long)shadw_entry->shadow_ptable);
		shadw_entry->shadow_ptable = NULL;
		shadw_entry->gmem_ptable = NULL;
		kfree(shadw_entry);
	}
}

static void free_pagetable_one_level(struct iommu_info *iommu_entry, struct spt_entry *shd_entry,
		int level)
{
	int i;
	unsigned long *psentry;
	struct spt_entry *shd_entry_tmp;
	struct loongarch_iommu *iommu = iommu_entry->iommu;

	psentry = (unsigned long *)shd_entry;
	if (level == IOMMU_PT_LEVEL1) {
		if (iommu_pt_present(psentry) && (!iommu_pt_huge(psentry)))
			iommu_free_page(iommu, shd_entry);
		return;
	}

	for (i = 0; i < IOMMU_PTRS_PER_LEVEL; i++) {
		psentry = shd_entry->shadow_ptable + i;
		if (!iommu_pt_present(psentry))
			continue;

		shd_entry_tmp = (struct spt_entry *)(*psentry);
		free_pagetable_one_level(iommu_entry, shd_entry_tmp, level - 1);
	}

	iommu_free_page(iommu, shd_entry);
}

static void free_pagetable(struct iommu_info *iommu_entry)
{
	struct spt_entry *shd_entry;
	struct loongarch_iommu *iommu;

	iommu = iommu_entry->iommu;
	shd_entry = iommu_entry->shadow_pgd;
	free_pagetable_one_level(iommu_entry, shd_entry, IOMMU_LEVEL_MAX);
	iommu_entry->shadow_pgd = NULL;
}

static struct dom_info *alloc_dom_info(void)
{
	struct dom_info *info;

	info = kzalloc(sizeof(*info), GFP_KERNEL);
	if (info == NULL)
		return NULL;

	/* 0x10000000~0x8fffffff */
	info->mmio_pgd = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, 6);
	if (info->mmio_pgd == NULL) {
		pr_err("%s alloc virtio pgtable failed\n", __func__);
		kfree(info);
		return NULL;
	}

	INIT_LIST_HEAD(&info->iommu_devlist);
	spin_lock_init(&info->lock);
	return info;
}

static void dom_info_free(struct dom_info *info)
{
	/* 0x10000000~0x8fffffff */
	if (info->mmio_pgd) {
		free_pages((unsigned long)info->mmio_pgd, 6);
		info->mmio_pgd = NULL;
	}

	kfree(info);
}

static struct iommu_domain *la_iommu_domain_alloc(unsigned int type)
{
	struct dom_info *info;

	switch (type) {
	case IOMMU_DOMAIN_UNMANAGED:
		info = alloc_dom_info();
		if (info == NULL)
			return NULL;

		info->domain.geometry.aperture_start	= 0;
		info->domain.geometry.aperture_end	= ~0ULL;
		info->domain.geometry.force_aperture	= true;
		break;
	default:
		return NULL;
	}

	return &info->domain;
}

void domain_deattach_iommu(struct dom_info *priv, struct iommu_info *iommu_entry)
{
	struct loongarch_iommu *iommu = NULL;

	if (priv == NULL) {
		pr_err("%s priv is NULL", __func__);
		return;
	}

	if (iommu_entry == NULL) {
		pr_err("%s iommu_entry is NULL", __func__);
		return;
	}

	if (iommu_entry->dev_cnt != 0)
		return;

	iommu = iommu_entry->iommu;
	if (iommu == NULL) {
		pr_err("%s iommu is NULL", __func__);
		return;
	}

	domain_id_free(iommu_entry->iommu, iommu_entry->id);

	mutex_lock(&iommu->la_iommu_pgtlock);
	free_pagetable(iommu_entry);
	mutex_unlock(&iommu->la_iommu_pgtlock);

	spin_lock(&priv->lock);
	list_del(&iommu_entry->list);
	spin_unlock(&priv->lock);

	kfree(iommu_entry);
	del_domain_from_list(iommu, priv);

}

static void la_iommu_domain_free(struct iommu_domain *domain)
{

	struct dom_info *priv;
	struct loongarch_iommu *iommu = NULL;
	struct iommu_info *iommu_entry, *iommu_entry_temp;

	priv = to_dom_info(domain);

	spin_lock(&priv->lock);
	list_for_each_entry_safe(iommu_entry, iommu_entry_temp,
			&priv->iommu_devlist, list) {
		iommu = iommu_entry->iommu;

		if (iommu_entry->dev_cnt > 0)
			cleanup_iommu_entry(iommu_entry);

		spin_unlock(&priv->lock);
		domain_deattach_iommu(priv, iommu_entry);
		spin_lock(&priv->lock);

		loongarch_iommu_flush_iotlb_all(iommu);

		if (!has_dom(iommu))
			iommu_translate_disable(iommu);

	}
	spin_unlock(&priv->lock);

	dom_info_free(priv);
}

struct la_iommu_rlookup_entry *lookup_rlooptable(int pcisegment)
{
	struct la_iommu_rlookup_entry *rlookupentry = NULL;

	list_for_each_entry(rlookupentry, &la_rlookup_iommu_list, list) {
		if (rlookupentry->pcisegment == pcisegment)
			return rlookupentry;
	}

	return NULL;
}

struct loongarch_iommu *find_iommu_by_dev(struct pci_dev  *pdev)
{
	int pcisegment;
	unsigned short devid;
	struct la_iommu_rlookup_entry *rlookupentry = NULL;
	struct loongarch_iommu *iommu = NULL;

	devid = pdev->devfn & 0xff;

	pcisegment = pci_domain_nr(pdev->bus);

	rlookupentry = lookup_rlooptable(pcisegment);
	if (rlookupentry == NULL) {
		pr_info("%s find segment %d rlookupentry failed\n", __func__,
				pcisegment);
		return iommu;
	}

	iommu = rlookupentry->la_iommu_rlookup_table[devid];

	return iommu;
}

static int iommu_init_device(struct device *dev)
{
	struct la_iommu_dev_data *dev_data;
	struct pci_dev	*pdev = to_pci_dev(dev);
	struct pci_bus	*bus = pdev->bus;
	unsigned short bdf, devid;
	unsigned char busnum;
	struct loongarch_iommu *iommu = NULL;

	bdf = pdev->devfn & 0xff;
	busnum = bus->number;
	if (busnum != 0) {
		while (bus->parent->parent)
			bus = bus->parent;
		bdf = bus->self->devfn & 0xff;
	}

	if (dev->archdata.iommu != NULL) {
		pr_info("LA-IOMMU: bdf:0x%x has added\n", bdf);
		return 0;
	}

	dev_data = kzalloc(sizeof(*dev_data), GFP_KERNEL);
	if (!dev_data)
		return -ENOMEM;

	devid = PCI_DEVID(bus->number, bdf);
	dev_data->bdf = devid;

	pci_info(pdev, "%s devid %x bus %x\n", __func__, devid, busnum);
	iommu = find_iommu_by_dev(pdev);
	if (iommu == NULL)
		pci_info(pdev, "%s find iommu failed by dev\n", __func__);

	/* The initial state is 0, and 1 is added only when attach dev */
	dev_data->count = 0;
	dev_data->iommu = iommu;

	dev->archdata.iommu = dev_data;

	return 0;
}

static struct iommu_device *la_iommu_probe_device(struct device *dev)
{
	int ret = 0;

	ret = iommu_init_device(dev);
	if (ret)
		pr_err("LA-IOMMU: unable to alloc dev_data\n");

	return NULL;
}

static struct iommu_group *la_iommu_device_group(struct device *dev)
{
	struct iommu_group *group;

	/*
	 * We don't support devices sharing stream IDs other than PCI RID
	 * aliases, since the necessary ID-to-device lookup becomes rather
	 * impractical given a potential sparse 32-bit stream ID space.
	 */
	if (dev_is_pci(dev))
		group = pci_device_group(dev);
	else
		group = generic_device_group(dev);

	return group;
}

static void la_iommu_release_device(struct device *dev)
{
	struct la_iommu_dev_data *dev_data;

	dev_data = dev->archdata.iommu;
	dev->archdata.iommu = NULL;
	kfree(dev_data);

}

struct iommu_info *get_first_iommu_entry(struct dom_info *priv)
{
	struct iommu_info *iommu_entry;

	if (priv == NULL) {
		pr_err("%s priv is NULL", __func__);
		return NULL;
	}

	iommu_entry = list_first_entry_or_null(&priv->iommu_devlist,
			struct iommu_info, list);

	return iommu_entry;
}

struct iommu_info *get_iommu_entry(struct dom_info *priv, struct loongarch_iommu *iommu)
{
	struct iommu_info *iommu_entry;

	spin_lock(&priv->lock);
	list_for_each_entry(iommu_entry, &priv->iommu_devlist, list) {
		if (iommu_entry->iommu == iommu) {
			spin_unlock(&priv->lock);
			return iommu_entry;
		}
	}
	spin_unlock(&priv->lock);

	return NULL;
}

struct iommu_info *domain_attach_iommu(struct dom_info *priv, struct loongarch_iommu *iommu)
{
	unsigned long pgd_pa;
	u32 dir_ctrl, pgd_lo, pgd_hi;
	struct iommu_info *iommu_entry = NULL;
	struct spt_entry *shd_entry = NULL;

	iommu_entry = get_iommu_entry(priv, iommu);
	if (iommu_entry)
		return iommu_entry;

	iommu_entry = kzalloc(sizeof(struct iommu_info), GFP_KERNEL);
	if (iommu_entry == NULL)
		return NULL;

	INIT_LIST_HEAD(&iommu_entry->dev_list);
	iommu_entry->iommu = iommu;
	iommu_entry->id = domain_id_alloc(iommu);
	if (iommu_entry->id == -1) {
		pr_info("%s alloc id for domain failed\n", __func__);
		kfree(iommu_entry);
		return NULL;
	}

	shd_entry = iommu_zalloc_page(iommu);
	if (!shd_entry) {
		pr_info("%s shadow page entry err\n", __func__);
		domain_id_free(iommu, iommu_entry->id);
		kfree(iommu_entry);
		return NULL;
	}

	iommu_entry->shadow_pgd = shd_entry;
	dir_ctrl = (IOMMU_LEVEL_STRIDE << 26) | (IOMMU_LEVEL_SHIFT(2) << 20);
	dir_ctrl |= (IOMMU_LEVEL_STRIDE <<  16) | (IOMMU_LEVEL_SHIFT(1) << 10);
	dir_ctrl |= (IOMMU_LEVEL_STRIDE << 6) | IOMMU_LEVEL_SHIFT(0);
	pgd_pa = iommu_pgt_v2p(iommu, shd_entry->gmem_ptable);
	pgd_hi = pgd_pa >> 32;
	pgd_lo = pgd_pa & 0xffffffff;
	iommu_write_regl(iommu, LA_IOMMU_DIR_CTRL(iommu_entry->id), dir_ctrl);
	iommu_write_regl(iommu, LA_IOMMU_PGD_HI(iommu_entry->id), pgd_hi);
	iommu_write_regl(iommu, LA_IOMMU_PGD_LO(iommu_entry->id), pgd_lo);

	spin_lock(&priv->lock);
	list_add(&iommu_entry->list, &priv->iommu_devlist);
	spin_unlock(&priv->lock);

	add_domain_to_list(iommu, priv);
	pr_info("%s iommu_entry->iommu %lx id %x\n", __func__,
	       (unsigned long)iommu_entry->iommu, iommu_entry->id);

	return iommu_entry;
}

static struct la_iommu_dev_data *iommu_get_devdata(struct dom_info *info,
		struct loongarch_iommu *iommu, unsigned long bdf)
{
	struct iommu_info *entry;
	struct la_iommu_dev_data *dev_data;

	entry = get_iommu_entry(info, iommu);
	if (!entry)
		return NULL;

	/* Find from priv list */
	spin_lock(&entry->devlock);
	list_for_each_entry(dev_data, &entry->dev_list, list) {
		if (dev_data->bdf == bdf) {
			spin_unlock(&entry->devlock);
			return dev_data;
		}
	}
	spin_unlock(&entry->devlock);
	return NULL;
}

static int la_iommu_attach_dev(struct iommu_domain *domain, struct device *dev)
{
	struct dom_info *priv = to_dom_info(domain);
	struct pci_dev  *pdev = to_pci_dev(dev);
	struct pci_bus  *bus = pdev->bus;
	unsigned char busnum = pdev->bus->number;
	struct la_iommu_dev_data *dev_data;
	struct loongarch_iommu *iommu;
	struct iommu_info *iommu_entry = NULL;
	unsigned short bdf;

	bdf = pdev->devfn & 0xff;
	if (busnum != 0) {
		while (bus->parent->parent)
			bus = bus->parent;
		bdf = bus->self->devfn & 0xff;
	}

	dev_data = (struct la_iommu_dev_data *)dev->archdata.iommu;
	if (dev_data == NULL) {
		pci_info(pdev, "%s dev_data is Invalid\n", __func__);
		return 0;
	}

	iommu = dev_data->iommu;
	if (iommu == NULL) {
		pci_info(pdev, "%s iommu is Invalid\n", __func__);
		return 0;
	}

	pci_info(pdev, "%s busnum %x bdf %x priv %lx iommu %lx\n", __func__,
			busnum, bdf, (unsigned long)priv, (unsigned long)iommu);
	dev_data = iommu_get_devdata(priv, iommu, bdf);
	if (!dev_data) {
		dev_data = (struct la_iommu_dev_data *)dev->archdata.iommu;
	} else {
		dev_data->count++;
		pci_info(pdev,
		"LA-IOMMU: bdf 0x%x devfn %x has attached, count:0x%x\n",
			bdf, pdev->devfn, dev_data->count);
		return 0;
	}

	iommu_entry = domain_attach_iommu(priv, iommu);
	if (iommu_entry == NULL) {
		pci_info(pdev, "domain attach iommu failed\n");
		return 0;
	}

	dev_data->count++;
	do_attach(iommu_entry, dev_data);

	return 0;
}

static void la_iommu_detach_dev(struct iommu_domain *domain,
				 struct device *dev)
{
	struct dom_info *priv = to_dom_info(domain);
	struct pci_dev *pdev = to_pci_dev(dev);
	struct pci_bus *bus = pdev->bus;
	unsigned char busnum = pdev->bus->number;
	struct la_iommu_dev_data *dev_data;
	struct loongarch_iommu *iommu;
	struct iommu_info *iommu_entry = NULL;
	unsigned short bdf;

	bdf = pdev->devfn & 0xff;
	if (busnum != 0) {
		while (bus->parent->parent)
			bus = bus->parent;
		bdf = bus->self->devfn & 0xff;
	}

	dev_data = (struct la_iommu_dev_data *)dev->archdata.iommu;
	if (dev_data == NULL) {
		pci_info(pdev, "%s dev_data is Invalid\n", __func__);
		return;
	}

	iommu = dev_data->iommu;
	if (iommu == NULL) {
		pci_info(pdev, "%s iommu is Invalid\n", __func__);
		return;
	}

	dev_data = iommu_get_devdata(priv, iommu, bdf);
	if (dev_data == NULL) {
		pci_info(pdev, "LA-IOMMU: bdf 0x%x devfn %x dev_data is NULL\n",
			bdf, pdev->devfn & 0xff);
			return;
	}

	iommu = dev_data->iommu;
	dev_data->count--;
	iommu_entry = get_iommu_entry(priv, iommu);
	if (iommu_entry == NULL) {
		pci_info(pdev, "%s get iommu_entry failed\n", __func__);
		return;
	}

	spin_lock(&iommu_entry->devlock);
	do_detach(dev_data);
	spin_unlock(&iommu_entry->devlock);

	pci_info(pdev, "%s iommu devid  %x sigment %x\n", __func__,
			iommu->devid, iommu->segment);
}

static unsigned long *iommu_get_spte(struct spt_entry *entry, unsigned long iova,
		int level)
{
	int i;
	unsigned long *pte;

	if (level > (IOMMU_LEVEL_MAX - 1))
		return NULL;

	for (i = IOMMU_LEVEL_MAX - 1; i >= level; i--) {
		pte  = iommu_shadow_offset(entry, iova, i);
		if (!iommu_pt_present(pte))
			break;

		if (iommu_pt_huge(pte))
			break;

		entry = (struct spt_entry *)(*pte);
	}

	return pte;
}

static int _iommu_alloc_ptable(struct loongarch_iommu *iommu,
		unsigned long *psentry, unsigned long *phwentry)
{
	struct iommu_pte *new_phwentry;
	struct spt_entry *new_shd_entry;
	unsigned long pte;

	if (!iommu_pt_present(psentry)) {
		new_shd_entry = iommu_zalloc_page(iommu);
		if (!new_shd_entry) {
			pr_err("LA-IOMMU: new_shd_entry alloc err\n");
			return -ENOMEM;
		}
		/* fill shd_entry */
		*psentry = (unsigned long)new_shd_entry;
		/* fill gmem phwentry */
		new_phwentry = (struct iommu_pte *)new_shd_entry->gmem_ptable;
		pte = iommu_pgt_v2p(iommu, new_phwentry) & IOMMU_PAGE_MASK;
		pte |= IOMMU_PTE_RW;
		*phwentry = pte;
	}

	return 0;
}

static size_t iommu_ptw_map(struct loongarch_iommu *iommu, struct spt_entry *shd_entry,
		unsigned long start, unsigned long end, phys_addr_t pa, int level)
{
	unsigned long next, old, step;
	unsigned long *psentry, *phwentry;
	unsigned long pte;
	int ret, huge;

	old = start;
	psentry = iommu_shadow_offset(shd_entry, start, level);
	phwentry = iommu_ptable_offset(shd_entry->gmem_ptable, start, level);
	if (level == IOMMU_PT_LEVEL0) {
		pa = pa & IOMMU_PAGE_MASK;
		do {
			pte =  pa | IOMMU_PTE_RW;
			*phwentry = pte;
			*psentry = pte;
			psentry++;
			phwentry++;
			start += IOMMU_PAGE_SIZE;
			pa += IOMMU_PAGE_SIZE;
		} while (start < end);

		return start - old;
	}

	do {
		next = iommu_ptable_end(start, end, level);
		step = next - start;

		huge = 0;
		if ((level == IOMMU_PT_LEVEL1) && (step == IOMMU_HPAGE_SIZE))
			if (!iommu_pt_present(psentry) || iommu_pt_huge(psentry))
				huge = 1;

		if (huge) {
			pte =  (pa & IOMMU_HPAGE_MASK) | IOMMU_PTE_RW | IOMMU_PTE_HP;
			*phwentry = pte;
			*psentry = pte;
		} else {
			ret = _iommu_alloc_ptable(iommu, psentry, phwentry);
			if (ret != 0)
				break;
			iommu_ptw_map(iommu, (struct spt_entry *)*psentry, start, next, pa,
					level - 1);
		}

		psentry++;
		phwentry++;
		pa += step;
		start = next;
	} while (start < end);
	return start - old;
}

static int dev_map_page(struct iommu_info *iommu_entry, unsigned long start,
			phys_addr_t pa, size_t size)
{
	int ret = 0;
	struct spt_entry *entry;
	phys_addr_t end;
	size_t map_size;
	struct loongarch_iommu *iommu;

	end = start + size;
	iommu = iommu_entry->iommu;

	mutex_lock(&iommu->la_iommu_pgtlock);
	entry = iommu_entry->shadow_pgd;
	map_size = iommu_ptw_map(iommu, entry, start, end, pa, IOMMU_LEVEL_MAX - 1);
	if (map_size != size)
		ret = -EFAULT;

	if (has_dom(iommu))
		loongarch_iommu_flush_iotlb_all(iommu);
	mutex_unlock(&iommu->la_iommu_pgtlock);

	return ret;
}

static size_t iommu_ptw_unmap(struct loongarch_iommu *iommu, struct spt_entry *shd_entry,
		unsigned long start, unsigned long end, int level)
{
	unsigned long next, old;
	unsigned long *psentry, *phwentry;

	old = start;
	psentry = iommu_shadow_offset(shd_entry, start, level);
	phwentry = iommu_ptable_offset(shd_entry->gmem_ptable, start, level);
	if (level == IOMMU_PT_LEVEL0) {
		do {
			*phwentry++ = 0;
			*psentry++ = 0;
			start += IOMMU_PAGE_SIZE;
		} while (start < end);
	} else {
		do {
			next = iommu_ptable_end(start, end, level);
			if (!iommu_pt_present(psentry))
				continue;

			if (iommu_pt_huge(psentry)) {
				if ((next - start) != IOMMU_HPAGE_SIZE)
					pr_err("Map pte on hugepage not supported now\n");
				*phwentry = 0;
				*psentry = 0;
			} else
				iommu_ptw_unmap(iommu, (struct spt_entry *)*psentry, start, next,
						level - 1);
		} while (psentry++, phwentry++, start = next, start < end);
	}

	return start - old;
}

static int iommu_map_page(struct dom_info *priv, unsigned long start,
			phys_addr_t pa, size_t size, int prot, gfp_t gfp)
{

	unsigned long *pte;
	int ret = 0;
	struct iommu_info *iommu_entry = NULL;

	/* 0x10000000~0x8fffffff */
	if ((start >= SZ_256M) && (start < LOONGSON_HIGHMEM_START)) {
		start -= SZ_256M;
		pte = (unsigned long *)priv->mmio_pgd;
		while (size > 0) {
			pte[start >> LA_VIRTIO_PAGE_SHIFT] =
					pa & LA_VIRTIO_PAGE_MASK;
			size -= IOMMU_PAGE_SIZE;
			start += IOMMU_PAGE_SIZE;
			pa += IOMMU_PAGE_SIZE;
		}
		return 0;
	}

	spin_lock(&priv->lock);
	list_for_each_entry(iommu_entry, &priv->iommu_devlist, list) {
		ret |= dev_map_page(iommu_entry, start, pa, size);
	}
	spin_unlock(&priv->lock);

	return ret;
}

static size_t iommu_unmap_page(struct iommu_info *iommu_entry, unsigned long start, size_t size)
{
	struct loongarch_iommu *iommu;
	struct spt_entry *entry;
	size_t unmap_len;
	unsigned long end;

	end = start + size;
	iommu = iommu_entry->iommu;
	mutex_lock(&iommu->la_iommu_pgtlock);
	entry = iommu_entry->shadow_pgd;
	unmap_len = iommu_ptw_unmap(iommu, entry, start, end, (IOMMU_LEVEL_MAX - 1));

	if (has_dom(iommu))
		loongarch_iommu_flush_iotlb_all(iommu);
	mutex_unlock(&iommu->la_iommu_pgtlock);
	return unmap_len;
}

static size_t domain_unmap_page(struct dom_info *info, unsigned long start, size_t size)
{
	unsigned long *pte;
	size_t unmap_len = 0;
	struct iommu_info *entry;

	/* 0x10000000~0x8fffffff */
	if ((start >= SZ_256M) && (start < LOONGSON_HIGHMEM_START)) {
		start -= SZ_256M;
		pte = (unsigned long *)info->mmio_pgd;
		while (size > 0) {
			pte[start >> LA_VIRTIO_PAGE_SHIFT] = 0;
			size -= 0x4000;
			unmap_len += 0x4000;
			start += 0x4000;
		}
		unmap_len += size;

		return unmap_len;
	}

	spin_lock(&info->lock);
	list_for_each_entry(entry, &info->iommu_devlist, list)
		unmap_len = iommu_unmap_page(entry, start, size);
	spin_unlock(&info->lock);

	return unmap_len;
}

static int la_iommu_map(struct iommu_domain *domain, unsigned long iova,
			 phys_addr_t pa, size_t len, int prot, gfp_t gfp)
{
	struct dom_info *priv = to_dom_info(domain);
	int ret;

	ret = iommu_map_page(priv, iova, pa, len, prot, GFP_KERNEL);

	return ret;
}

static size_t la_iommu_unmap(struct iommu_domain *domain, unsigned long iova,
				size_t size, struct iommu_iotlb_gather *iotlb_gather)
{
	struct dom_info *priv = to_dom_info(domain);
	int ret;

	ret = domain_unmap_page(priv, iova, size);

	return ret;
}

static phys_addr_t la_iommu_iova_to_pa(struct iommu_domain *domain,
						dma_addr_t iova)
{
	struct dom_info *priv = to_dom_info(domain);
	unsigned long *psentry, *pte;
	unsigned long pa, offset, tmpva, page_size, page_mask;
	int ret = 0;
	struct spt_entry *entry;
	struct loongarch_iommu *iommu;
	struct iommu_info *iommu_entry = NULL;

	/* 0x10000000~0x8fffffff */
	if ((iova >= SZ_256M) && (iova < LOONGSON_HIGHMEM_START)) {
		tmpva = iova & LA_VIRTIO_PAGE_MASK;
		pte = (unsigned long *)priv->mmio_pgd;
		offset = iova & ((1ULL << LA_VIRTIO_PAGE_SHIFT) - 1);
		pa = pte[(tmpva - SZ_256M) >> 14] + offset;

		return pa;
	}

	iommu_entry = get_first_iommu_entry(priv);
	if (iommu_entry == NULL) {
		pr_err("%s iova:0x%llx iommu_entry is invalid\n",
				__func__, iova);
		ret = -EFAULT;
		return ret;
	}

	iommu = iommu_entry->iommu;

	mutex_lock(&iommu->la_iommu_pgtlock);
	entry = iommu_entry->shadow_pgd;
	psentry = iommu_get_spte(entry, iova, IOMMU_PT_LEVEL0);
	mutex_unlock(&iommu->la_iommu_pgtlock);

	if (!psentry || !iommu_pt_present(psentry)) {
		ret = -EFAULT;
		pr_warn_once("LA-IOMMU: shadow pte is null or not present with iova %llx\n", iova);
		return ret;
	}

	if (iommu_pt_huge(psentry)) {
		page_size = IOMMU_HPAGE_SIZE;
		page_mask = IOMMU_HPAGE_MASK;
	} else {
		page_size = IOMMU_PAGE_SIZE;
		page_mask = IOMMU_PAGE_MASK;
	}

	pa = *psentry & page_mask;
	pa |= (iova & (page_size - 1));
	return (phys_addr_t)pa;
}

static phys_addr_t la_iommu_iova_to_phys(struct iommu_domain *domain,
					dma_addr_t iova)
{
	phys_addr_t pa;

	pa = la_iommu_iova_to_pa(domain, iova);

	return pa;
}

static void la_iommu_flush_iotlb_all(struct iommu_domain *domain)
{
	struct dom_info *priv = to_dom_info(domain);
	struct iommu_info *iommu_entry;
	struct loongarch_iommu *iommu;
	int ret;

	spin_lock(&priv->lock);
	list_for_each_entry(iommu_entry, &priv->iommu_devlist, list) {
		iommu = iommu_entry->iommu;

		ret = loongarch_iommu_flush_iotlb_all(iommu);
	}
	spin_unlock(&priv->lock);
}

static void la_iommu_flush_iotlb_sync(struct iommu_domain *domain,
		struct iommu_iotlb_gather *iotlb_gather)
{
	la_iommu_flush_iotlb_all(domain);
}

static struct iommu_ops la_iommu_ops = {
	.capable = la_iommu_capable,
	.domain_alloc = la_iommu_domain_alloc,
	.domain_free = la_iommu_domain_free,
	.attach_dev = la_iommu_attach_dev,
	.detach_dev = la_iommu_detach_dev,
	.map = la_iommu_map,
	.unmap = la_iommu_unmap,
	.iova_to_phys = la_iommu_iova_to_phys,
	.probe_device = la_iommu_probe_device,
	.release_device = la_iommu_release_device,
	.device_group = la_iommu_device_group,
	.pgsize_bitmap = LA_IOMMU_PGSIZE,
	.flush_iotlb_all = la_iommu_flush_iotlb_all,
	.iotlb_sync = la_iommu_flush_iotlb_sync,
};

static inline int la_iommu_init_api(void)
{
	int ret;

	ret = bus_set_iommu(&pci_bus_type, &la_iommu_ops);
	return ret;
}

struct loongarch_iommu *loongarch_get_iommu(struct pci_dev *pdev)
{
	int pcisegment;
	unsigned short devid;
	struct loongarch_iommu *iommu = NULL;

	devid = pdev->devfn & 0xff;
	pcisegment = pci_domain_nr(pdev->bus);

	list_for_each_entry(iommu, &la_iommu_list, list) {
		if ((iommu->segment == pcisegment) &&
		    (iommu->devid == devid)) {
			return iommu;
		}
	}

	return NULL;
}

static int loongarch_iommu_probe(struct pci_dev *pdev,
				const struct pci_device_id *ent)
{
	int ret = 1;
	int bitmap_sz = 0;
	int tmp;
	struct loongarch_iommu *iommu = NULL;
	resource_size_t base, size;

	iommu = loongarch_get_iommu(pdev);
	if (iommu == NULL) {
		pci_info(pdev, "%s can't find iommu\n", __func__);
		return -ENODEV;
	}

	base = pci_resource_start(pdev, 0);
	size = pci_resource_len(pdev, 0);
	if (!request_mem_region(base, size, "loongarch_iommu")) {
		pci_err(pdev, "base %lx size %lx can't reserve mmio registers\n",
				base, size);
		return -ENOMEM;
	}

	iommu->membase = ioremap(base, size);
	if (iommu->membase == NULL) {
		pci_info(pdev, "%s iommu pci dev bar0 is NULL\n", __func__);
		return ret;
	}

	base = pci_resource_start(pdev, 2);
	size = pci_resource_len(pdev, 2);
	if (!request_mem_region(base, size, "loongarch_iommu")) {
		pci_err(pdev, "can't reserve mmio registers\n");
		return -ENOMEM;
	}
	iommu->pgtbase = ioremap(base, size);
	if (iommu->pgtbase == NULL)
		return -ENOMEM;

	iommu->maxpages = size / IOMMU_PAGE_SIZE;
	pr_info("iommu membase %p pgtbase %p pgtsize %llx maxpages %lx\n", iommu->membase,
		iommu->pgtbase, size, iommu->maxpages);
	tmp = MAX_DOMAIN_ID / 8;
	bitmap_sz = (MAX_DOMAIN_ID % 8) ? (tmp + 1) : tmp;
	iommu->domain_bitmap = bitmap_zalloc(bitmap_sz, GFP_KERNEL);
	if (iommu->domain_bitmap == NULL) {
		pr_err("LA-IOMMU: domain bitmap alloc err bitmap_sz:%d\n",
								bitmap_sz);
		goto out_err;
	}

	tmp = MAX_ATTACHED_DEV_ID / 8;
	bitmap_sz = (MAX_ATTACHED_DEV_ID % 8) ? (tmp + 1) : tmp;
	iommu->devtable_bitmap = bitmap_zalloc(bitmap_sz, GFP_KERNEL);
	if (iommu->devtable_bitmap == NULL) {
		pr_err("LA-IOMMU: devtable bitmap alloc err bitmap_sz:%d\n",
								bitmap_sz);
		goto out_err_1;
	}

	tmp = iommu->maxpages / 8;
	bitmap_sz = (iommu->maxpages % 8) ? (tmp + 1) : tmp;
	iommu->pgtable_bitmap = bitmap_zalloc(bitmap_sz, GFP_KERNEL);
	if (iommu->pgtable_bitmap == NULL) {
		pr_err("LA-IOMMU: pgtable bitmap alloc err bitmap_sz:%d\n",
								bitmap_sz);
		goto out_err_2;
	}

	la_iommu_init_api();

	return 0;

out_err_2:
	kfree(iommu->devtable_bitmap);
	iommu->devtable_bitmap = NULL;
out_err_1:
	kfree(iommu->domain_bitmap);
	iommu->domain_bitmap = NULL;
out_err:

	return ret;
}

static void loongarch_iommu_remove(struct pci_dev *pdev)
{
	struct  loongarch_iommu *iommu = NULL;

	iommu = loongarch_get_iommu(pdev);
	if (iommu == NULL)
		return;

	if (iommu->domain_bitmap != NULL) {
		kfree(iommu->domain_bitmap);
		iommu->domain_bitmap = NULL;
	}

	if (iommu->devtable_bitmap != NULL) {
		kfree(iommu->devtable_bitmap);
		iommu->devtable_bitmap = NULL;
	}

	if (iommu->pgtable_bitmap != NULL) {
		kfree(iommu->pgtable_bitmap);
		iommu->pgtable_bitmap = NULL;
	}

	iommu->membase = NULL;
	iommu->pgtbase = NULL;
}

static int __init check_ivrs_checksum(struct acpi_table_header *table)
{
	int i;
	u8 checksum = 0, *p = (u8 *)table;

	for (i = 0; i < table->length; ++i)
		checksum += p[i];
	if (checksum != 0) {
		/* ACPI table corrupt */
		pr_err("IVRS invalid checksum\n");
		return -ENODEV;
	}

	return 0;
}

struct la_iommu_rlookup_entry *create_rlookup_entry(int pcisegment)
{
	struct la_iommu_rlookup_entry *rlookupentry = NULL;

	rlookupentry = kzalloc(sizeof(struct la_iommu_rlookup_entry),
			GFP_KERNEL);
	if (rlookupentry == NULL)
		return rlookupentry;

	rlookupentry->pcisegment = pcisegment;

	/* IOMMU rlookup table - find the IOMMU for a specific device */
	rlookupentry->la_iommu_rlookup_table = (void *)__get_free_pages(
			GFP_KERNEL | __GFP_ZERO,
			get_order(rlookup_table_size));
	if (rlookupentry->la_iommu_rlookup_table == NULL) {
		kfree(rlookupentry);
		rlookupentry = NULL;
	} else {
		list_add(&rlookupentry->list, &la_rlookup_iommu_list);
	}

	return rlookupentry;
}

/* Writes the specific IOMMU for a device into the rlookup table */
static void __init set_iommu_for_device(struct loongarch_iommu *iommu,
		u16 devid)
{
	struct la_iommu_rlookup_entry *rlookupentry = NULL;

	rlookupentry = lookup_rlooptable(iommu->segment);
	if (rlookupentry == NULL)
		rlookupentry = create_rlookup_entry(iommu->segment);

	if (rlookupentry != NULL)
		rlookupentry->la_iommu_rlookup_table[devid] = iommu;
}

static inline u32 get_ivhd_header_size(struct ivhd_header *h)
{
	u32 size = 0;

	switch (h->type) {
	case IVHD_HEAD_TYPE10:
		size = 24;
		break;
	case IVHD_HEAD_TYPE11:
	case IVHD_HEAD_TYPE40:
		size = 40;
		break;
	}
	return size;
}

static inline void update_last_devid(u16 devid)
{
	if (devid > la_iommu_last_bdf)
		la_iommu_last_bdf = devid;
}

/*
 * This function calculates the length of a given IVHD entry
 */
static inline int ivhd_entry_length(u8 *ivhd)
{
	u32 type = ((struct ivhd_entry *)ivhd)->type;

	if (type < 0x80) {
		return 0x04 << (*ivhd >> 6);
	} else if (type == IVHD_DEV_ACPI_HID) {
		/* For ACPI_HID, offset 21 is uid len */
		return *((u8 *)ivhd + 21) + 22;
	}
	return 0;
}

/*
 * After reading the highest device id from the IOMMU PCI capability header
 * this function looks if there is a higher device id defined in the ACPI table
 */
static int __init find_last_devid_from_ivhd(struct ivhd_header *h)
{
	u8 *p = (void *)h, *end = (void *)h;
	struct ivhd_entry *dev;

	u32 ivhd_size = get_ivhd_header_size(h);

	if (!ivhd_size) {
		pr_err("la-iommu: Unsupported IVHD type %#x\n", h->type);
		return -EINVAL;
	}

	p += ivhd_size;
	end += h->length;

	while (p < end) {
		dev = (struct ivhd_entry *)p;
		switch (dev->type) {
		case IVHD_DEV_ALL:
			/* Use maximum BDF value for DEV_ALL */
			update_last_devid(MAX_BDF_NUM);
			break;
		case IVHD_DEV_SELECT:
		case IVHD_DEV_RANGE_END:
		case IVHD_DEV_ALIAS:
		case IVHD_DEV_EXT_SELECT:
			/* all the above subfield types refer to device ids */
			update_last_devid(dev->devid);
			break;
		default:
			break;
		}
		p += ivhd_entry_length(p);
	}

	WARN_ON(p != end);

	return 0;
}

/*
 * Iterate over all IVHD entries in the ACPI table and find the highest device
 * id which we need to handle. This is the first of three functions which parse
 * the ACPI table. So we check the checksum here.
 */
static int __init find_last_devid_acpi(struct acpi_table_header *table)
{
	u8 *p = (u8 *)table, *end = (u8 *)table;
	struct ivhd_header *h;

	p += IVRS_HEADER_LENGTH;

	end += table->length;
	while (p < end) {
		h = (struct ivhd_header *)p;
		if (h->type == la_iommu_target_ivhd_type) {
			int ret = find_last_devid_from_ivhd(h);

			if (ret)
				return ret;
		}

		if (h->length == 0)
			break;

		p += h->length;
	}

	if (p != end)
		return -EINVAL;


	return 0;
}

/*
 * Takes a pointer to an loongarch IOMMU entry in the ACPI table and
 * initializes the hardware and our data structures with it.
 */
static int __init init_iommu_from_acpi(struct loongarch_iommu *iommu,
					struct ivhd_header *h)
{
	u8 *p = (u8 *)h;
	u8 *end = p;
	u16 devid = 0, devid_start = 0;
	u32 dev_i;
	struct ivhd_entry *e;
	u32 ivhd_size;

	/*
	 * Done. Now parse the device entries
	 */
	ivhd_size = get_ivhd_header_size(h);
	if (!ivhd_size) {
		pr_err("loongarch iommu: Unsupported IVHD type %#x\n", h->type);
		return -EINVAL;
	}

	if (h->length == 0)
		return -EINVAL;

	p += ivhd_size;
	end += h->length;

	while (p < end) {
		e = (struct ivhd_entry *)p;
		switch (e->type) {
		case IVHD_DEV_ALL:
			for (dev_i = 0; dev_i <= la_iommu_last_bdf; ++dev_i)
				set_iommu_for_device(iommu, dev_i);
			break;
		case IVHD_DEV_SELECT:

			pr_info("  DEV_SELECT\t\t\t devid: %02x:%02x.%x\n",
				    PCI_BUS_NUM(e->devid),
				    PCI_SLOT(e->devid),
				    PCI_FUNC(e->devid));

			devid = e->devid;
			set_iommu_for_device(iommu, devid);
			break;
		case IVHD_DEV_SELECT_RANGE_START:

			pr_info(
			"  DEV_SELECT_RANGE_START\t devid: %02x:%02x.%x\n",
				    PCI_BUS_NUM(e->devid),
				    PCI_SLOT(e->devid),
				    PCI_FUNC(e->devid));

			devid_start = e->devid;
			break;
		case IVHD_DEV_RANGE_END:

			pr_info("  DEV_RANGE_END\t\t devid: %02x:%02x.%x\n",
				    PCI_BUS_NUM(e->devid),
				    PCI_SLOT(e->devid),
				    PCI_FUNC(e->devid));

			devid = e->devid;
			for (dev_i = devid_start; dev_i <= devid; ++dev_i)
				set_iommu_for_device(iommu, dev_i);
			break;
		default:
			break;
		}

		p += ivhd_entry_length(p);
	}

	return 0;
}

/*
 * This function clues the initialization function for one IOMMU
 * together and also allocates the command buffer and programs the
 * hardware. It does NOT enable the IOMMU. This is done afterwards.
 */
static int __init init_iommu_one(struct loongarch_iommu *iommu,
		struct ivhd_header *h)
{
	int ret;
	struct la_iommu_rlookup_entry *rlookupentry = NULL;

	spin_lock_init(&iommu->domain_bitmap_lock);
	spin_lock_init(&iommu->dom_info_lock);
	spin_lock_init(&iommu->pgtable_bitmap_lock);
	mutex_init(&iommu->la_iommu_pgtlock);

	/* Add IOMMU to internal data structures */
	INIT_LIST_HEAD(&iommu->dom_list);

	list_add_tail(&iommu->list, &la_iommu_list);

	/*
	 * Copy data from ACPI table entry to the iommu struct
	 */
	iommu->devid   = h->devid;
	iommu->segment = h->pci_seg;

	ret = init_iommu_from_acpi(iommu, h);
	if (ret) {
		pr_err("%s init iommu from acpi failed\n", __func__);
		return ret;
	}

	rlookupentry = lookup_rlooptable(iommu->segment);
	if (rlookupentry != NULL) {
		/*
		 * Make sure IOMMU is not considered to translate itself.
		 * The IVRS table tells us so, but this is a lie!
		 */
		rlookupentry->la_iommu_rlookup_table[iommu->devid] = NULL;
	}

	return 0;
}

/*
 * Iterates over all IOMMU entries in the ACPI table, allocates the
 * IOMMU structure and initializes it with init_iommu_one()
 */
static int __init init_iommu_all(struct acpi_table_header *table)
{
	u8 *p = (u8 *)table, *end = (u8 *)table;
	struct ivhd_header *h;
	struct loongarch_iommu *iommu;
	int ret;

	end += table->length;
	p += IVRS_HEADER_LENGTH;

	while (p < end) {
		h = (struct ivhd_header *)p;

		if (h->length == 0)
			break;

		if (*p == la_iommu_target_ivhd_type) {

			pr_info("device: %02x:%02x.%01x seg: %d\n",
				    PCI_BUS_NUM(h->devid), PCI_SLOT(h->devid),
				    PCI_FUNC(h->devid), h->pci_seg);

			iommu = kzalloc(sizeof(struct loongarch_iommu),
					GFP_KERNEL);
			if (iommu == NULL)
				return -ENOMEM;

			ret = init_iommu_one(iommu, h);
			if (ret) {
				kfree(iommu);
				pr_info("%s init iommu failed\n", __func__);
				return ret;
			}
		}
		p += h->length;
	}

	if (p != end)
		return -EINVAL;

	return 0;
}

/**
 * get_highest_supported_ivhd_type - Look up the appropriate IVHD type
 * @ivrs          Pointer to the IVRS header
 *
 * This function search through all IVDB of the maximum supported IVHD
 */
static u8 get_highest_supported_ivhd_type(struct acpi_table_header *ivrs)
{
	u8 *base = (u8 *)ivrs;
	struct ivhd_header *ivhd = (struct ivhd_header *)
					(base + IVRS_HEADER_LENGTH);
	u8 last_type = ivhd->type;
	u16 devid = ivhd->devid;

	while (((u8 *)ivhd - base < ivrs->length) &&
	       (ivhd->type <= ACPI_IVHD_TYPE_MAX_SUPPORTED) &&
	       (ivhd->length > 0)) {
		u8 *p = (u8 *) ivhd;

		if (ivhd->devid == devid)
			last_type = ivhd->type;
		ivhd = (struct ivhd_header *)(p + ivhd->length);
	}

	return last_type;
}

static inline unsigned long tbl_size(int entry_size)
{
	unsigned int shift = PAGE_SHIFT +
			 get_order(((int)la_iommu_last_bdf + 1) * entry_size);

	return 1UL << shift;
}

static int __init loongarch_iommu_ivrs_init(void)
{
	struct acpi_table_header *ivrs_base;
	acpi_status status;
	int ret = 0;

	status = acpi_get_table("IVRS", 0, &ivrs_base);
	if (status == AE_NOT_FOUND) {
		pr_info("%s get ivrs table failed\n", __func__);
		return -ENODEV;
	}

	/*
	 * Validate checksum here so we don't need to do it when
	 * we actually parse the table
	 */
	ret = check_ivrs_checksum(ivrs_base);
	if (ret)
		goto out;

	la_iommu_target_ivhd_type = get_highest_supported_ivhd_type(ivrs_base);
	pr_info("Using IVHD type %#x\n", la_iommu_target_ivhd_type);

	/*
	 * First parse ACPI tables to find the largest Bus/Dev/Func
	 * we need to handle. Upon this information the shared data
	 * structures for the IOMMUs in the system will be allocated
	 */
	ret = find_last_devid_acpi(ivrs_base);
	if (ret) {
		pr_err("%s find last devid failed\n", __func__);
		goto out;
	}

	rlookup_table_size = tbl_size(RLOOKUP_TABLE_ENTRY_SIZE);

	/*
	 * now the data structures are allocated and basically initialized
	 * start the real acpi table scan
	 */
	ret = init_iommu_all(ivrs_base);

out:
	/* Don't leak any ACPI memory */
	acpi_put_table(ivrs_base);
	ivrs_base = NULL;

	return ret;
}

static int __init loongarch_iommu_ivrs_init_stub(void)
{
	struct loongarch_iommu *iommu;
	struct la_iommu_rlookup_entry *rlookupentry = NULL;
	u32 dev_i;

	/* Use maximum BDF value for DEV_ALL */
	update_last_devid(MAX_BDF_NUM);

	rlookup_table_size = tbl_size(RLOOKUP_TABLE_ENTRY_SIZE);

	iommu = kzalloc(sizeof(struct loongarch_iommu), GFP_KERNEL);
	if (iommu == NULL)
		return -ENOMEM;

	spin_lock_init(&iommu->domain_bitmap_lock);
	spin_lock_init(&iommu->dom_info_lock);
	spin_lock_init(&iommu->pgtable_bitmap_lock);
	mutex_init(&iommu->la_iommu_pgtlock);

	/* Add IOMMU to internal data structures */
	INIT_LIST_HEAD(&iommu->dom_list);

	list_add_tail(&iommu->list, &la_iommu_list);

	/*
	 * Copy data from ACPI table entry to the iommu struct
	 */
	iommu->devid = 0xd0;
	iommu->segment = 0;

	for (dev_i = 0; dev_i <= la_iommu_last_bdf; ++dev_i)
		set_iommu_for_device(iommu, dev_i);

	rlookupentry = lookup_rlooptable(iommu->segment);
	if (rlookupentry != NULL) {
		/*
		 * Make sure IOMMU is not considered to translate itself.
		 * The IVRS table tells us so, but this is a lie!
		 */
		rlookupentry->la_iommu_rlookup_table[iommu->devid] = NULL;
	}

	return 0;
}

static void free_iommu_rlookup_entry(void)
{
	struct loongarch_iommu *iommu = NULL;
	struct la_iommu_rlookup_entry *rlookupentry = NULL;

	while (!list_empty(&la_iommu_list)) {
		iommu = list_first_entry(&la_iommu_list, struct loongarch_iommu, list);
		list_del(&iommu->list);
		kfree(iommu);
	}

	while (!list_empty(&la_rlookup_iommu_list)) {
		rlookupentry = list_first_entry(&la_rlookup_iommu_list,
				struct la_iommu_rlookup_entry, list);

		list_del(&rlookupentry->list);
		if (rlookupentry->la_iommu_rlookup_table != NULL) {
			free_pages(
			(unsigned long)rlookupentry->la_iommu_rlookup_table,
			get_order(rlookup_table_size));

			rlookupentry->la_iommu_rlookup_table = NULL;
		}

		kfree(rlookupentry);
	}
}

static int __init la_iommu_setup(char *str)
{
	if (!str)
		return -EINVAL;
	while (*str) {
		if (!strncmp(str, "on", 2)) {
			loongarch_iommu_disable = 0;
			pr_info("IOMMU enabled\n");
		} else if (!strncmp(str, "off", 3)) {
			loongarch_iommu_disable = 1;
			pr_info("IOMMU disabled\n");
		}
		str += strcspn(str, ",");
		while (*str == ',')
			str++;
	}
	return 0;
}
__setup("loongarch_iommu=", la_iommu_setup);

static const struct pci_device_id loongson_iommu_pci_tbl[] = {
	{ PCI_DEVICE(0x14, 0x7a1f) },
	{ 0, }
};

static struct pci_driver loongarch_iommu_driver = {
	.name = "loongarch-iommu",
	.id_table = loongson_iommu_pci_tbl,
	.probe	= loongarch_iommu_probe,
	.remove	= loongarch_iommu_remove,
};

static int __init loongarch_iommu_driver_init(void)
{
	int ret = 0;

	if (loongarch_iommu_disable == 0) {
		ret = loongarch_iommu_ivrs_init();
		if (ret != 0) {
			free_iommu_rlookup_entry();
			pr_err("Failed to init iommu by ivrs\n");
			ret = loongarch_iommu_ivrs_init_stub();
			if (ret != 0) {
				free_iommu_rlookup_entry();
				pr_err("Failed to init iommu by stub\n");
				return ret;
			}
		}

		ret = pci_register_driver(&loongarch_iommu_driver);
		if (ret != 0) {
			pr_err("Failed to register IOMMU driver\n");
			return ret;
		}
	}

	return ret;
}

static void __exit loongarch_iommu_driver_exit(void)
{
	if (loongarch_iommu_disable == 0) {
		free_iommu_rlookup_entry();
		pci_unregister_driver(&loongarch_iommu_driver);
	}
}

module_init(loongarch_iommu_driver_init);
module_exit(loongarch_iommu_driver_exit);
