Commit 67eccf15 authored by Benjamin Tissoires's avatar Benjamin Tissoires
Browse files

HID: add source argument to HID low level functions

This allows to know who actually sent what when we process the request
to the device.
This will be useful for a BPF firewall program to allow or not requests
coming from a dedicated hidraw node client.

Link: https://patch.msgid.link/20240626-hid_hw_req_bpf-v2-2-cfd60fb6c79f@kernel.org


Acked-by: default avatarJiri Kosina <jkosina@suse.com>
Signed-off-by: default avatarBenjamin Tissoires <bentiss@kernel.org>
parent ebae0b2a
Loading
Loading
Loading
Loading
+7 −5
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ EXPORT_SYMBOL(hid_ops);

u8 *
dispatch_hid_bpf_device_event(struct hid_device *hdev, enum hid_report_type type, u8 *data,
			      u32 *size, int interrupt)
			      u32 *size, int interrupt, u64 source)
{
	struct hid_bpf_ctx_kern ctx_kern = {
		.ctx = {
@@ -50,7 +50,7 @@ dispatch_hid_bpf_device_event(struct hid_device *hdev, enum hid_report_type type
	rcu_read_lock();
	list_for_each_entry_rcu(e, &hdev->bpf.prog_list, list) {
		if (e->hid_device_event) {
			ret = e->hid_device_event(&ctx_kern.ctx, type);
			ret = e->hid_device_event(&ctx_kern.ctx, type, source);
			if (ret < 0) {
				rcu_read_unlock();
				return ERR_PTR(ret);
@@ -359,7 +359,8 @@ hid_bpf_hw_request(struct hid_bpf_ctx *ctx, __u8 *buf, size_t buf__sz,
					      dma_data,
					      size,
					      rtype,
					      reqtype);
					      reqtype,
					      (__u64)ctx);

	if (ret > 0)
		memcpy(buf, dma_data, ret);
@@ -398,7 +399,8 @@ hid_bpf_hw_output_report(struct hid_bpf_ctx *ctx, __u8 *buf, size_t buf__sz)

	ret = hid_ops->hid_hw_output_report(hdev,
						dma_data,
						size);
						size,
						(__u64)ctx);

	kfree(dma_data);
	return ret;
@@ -429,7 +431,7 @@ hid_bpf_input_report(struct hid_bpf_ctx *ctx, enum hid_report_type type, u8 *buf

	hdev = (struct hid_device *)ctx->hid; /* discard const */

	return hid_ops->hid_input_report(hdev, type, buf, size, 0);
	return hid_ops->hid_input_report(hdev, type, buf, size, 0, (__u64)ctx);
}
__bpf_kfunc_end_defs();

+1 −1
Original line number Diff line number Diff line
@@ -257,7 +257,7 @@ static void hid_bpf_unreg(void *kdata)
	hid_put_device(hdev);
}

static int __hid_bpf_device_event(struct hid_bpf_ctx *ctx, enum hid_report_type type)
static int __hid_bpf_device_event(struct hid_bpf_ctx *ctx, enum hid_report_type type, __u64 source)
{
	return 0;
}
+53 −32
Original line number Diff line number Diff line
@@ -2025,19 +2025,9 @@ int hid_report_raw_event(struct hid_device *hid, enum hid_report_type type, u8 *
}
EXPORT_SYMBOL_GPL(hid_report_raw_event);

/**
 * hid_input_report - report data from lower layer (usb, bt...)
 *
 * @hid: hid device
 * @type: HID report type (HID_*_REPORT)
 * @data: report contents
 * @size: size of data parameter
 * @interrupt: distinguish between interrupt and control transfers
 *
 * This is data entry for lower layers.
 */
int hid_input_report(struct hid_device *hid, enum hid_report_type type, u8 *data, u32 size,
		     int interrupt)

static int __hid_input_report(struct hid_device *hid, enum hid_report_type type,
			      u8 *data, u32 size, int interrupt, u64 source)
{
	struct hid_report_enum *report_enum;
	struct hid_driver *hdrv;
@@ -2057,7 +2047,7 @@ int hid_input_report(struct hid_device *hid, enum hid_report_type type, u8 *data
	report_enum = hid->report_enum + type;
	hdrv = hid->driver;

	data = dispatch_hid_bpf_device_event(hid, type, data, &size, interrupt);
	data = dispatch_hid_bpf_device_event(hid, type, data, &size, interrupt, source);
	if (IS_ERR(data)) {
		ret = PTR_ERR(data);
		goto unlock;
@@ -2092,6 +2082,23 @@ int hid_input_report(struct hid_device *hid, enum hid_report_type type, u8 *data
	up(&hid->driver_input_lock);
	return ret;
}

/**
 * hid_input_report - report data from lower layer (usb, bt...)
 *
 * @hid: hid device
 * @type: HID report type (HID_*_REPORT)
 * @data: report contents
 * @size: size of data parameter
 * @interrupt: distinguish between interrupt and control transfers
 *
 * This is data entry for lower layers.
 */
int hid_input_report(struct hid_device *hid, enum hid_report_type type, u8 *data, u32 size,
		     int interrupt)
{
	return __hid_input_report(hid, type, data, size, interrupt, 0);
}
EXPORT_SYMBOL_GPL(hid_input_report);

bool hid_match_one_id(const struct hid_device *hdev,
@@ -2392,6 +2399,24 @@ void hid_hw_request(struct hid_device *hdev,
}
EXPORT_SYMBOL_GPL(hid_hw_request);

int __hid_hw_raw_request(struct hid_device *hdev,
			 unsigned char reportnum, __u8 *buf,
			 size_t len, enum hid_report_type rtype,
			 enum hid_class_request reqtype,
			 __u64 source)
{
	unsigned int max_buffer_size = HID_MAX_BUFFER_SIZE;

	if (hdev->ll_driver->max_buffer_size)
		max_buffer_size = hdev->ll_driver->max_buffer_size;

	if (len < 1 || len > max_buffer_size || !buf)
		return -EINVAL;

	return hdev->ll_driver->raw_request(hdev, reportnum, buf, len,
					    rtype, reqtype);
}

/**
 * hid_hw_raw_request - send report request to device
 *
@@ -2409,6 +2434,12 @@ EXPORT_SYMBOL_GPL(hid_hw_request);
int hid_hw_raw_request(struct hid_device *hdev,
		       unsigned char reportnum, __u8 *buf,
		       size_t len, enum hid_report_type rtype, enum hid_class_request reqtype)
{
	return __hid_hw_raw_request(hdev, reportnum, buf, len, rtype, reqtype, 0);
}
EXPORT_SYMBOL_GPL(hid_hw_raw_request);

int __hid_hw_output_report(struct hid_device *hdev, __u8 *buf, size_t len, __u64 source)
{
	unsigned int max_buffer_size = HID_MAX_BUFFER_SIZE;

@@ -2418,10 +2449,11 @@ int hid_hw_raw_request(struct hid_device *hdev,
	if (len < 1 || len > max_buffer_size || !buf)
		return -EINVAL;

	return hdev->ll_driver->raw_request(hdev, reportnum, buf, len,
					    rtype, reqtype);
	if (hdev->ll_driver->output_report)
		return hdev->ll_driver->output_report(hdev, buf, len);

	return -ENOSYS;
}
EXPORT_SYMBOL_GPL(hid_hw_raw_request);

/**
 * hid_hw_output_report - send output report to device
@@ -2434,18 +2466,7 @@ EXPORT_SYMBOL_GPL(hid_hw_raw_request);
 */
int hid_hw_output_report(struct hid_device *hdev, __u8 *buf, size_t len)
{
	unsigned int max_buffer_size = HID_MAX_BUFFER_SIZE;

	if (hdev->ll_driver->max_buffer_size)
		max_buffer_size = hdev->ll_driver->max_buffer_size;

	if (len < 1 || len > max_buffer_size || !buf)
		return -EINVAL;

	if (hdev->ll_driver->output_report)
		return hdev->ll_driver->output_report(hdev, buf, len);

	return -ENOSYS;
	return __hid_hw_output_report(hdev, buf, len, 0);
}
EXPORT_SYMBOL_GPL(hid_hw_output_report);

@@ -2972,9 +2993,9 @@ EXPORT_SYMBOL_GPL(hid_check_keys_pressed);
#ifdef CONFIG_HID_BPF
static struct hid_ops __hid_ops = {
	.hid_get_report = hid_get_report,
	.hid_hw_raw_request = hid_hw_raw_request,
	.hid_hw_output_report = hid_hw_output_report,
	.hid_input_report = hid_input_report,
	.hid_hw_raw_request = __hid_hw_raw_request,
	.hid_hw_output_report = __hid_hw_output_report,
	.hid_input_report = __hid_input_report,
	.owner = THIS_MODULE,
	.bus_type = &hid_bus_type,
};
+5 −5
Original line number Diff line number Diff line
@@ -140,7 +140,7 @@ static ssize_t hidraw_send_report(struct file *file, const char __user *buffer,

	if ((report_type == HID_OUTPUT_REPORT) &&
	    !(dev->quirks & HID_QUIRK_NO_OUTPUT_REPORTS_ON_INTR_EP)) {
		ret = hid_hw_output_report(dev, buf, count);
		ret = __hid_hw_output_report(dev, buf, count, (__u64)file);
		/*
		 * compatibility with old implementation of USB-HID and I2C-HID:
		 * if the device does not support receiving output reports,
@@ -150,8 +150,8 @@ static ssize_t hidraw_send_report(struct file *file, const char __user *buffer,
			goto out_free;
	}

	ret = hid_hw_raw_request(dev, buf[0], buf, count, report_type,
				HID_REQ_SET_REPORT);
	ret = __hid_hw_raw_request(dev, buf[0], buf, count, report_type,
				   HID_REQ_SET_REPORT, (__u64)file);

out_free:
	kfree(buf);
@@ -227,8 +227,8 @@ static ssize_t hidraw_get_report(struct file *file, char __user *buffer, size_t
		goto out_free;
	}

	ret = hid_hw_raw_request(dev, report_number, buf, count, report_type,
				 HID_REQ_GET_REPORT);
	ret = __hid_hw_raw_request(dev, report_number, buf, count, report_type,
				   HID_REQ_GET_REPORT, (__u64)file);

	if (ret < 0)
		goto out_free;
+6 −0
Original line number Diff line number Diff line
@@ -1125,6 +1125,12 @@ int __must_check hid_hw_open(struct hid_device *hdev);
void hid_hw_close(struct hid_device *hdev);
void hid_hw_request(struct hid_device *hdev,
		    struct hid_report *report, enum hid_class_request reqtype);
int __hid_hw_raw_request(struct hid_device *hdev,
			 unsigned char reportnum, __u8 *buf,
			 size_t len, enum hid_report_type rtype,
			 enum hid_class_request reqtype,
			 __u64 source);
int __hid_hw_output_report(struct hid_device *hdev, __u8 *buf, size_t len, __u64 source);
int hid_hw_raw_request(struct hid_device *hdev,
		       unsigned char reportnum, __u8 *buf,
		       size_t len, enum hid_report_type rtype,
Loading