Commit 3c27a36f authored by Marcelo Diop-Gonzalez's avatar Marcelo Diop-Gonzalez Committed by Greg Kroah-Hartman
Browse files

staging: vc04_services: use kref + RCU to reference count services



Currently reference counts are implemented by locking service_spinlock
and then incrementing the service's ->ref_count field, calling
kfree() when the last reference has been dropped. But at the same
time, there's code in multiple places that dereferences pointers
to services without having a reference, so there could be a race there.

It should be possible to avoid taking any lock in unlock_service()
or service_release() because we are setting a single array element
to NULL, and on service creation, a mutex is locked before looking
for a NULL spot to put the new service in.

Using a struct kref and RCU-delaying the freeing of services fixes
this race condition while still making it possible to skip
grabbing a reference in many places. Also it avoids the need to
acquire a single spinlock when e.g. taking a reference on
state->services[i] when somebody else is in the middle of taking
a reference on state->services[j].

Signed-off-by: default avatarMarcelo Diop-Gonzalez <marcgonzalez@google.com>
Link: https://lore.kernel.org/r/3bf6f1ec6ace64d7072025505e165b8dd18b25ca.1581532523.git.marcgonzalez@google.com


Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@linuxfoundation.org>
parent 0e35fa61
Loading
Loading
Loading
Loading
+19 −6
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include <linux/platform_device.h>
#include <linux/compat.h>
#include <linux/dma-mapping.h>
#include <linux/rcupdate.h>
#include <soc/bcm2835/raspberrypi-firmware.h>

#include "vchiq_core.h"
@@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context)
	/* There is no list of instances, so instead scan all services,
		marking those that have been dumped. */

	rcu_read_lock();
	for (i = 0; i < state->unused_service; i++) {
		struct vchiq_service *service = state->services[i];
		struct vchiq_service *service;
		struct vchiq_instance *instance;

		service = rcu_dereference(state->services[i]);
		if (!service || service->base.callback != service_callback)
			continue;

@@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context)
		if (instance)
			instance->mark = 0;
	}
	rcu_read_unlock();

	for (i = 0; i < state->unused_service; i++) {
		struct vchiq_service *service = state->services[i];
		struct vchiq_service *service;
		struct vchiq_instance *instance;
		int err;

		if (!service || service->base.callback != service_callback)
		rcu_read_lock();
		service = rcu_dereference(state->services[i]);
		if (!service || service->base.callback != service_callback) {
			rcu_read_unlock();
			continue;
		}

		instance = service->instance;
		if (!instance || instance->mark)
		if (!instance || instance->mark) {
			rcu_read_unlock();
			continue;
		}
		rcu_read_unlock();

		len = snprintf(buf, sizeof(buf),
			       "Instance %pK: pid %d,%s completions %d/%d",
@@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context)
			       instance->completion_insert -
			       instance->completion_remove,
			       MAX_COMPLETIONS);

		err = vchiq_dump(dump_context, buf, len + 1);
		if (err)
			return err;
@@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
	if (active_services > MAX_SERVICES)
		only_nonzero = 1;

	rcu_read_lock();
	for (i = 0; i < active_services; i++) {
		struct vchiq_service *service_ptr = state->services[i];
		struct vchiq_service *service_ptr =
			rcu_dereference(state->services[i]);

		if (!service_ptr)
			continue;
@@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
		if (found >= MAX_SERVICES)
			break;
	}
	rcu_read_unlock();

	read_unlock_bh(&arm_state->susp_res_lock);

+113 −109
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
/* Copyright (c) 2010-2012 Broadcom. All rights reserved. */

#include <linux/kref.h>
#include <linux/rcupdate.h>

#include "vchiq_core.h"

#define VCHIQ_SLOT_HANDLER_STACK 8192
@@ -54,7 +57,6 @@ int vchiq_core_log_level = VCHIQ_LOG_DEFAULT;
int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT;
int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT;

static DEFINE_SPINLOCK(service_spinlock);
DEFINE_SPINLOCK(bulk_waiter_spinlock);
static DEFINE_SPINLOCK(quota_spinlock);

@@ -136,44 +138,41 @@ find_service_by_handle(unsigned int handle)
{
	struct vchiq_service *service;

	spin_lock(&service_spinlock);
	rcu_read_lock();
	service = handle_to_service(handle);
	if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
	    service->handle == handle) {
		WARN_ON(service->ref_count == 0);
		service->ref_count++;
	} else
		service = NULL;
	spin_unlock(&service_spinlock);

	if (!service)
	    service->handle == handle &&
	    kref_get_unless_zero(&service->ref_count)) {
		service = rcu_pointer_handoff(service);
		rcu_read_unlock();
		return service;
	}
	rcu_read_unlock();
	vchiq_log_info(vchiq_core_log_level,
		       "Invalid service handle 0x%x", handle);

	return service;
	return NULL;
}

struct vchiq_service *
find_service_by_port(struct vchiq_state *state, int localport)
{
	struct vchiq_service *service = NULL;

	if ((unsigned int)localport <= VCHIQ_PORT_MAX) {
		spin_lock(&service_spinlock);
		service = state->services[localport];
		if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) {
			WARN_ON(service->ref_count == 0);
			service->ref_count++;
		} else
			service = NULL;
		spin_unlock(&service_spinlock);
	}
		struct vchiq_service *service;

	if (!service)
		rcu_read_lock();
		service = rcu_dereference(state->services[localport]);
		if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
		    kref_get_unless_zero(&service->ref_count)) {
			service = rcu_pointer_handoff(service);
			rcu_read_unlock();
			return service;
		}
		rcu_read_unlock();
	}
	vchiq_log_info(vchiq_core_log_level,
		       "Invalid port %d", localport);

	return service;
	return NULL;
}

struct vchiq_service *
@@ -182,22 +181,20 @@ find_service_for_instance(struct vchiq_instance *instance,
{
	struct vchiq_service *service;

	spin_lock(&service_spinlock);
	rcu_read_lock();
	service = handle_to_service(handle);
	if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
	    service->handle == handle &&
	    service->instance == instance) {
		WARN_ON(service->ref_count == 0);
		service->ref_count++;
	} else
		service = NULL;
	spin_unlock(&service_spinlock);

	if (!service)
	    service->instance == instance &&
	    kref_get_unless_zero(&service->ref_count)) {
		service = rcu_pointer_handoff(service);
		rcu_read_unlock();
		return service;
	}
	rcu_read_unlock();
	vchiq_log_info(vchiq_core_log_level,
		       "Invalid service handle 0x%x", handle);

	return service;
	return NULL;
}

struct vchiq_service *
@@ -206,23 +203,21 @@ find_closed_service_for_instance(struct vchiq_instance *instance,
{
	struct vchiq_service *service;

	spin_lock(&service_spinlock);
	rcu_read_lock();
	service = handle_to_service(handle);
	if (service &&
	    (service->srvstate == VCHIQ_SRVSTATE_FREE ||
	     service->srvstate == VCHIQ_SRVSTATE_CLOSED) &&
	    service->handle == handle &&
	    service->instance == instance) {
		WARN_ON(service->ref_count == 0);
		service->ref_count++;
	} else
		service = NULL;
	spin_unlock(&service_spinlock);

	if (!service)
	    service->instance == instance &&
	    kref_get_unless_zero(&service->ref_count)) {
		service = rcu_pointer_handoff(service);
		rcu_read_unlock();
		return service;
	}
	rcu_read_unlock();
	vchiq_log_info(vchiq_core_log_level,
		       "Invalid service handle 0x%x", handle);

	return service;
}

@@ -233,19 +228,19 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
	struct vchiq_service *service = NULL;
	int idx = *pidx;

	spin_lock(&service_spinlock);
	rcu_read_lock();
	while (idx < state->unused_service) {
		struct vchiq_service *srv = state->services[idx++];
		struct vchiq_service *srv;

		srv = rcu_dereference(state->services[idx++]);
		if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE &&
		    srv->instance == instance) {
			service = srv;
			WARN_ON(service->ref_count == 0);
			service->ref_count++;
		    srv->instance == instance &&
		    kref_get_unless_zero(&srv->ref_count)) {
			service = rcu_pointer_handoff(srv);
			break;
		}
	}
	spin_unlock(&service_spinlock);
	rcu_read_unlock();

	*pidx = idx;

@@ -255,43 +250,34 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
void
lock_service(struct vchiq_service *service)
{
	spin_lock(&service_spinlock);
	WARN_ON(!service);
	if (service) {
		WARN_ON(service->ref_count == 0);
		service->ref_count++;
	if (!service) {
		WARN(1, "%s service is NULL\n", __func__);
		return;
	}
	spin_unlock(&service_spinlock);
	kref_get(&service->ref_count);
}

void
unlock_service(struct vchiq_service *service)
static void service_release(struct kref *kref)
{
	spin_lock(&service_spinlock);
	if (!service) {
		WARN(1, "%s: service is NULL\n", __func__);
		goto unlock;
	}
	if (!service->ref_count) {
		WARN(1, "%s: ref_count is zero\n", __func__);
		goto unlock;
	}
	service->ref_count--;
	if (!service->ref_count) {
	struct vchiq_service *service =
		container_of(kref, struct vchiq_service, ref_count);
	struct vchiq_state *state = service->state;

	WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
		state->services[service->localport] = NULL;
	} else {
		service = NULL;
	}
unlock:
	spin_unlock(&service_spinlock);

	if (service && service->userdata_term)
	rcu_assign_pointer(state->services[service->localport], NULL);
	if (service->userdata_term)
		service->userdata_term(service->base.userdata);
	kfree_rcu(service, rcu);
}

	kfree(service);
void
unlock_service(struct vchiq_service *service)
{
	if (!service) {
		WARN(1, "%s: service is NULL\n", __func__);
		return;
	}
	kref_put(&service->ref_count, service_release);
}

int
@@ -310,9 +296,14 @@ vchiq_get_client_id(unsigned int handle)
void *
vchiq_get_service_userdata(unsigned int handle)
{
	struct vchiq_service *service = handle_to_service(handle);
	void *userdata;
	struct vchiq_service *service;

	return service ? service->base.userdata : NULL;
	rcu_read_lock();
	service = handle_to_service(handle);
	userdata = service ? service->base.userdata : NULL;
	rcu_read_unlock();
	return userdata;
}

static void
@@ -460,19 +451,23 @@ get_listening_service(struct vchiq_state *state, int fourcc)

	WARN_ON(fourcc == VCHIQ_FOURCC_INVALID);

	rcu_read_lock();
	for (i = 0; i < state->unused_service; i++) {
		struct vchiq_service *service = state->services[i];
		struct vchiq_service *service;

		service = rcu_dereference(state->services[i]);
		if (service &&
		    service->public_fourcc == fourcc &&
		    (service->srvstate == VCHIQ_SRVSTATE_LISTENING ||
		     (service->srvstate == VCHIQ_SRVSTATE_OPEN &&
		      service->remoteport == VCHIQ_PORT_FREE))) {
			lock_service(service);
		      service->remoteport == VCHIQ_PORT_FREE)) &&
		    kref_get_unless_zero(&service->ref_count)) {
			service = rcu_pointer_handoff(service);
			rcu_read_unlock();
			return service;
		}
	}

	rcu_read_unlock();
	return NULL;
}

@@ -482,15 +477,20 @@ get_connected_service(struct vchiq_state *state, unsigned int port)
{
	int i;

	rcu_read_lock();
	for (i = 0; i < state->unused_service; i++) {
		struct vchiq_service *service = state->services[i];
		struct vchiq_service *service =
			rcu_dereference(state->services[i]);

		if (service && service->srvstate == VCHIQ_SRVSTATE_OPEN &&
		    service->remoteport == port) {
			lock_service(service);
		    service->remoteport == port &&
		    kref_get_unless_zero(&service->ref_count)) {
			service = rcu_pointer_handoff(service);
			rcu_read_unlock();
			return service;
		}
	}
	rcu_read_unlock();
	return NULL;
}

@@ -2260,7 +2260,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
			   vchiq_userdata_term userdata_term)
{
	struct vchiq_service *service;
	struct vchiq_service **pservice = NULL;
	struct vchiq_service __rcu **pservice = NULL;
	struct vchiq_service_quota *service_quota;
	int i;

@@ -2272,7 +2272,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
	service->base.callback = params->callback;
	service->base.userdata = params->userdata;
	service->handle        = VCHIQ_SERVICE_HANDLE_INVALID;
	service->ref_count     = 1;
	kref_init(&service->ref_count);
	service->srvstate      = VCHIQ_SRVSTATE_FREE;
	service->userdata_term = userdata_term;
	service->localport     = VCHIQ_PORT_FREE;
@@ -2298,7 +2298,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
	mutex_init(&service->bulk_mutex);
	memset(&service->stats, 0, sizeof(service->stats));

	/* Although it is perfectly possible to use service_spinlock
	/* Although it is perfectly possible to use a spinlock
	** to protect the creation of services, it is overkill as it
	** disables interrupts while the array is searched.
	** The only danger is of another thread trying to create a
@@ -2316,17 +2316,17 @@ vchiq_add_service_internal(struct vchiq_state *state,

	if (srvstate == VCHIQ_SRVSTATE_OPENING) {
		for (i = 0; i < state->unused_service; i++) {
			struct vchiq_service *srv = state->services[i];

			if (!srv) {
			if (!rcu_access_pointer(state->services[i])) {
				pservice = &state->services[i];
				break;
			}
		}
	} else {
		rcu_read_lock();
		for (i = (state->unused_service - 1); i >= 0; i--) {
			struct vchiq_service *srv = state->services[i];
			struct vchiq_service *srv;

			srv = rcu_dereference(state->services[i]);
			if (!srv)
				pservice = &state->services[i];
			else if ((srv->public_fourcc == params->fourcc)
@@ -2339,6 +2339,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
				break;
			}
		}
		rcu_read_unlock();
	}

	if (pservice) {
@@ -2350,7 +2351,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
			(state->id * VCHIQ_MAX_SERVICES) |
			service->localport;
		handle_seq += VCHIQ_MAX_STATES * VCHIQ_MAX_SERVICES;
		*pservice = service;
		rcu_assign_pointer(*pservice, service);
		if (pservice == &state->services[state->unused_service])
			state->unused_service++;
	}
@@ -2416,10 +2417,10 @@ vchiq_open_service_internal(struct vchiq_service *service, int client_id)
			   (service->srvstate != VCHIQ_SRVSTATE_OPENSYNC)) {
			if (service->srvstate != VCHIQ_SRVSTATE_CLOSEWAIT)
				vchiq_log_error(vchiq_core_log_level,
						"%d: osi - srvstate = %s (ref %d)",
						"%d: osi - srvstate = %s (ref %u)",
						service->state->id,
						srvstate_names[service->srvstate],
						service->ref_count);
						kref_read(&service->ref_count));
			status = VCHIQ_ERROR;
			VCHIQ_SERVICE_STATS_INC(service, error_count);
			vchiq_release_service_internal(service);
@@ -3425,10 +3426,13 @@ int vchiq_dump_service_state(void *dump_context, struct vchiq_service *service)
	char buf[80];
	int len;
	int err;
	unsigned int ref_count;

	/*Don't include the lock just taken*/
	ref_count = kref_read(&service->ref_count) - 1;
	len = scnprintf(buf, sizeof(buf), "Service %u: %s (ref %u)",
			service->localport, srvstate_names[service->srvstate],
			service->ref_count - 1); /*Don't include the lock just taken*/
			ref_count);

	if (service->srvstate != VCHIQ_SRVSTATE_FREE) {
		char remoteport[30];
+8 −4
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@
#include <linux/mutex.h>
#include <linux/completion.h>
#include <linux/kthread.h>
#include <linux/kref.h>
#include <linux/rcupdate.h>
#include <linux/wait.h>

#include "vchiq_cfg.h"
@@ -251,7 +253,8 @@ struct vchiq_slot_info {
struct vchiq_service {
	struct vchiq_service_base base;
	unsigned int handle;
	unsigned int ref_count;
	struct kref ref_count;
	struct rcu_head rcu;
	int srvstate;
	vchiq_userdata_term userdata_term;
	unsigned int localport;
@@ -464,7 +467,7 @@ struct vchiq_state {
		int error_count;
	} stats;

	struct vchiq_service *services[VCHIQ_MAX_SERVICES];
	struct vchiq_service __rcu *services[VCHIQ_MAX_SERVICES];
	struct vchiq_service_quota service_quotas[VCHIQ_MAX_SERVICES];
	struct vchiq_slot_info slot_info[VCHIQ_MAX_SLOTS];

@@ -545,12 +548,13 @@ request_poll(struct vchiq_state *state, struct vchiq_service *service,
static inline struct vchiq_service *
handle_to_service(unsigned int handle)
{
	int idx = handle & (VCHIQ_MAX_SERVICES - 1);
	struct vchiq_state *state = vchiq_states[(handle / VCHIQ_MAX_SERVICES) &
		(VCHIQ_MAX_STATES - 1)];

	if (!state)
		return NULL;

	return state->services[handle & (VCHIQ_MAX_SERVICES - 1)];
	return rcu_dereference(state->services[idx]);
}

extern struct vchiq_service *