Commit bec7dcbc authored by Linus Torvalds's avatar Linus Torvalds
Browse files
Pull probes fixes from Masami Hiramatsu:

 - fprobe: remove fprobe_hlist_node when module unloading

   When a fprobe target module is removed, the fprobe_hlist_node should
   be removed from the fprobe's hash table to prevent reusing
   accidentally if another module is loaded at the same address.

 - fprobe: lock module while registering fprobe

   The module containing the function to be probeed is locked using a
   reference counter until the fprobe registration is complete, which
   prevents use after free.

 - fprobe-events: fix possible UAF on modules

   Basically as same as above, but in the fprobe-events layer we also
   need to get module reference counter when we find the tracepoint in
   the module.

* tag 'probes-fixes-v6.14' of git://git.kernel.org/pub/scm/linux/kernel/git/trace/linux-trace:
  tracing: fprobe: Cleanup fprobe hash when module unloading
  tracing: fprobe events: Fix possible UAF on modules
  tracing: fprobe: Fix to lock module while registering fprobe
parents e37f72b3 a3dc2983
Loading
Loading
Loading
Loading
+149 −21
Original line number Diff line number Diff line
@@ -89,8 +89,11 @@ static bool delete_fprobe_node(struct fprobe_hlist_node *node)
{
	lockdep_assert_held(&fprobe_mutex);

	/* Avoid double deleting */
	if (READ_ONCE(node->fp) != NULL) {
		WRITE_ONCE(node->fp, NULL);
		hlist_del_rcu(&node->hlist);
	}
	return !!find_first_fprobe_node(node->addr);
}

@@ -411,6 +414,102 @@ static void fprobe_graph_remove_ips(unsigned long *addrs, int num)
		ftrace_set_filter_ips(&fprobe_graph_ops.ops, addrs, num, 1, 0);
}

#ifdef CONFIG_MODULES

#define FPROBE_IPS_BATCH_INIT 8
/* instruction pointer address list */
struct fprobe_addr_list {
	int index;
	int size;
	unsigned long *addrs;
};

static int fprobe_addr_list_add(struct fprobe_addr_list *alist, unsigned long addr)
{
	unsigned long *addrs;

	if (alist->index >= alist->size)
		return -ENOMEM;

	alist->addrs[alist->index++] = addr;
	if (alist->index < alist->size)
		return 0;

	/* Expand the address list */
	addrs = kcalloc(alist->size * 2, sizeof(*addrs), GFP_KERNEL);
	if (!addrs)
		return -ENOMEM;

	memcpy(addrs, alist->addrs, alist->size * sizeof(*addrs));
	alist->size *= 2;
	kfree(alist->addrs);
	alist->addrs = addrs;

	return 0;
}

static void fprobe_remove_node_in_module(struct module *mod, struct hlist_head *head,
					struct fprobe_addr_list *alist)
{
	struct fprobe_hlist_node *node;
	int ret = 0;

	hlist_for_each_entry_rcu(node, head, hlist) {
		if (!within_module(node->addr, mod))
			continue;
		if (delete_fprobe_node(node))
			continue;
		/*
		 * If failed to update alist, just continue to update hlist.
		 * Therefore, at list user handler will not hit anymore.
		 */
		if (!ret)
			ret = fprobe_addr_list_add(alist, node->addr);
	}
}

/* Handle module unloading to manage fprobe_ip_table. */
static int fprobe_module_callback(struct notifier_block *nb,
				  unsigned long val, void *data)
{
	struct fprobe_addr_list alist = {.size = FPROBE_IPS_BATCH_INIT};
	struct module *mod = data;
	int i;

	if (val != MODULE_STATE_GOING)
		return NOTIFY_DONE;

	alist.addrs = kcalloc(alist.size, sizeof(*alist.addrs), GFP_KERNEL);
	/* If failed to alloc memory, we can not remove ips from hash. */
	if (!alist.addrs)
		return NOTIFY_DONE;

	mutex_lock(&fprobe_mutex);
	for (i = 0; i < FPROBE_IP_TABLE_SIZE; i++)
		fprobe_remove_node_in_module(mod, &fprobe_ip_table[i], &alist);

	if (alist.index < alist.size && alist.index > 0)
		ftrace_set_filter_ips(&fprobe_graph_ops.ops,
				      alist.addrs, alist.index, 1, 0);
	mutex_unlock(&fprobe_mutex);

	kfree(alist.addrs);

	return NOTIFY_DONE;
}

static struct notifier_block fprobe_module_nb = {
	.notifier_call = fprobe_module_callback,
	.priority = 0,
};

static int __init init_fprobe_module(void)
{
	return register_module_notifier(&fprobe_module_nb);
}
early_initcall(init_fprobe_module);
#endif

static int symbols_cmp(const void *a, const void *b)
{
	const char **str_a = (const char **) a;
@@ -445,6 +544,7 @@ struct filter_match_data {
	size_t index;
	size_t size;
	unsigned long *addrs;
	struct module **mods;
};

static int filter_match_callback(void *data, const char *name, unsigned long addr)
@@ -458,30 +558,47 @@ static int filter_match_callback(void *data, const char *name, unsigned long add
	if (!ftrace_location(addr))
		return 0;

	if (match->addrs)
		match->addrs[match->index] = addr;
	if (match->addrs) {
		struct module *mod = __module_text_address(addr);

		if (mod && !try_module_get(mod))
			return 0;

		match->mods[match->index] = mod;
		match->addrs[match->index] = addr;
	}
	match->index++;
	return match->index == match->size;
}

/*
 * Make IP list from the filter/no-filter glob patterns.
 * Return the number of matched symbols, or -ENOENT.
 * Return the number of matched symbols, or errno.
 * If @addrs == NULL, this just counts the number of matched symbols. If @addrs
 * is passed with an array, we need to pass the an @mods array of the same size
 * to increment the module refcount for each symbol.
 * This means we also need to call `module_put` for each element of @mods after
 * using the @addrs.
 */
static int ip_list_from_filter(const char *filter, const char *notfilter,
			       unsigned long *addrs, size_t size)
static int get_ips_from_filter(const char *filter, const char *notfilter,
			       unsigned long *addrs, struct module **mods,
			       size_t size)
{
	struct filter_match_data match = { .filter = filter, .notfilter = notfilter,
		.index = 0, .size = size, .addrs = addrs};
		.index = 0, .size = size, .addrs = addrs, .mods = mods};
	int ret;

	if (addrs && !mods)
		return -EINVAL;

	ret = kallsyms_on_each_symbol(filter_match_callback, &match);
	if (ret < 0)
		return ret;
	if (IS_ENABLED(CONFIG_MODULES)) {
		ret = module_kallsyms_on_each_symbol(NULL, filter_match_callback, &match);
		if (ret < 0)
			return ret;
	}

	return match.index ?: -ENOENT;
}
@@ -543,24 +660,35 @@ static int fprobe_init(struct fprobe *fp, unsigned long *addrs, int num)
 */
int register_fprobe(struct fprobe *fp, const char *filter, const char *notfilter)
{
	unsigned long *addrs;
	int ret;
	unsigned long *addrs __free(kfree) = NULL;
	struct module **mods __free(kfree) = NULL;
	int ret, num;

	if (!fp || !filter)
		return -EINVAL;

	ret = ip_list_from_filter(filter, notfilter, NULL, FPROBE_IPS_MAX);
	if (ret < 0)
		return ret;
	num = get_ips_from_filter(filter, notfilter, NULL, NULL, FPROBE_IPS_MAX);
	if (num < 0)
		return num;

	addrs = kcalloc(ret, sizeof(unsigned long), GFP_KERNEL);
	addrs = kcalloc(num, sizeof(*addrs), GFP_KERNEL);
	if (!addrs)
		return -ENOMEM;
	ret = ip_list_from_filter(filter, notfilter, addrs, ret);
	if (ret > 0)

	mods = kcalloc(num, sizeof(*mods), GFP_KERNEL);
	if (!mods)
		return -ENOMEM;

	ret = get_ips_from_filter(filter, notfilter, addrs, mods, num);
	if (ret < 0)
		return ret;

	ret = register_fprobe_ips(fp, addrs, ret);

	kfree(addrs);
	for (int i = 0; i < num; i++) {
		if (mods[i])
			module_put(mods[i]);
	}
	return ret;
}
EXPORT_SYMBOL_GPL(register_fprobe);
+17 −9
Original line number Diff line number Diff line
@@ -919,10 +919,16 @@ static void __find_tracepoint_module_cb(struct tracepoint *tp, struct module *mo
	struct __find_tracepoint_cb_data *data = priv;

	if (!data->tpoint && !strcmp(data->tp_name, tp->name)) {
		data->tpoint = tp;
		if (!data->mod)
		/* If module is not specified, try getting module refcount. */
		if (!data->mod && mod) {
			/* If failed to get refcount, ignore this tracepoint. */
			if (!try_module_get(mod))
				return;

			data->mod = mod;
		}
		data->tpoint = tp;
	}
}

static void __find_tracepoint_cb(struct tracepoint *tp, void *priv)
@@ -933,7 +939,11 @@ static void __find_tracepoint_cb(struct tracepoint *tp, void *priv)
		data->tpoint = tp;
}

/* Find a tracepoint from kernel and module. */
/*
 * Find a tracepoint from kernel and module. If the tracepoint is on the module,
 * the module's refcount is incremented and returned as *@tp_mod. Thus, if it is
 * not NULL, caller must call module_put(*tp_mod) after used the tracepoint.
 */
static struct tracepoint *find_tracepoint(const char *tp_name,
					  struct module **tp_mod)
{
@@ -962,7 +972,10 @@ static void reenable_trace_fprobe(struct trace_fprobe *tf)
	}
}

/* Find a tracepoint from specified module. */
/*
 * Find a tracepoint from specified module. In this case, this does not get the
 * module's refcount. The caller must ensure the module is not freed.
 */
static struct tracepoint *find_tracepoint_in_module(struct module *mod,
						    const char *tp_name)
{
@@ -1169,11 +1182,6 @@ static int trace_fprobe_create_internal(int argc, const char *argv[],
	if (is_tracepoint) {
		ctx->flags |= TPARG_FL_TPOINT;
		tpoint = find_tracepoint(symbol, &tp_mod);
		/* lock module until register this tprobe. */
		if (tp_mod && !try_module_get(tp_mod)) {
			tpoint = NULL;
			tp_mod = NULL;
		}
		if (tpoint) {
			ctx->funcname = kallsyms_lookup(
				(unsigned long)tpoint->probestub,