Commit 99ff06dc authored by Steve Sistare's avatar Steve Sistare Committed by Jason Gunthorpe
Browse files

iommufd: Generalize iopt_pages address

The starting address in iopt_pages is currently a __user *uptr.
Generalize to allow other types of addresses.  Refactor iopt_alloc_pages()
and iopt_map_user_pages() into address-type specific and common functions.

Link: https://patch.msgid.link/r/1729861919-234514-4-git-send-email-steven.sistare@oracle.com


Suggested-by: default avatarNicolin Chen <nicolinc@nvidia.com>
Signed-off-by: default avatarSteve Sistare <steven.sistare@oracle.com>
Reviewed-by: default avatarJason Gunthorpe <jgg@nvidia.com>
Reviewed-by: default avatarNicolin Chen <nicolinc@nvidia.com>
Reviewed-by: default avatarKevin Tian <kevin.tian@intel.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@nvidia.com>
parent 32383c08
Loading
Loading
Loading
Loading
+34 −21
Original line number Diff line number Diff line
@@ -384,6 +384,34 @@ int iopt_map_pages(struct io_pagetable *iopt, struct list_head *pages_list,
	return rc;
}

static int iopt_map_common(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
			   struct iopt_pages *pages, unsigned long *iova,
			   unsigned long length, unsigned long start_byte,
			   int iommu_prot, unsigned int flags)
{
	struct iopt_pages_list elm = {};
	LIST_HEAD(pages_list);
	int rc;

	elm.pages = pages;
	elm.start_byte = start_byte;
	if (ictx->account_mode == IOPT_PAGES_ACCOUNT_MM &&
	    elm.pages->account_mode == IOPT_PAGES_ACCOUNT_USER)
		elm.pages->account_mode = IOPT_PAGES_ACCOUNT_MM;
	elm.length = length;
	list_add(&elm.next, &pages_list);

	rc = iopt_map_pages(iopt, &pages_list, length, iova, iommu_prot, flags);
	if (rc) {
		if (elm.area)
			iopt_abort_area(elm.area);
		if (elm.pages)
			iopt_put_pages(elm.pages);
		return rc;
	}
	return 0;
}

/**
 * iopt_map_user_pages() - Map a user VA to an iova in the io page table
 * @ictx: iommufd_ctx the iopt is part of
@@ -408,29 +436,14 @@ int iopt_map_user_pages(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
			unsigned long length, int iommu_prot,
			unsigned int flags)
{
	struct iopt_pages_list elm = {};
	LIST_HEAD(pages_list);
	int rc;
	struct iopt_pages *pages;

	elm.pages = iopt_alloc_pages(uptr, length, iommu_prot & IOMMU_WRITE);
	if (IS_ERR(elm.pages))
		return PTR_ERR(elm.pages);
	if (ictx->account_mode == IOPT_PAGES_ACCOUNT_MM &&
	    elm.pages->account_mode == IOPT_PAGES_ACCOUNT_USER)
		elm.pages->account_mode = IOPT_PAGES_ACCOUNT_MM;
	elm.start_byte = uptr - elm.pages->uptr;
	elm.length = length;
	list_add(&elm.next, &pages_list);
	pages = iopt_alloc_user_pages(uptr, length, iommu_prot & IOMMU_WRITE);
	if (IS_ERR(pages))
		return PTR_ERR(pages);

	rc = iopt_map_pages(iopt, &pages_list, length, iova, iommu_prot, flags);
	if (rc) {
		if (elm.area)
			iopt_abort_area(elm.area);
		if (elm.pages)
			iopt_put_pages(elm.pages);
		return rc;
	}
	return 0;
	return iopt_map_common(ictx, iopt, pages, iova, length,
			       uptr - pages->uptr, iommu_prot, flags);
}

struct iova_bitmap_fn_arg {
+10 −3
Original line number Diff line number Diff line
@@ -175,6 +175,10 @@ enum {
	IOPT_PAGES_ACCOUNT_MM = 2,
};

enum iopt_address_type {
	IOPT_ADDRESS_USER = 0,
};

/*
 * This holds a pinned page list for multiple areas of IO address space. The
 * pages always originate from a linear chunk of userspace VA. Multiple
@@ -195,7 +199,10 @@ struct iopt_pages {
	struct task_struct *source_task;
	struct mm_struct *source_mm;
	struct user_struct *source_user;
	void __user *uptr;
	enum iopt_address_type type;
	union {
		void __user *uptr;		/* IOPT_ADDRESS_USER */
	};
	bool writable:1;
	u8 account_mode;

@@ -206,8 +213,8 @@ struct iopt_pages {
	struct rb_root_cached domains_itree;
};

struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
				    bool writable);
struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
					 unsigned long length, bool writable);
void iopt_release_pages(struct kref *kref);
static inline void iopt_put_pages(struct iopt_pages *pages)
{
+23 −8
Original line number Diff line number Diff line
@@ -1139,11 +1139,11 @@ static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
	return 0;
}

struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
					   unsigned long length,
					   bool writable)
{
	struct iopt_pages *pages;
	unsigned long end;

	/*
	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
@@ -1152,9 +1152,6 @@ struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
		return ERR_PTR(-EINVAL);

	if (check_add_overflow((unsigned long)uptr, length, &end))
		return ERR_PTR(-EOVERFLOW);

	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
	if (!pages)
		return ERR_PTR(-ENOMEM);
@@ -1164,8 +1161,7 @@ struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
	mutex_init(&pages->mutex);
	pages->source_mm = current->mm;
	mmgrab(pages->source_mm);
	pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
	pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
	pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
	pages->access_itree = RB_ROOT_CACHED;
	pages->domains_itree = RB_ROOT_CACHED;
	pages->writable = writable;
@@ -1179,6 +1175,25 @@ struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
	return pages;
}

struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
					 unsigned long length, bool writable)
{
	struct iopt_pages *pages;
	unsigned long end;
	void __user *uptr_down =
		(void __user *) ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);

	if (check_add_overflow((unsigned long)uptr, length, &end))
		return ERR_PTR(-EOVERFLOW);

	pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
	if (IS_ERR(pages))
		return pages;
	pages->uptr = uptr_down;
	pages->type = IOPT_ADDRESS_USER;
	return pages;
}

void iopt_release_pages(struct kref *kref)
{
	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);