Commit 315cc066 authored by Michel Lespinasse's avatar Michel Lespinasse Committed by Linus Torvalds
Browse files

augmented rbtree: add new RB_DECLARE_CALLBACKS_MAX macro

Add RB_DECLARE_CALLBACKS_MAX, which generates augmented rbtree callbacks
for the case where the augmented value is a scalar whose definition
follows a max(f(node)) pattern.  This actually covers all present uses of
RB_DECLARE_CALLBACKS, and saves some (source) code duplication in the
various RBCOMPUTE function definitions.

[walken@google.com: fix mm/vmalloc.c]
  Link: http://lkml.kernel.org/r/CANN689FXgK13wDYNh1zKxdipeTuALG4eKvKpsdZqKFJ-rvtGiQ@mail.gmail.com
[walken@google.com: re-add check to check_augmented()]
  Link: http://lkml.kernel.org/r/20190727022027.GA86863@google.com
Link: http://lkml.kernel.org/r/20190703040156.56953-3-walken@google.com


Signed-off-by: default avatarMichel Lespinasse <walken@google.com>
Acked-by: default avatarPeter Zijlstra (Intel) <peterz@infradead.org>
Cc: David Howells <dhowells@redhat.com>
Cc: Davidlohr Bueso <dbueso@suse.de>
Cc: Uladzislau Rezki <urezki@gmail.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent 444b8a83
Loading
Loading
Loading
Loading
+3 −16
Original line number Diff line number Diff line
@@ -54,23 +54,10 @@ static u64 get_subtree_max_end(struct rb_node *node)
	return ret;
}

static u64 compute_subtree_max_end(struct memtype *data)
{
	u64 max_end = data->end, child_max_end;

	child_max_end = get_subtree_max_end(data->rb.rb_right);
	if (child_max_end > max_end)
		max_end = child_max_end;

	child_max_end = get_subtree_max_end(data->rb.rb_left);
	if (child_max_end > max_end)
		max_end = child_max_end;

	return max_end;
}
#define NODE_END(node) ((node)->end)

RB_DECLARE_CALLBACKS(static, memtype_rb_augment_cb, struct memtype, rb,
		     u64, subtree_max_end, compute_subtree_max_end)
RB_DECLARE_CALLBACKS_MAX(static, memtype_rb_augment_cb,
			 struct memtype, rb, u64, subtree_max_end, NODE_END)

/* Find the first (lowest start addr) overlapping range from rb tree */
static struct memtype *memtype_rb_lowest_match(struct rb_root *root,
+3 −26
Original line number Diff line number Diff line
@@ -13,33 +13,10 @@ sector_t interval_end(struct rb_node *node)
	return this->end;
}

/**
 * compute_subtree_last  -  compute end of @node
 *
 * The end of an interval is the highest (start + (size >> 9)) value of this
 * node and of its children.  Called for @node and its parents whenever the end
 * may have changed.
 */
static inline sector_t
compute_subtree_last(struct drbd_interval *node)
{
	sector_t max = node->sector + (node->size >> 9);

	if (node->rb.rb_left) {
		sector_t left = interval_end(node->rb.rb_left);
		if (left > max)
			max = left;
	}
	if (node->rb.rb_right) {
		sector_t right = interval_end(node->rb.rb_right);
		if (right > max)
			max = right;
	}
	return max;
}
#define NODE_END(node) ((node)->sector + ((node)->size >> 9))

RB_DECLARE_CALLBACKS(static, augment_callbacks, struct drbd_interval, rb,
		     sector_t, end, compute_subtree_last);
RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
			 struct drbd_interval, rb, sector_t, end, NODE_END);

/**
 * drbd_insert_interval  -  insert a new interval into a tree
+2 −20
Original line number Diff line number Diff line
@@ -30,26 +30,8 @@
									      \
/* Callbacks for augmented rbtree insert and remove */			      \
									      \
static inline ITTYPE ITPREFIX ## _compute_subtree_last(ITSTRUCT *node)	      \
{									      \
	ITTYPE max = ITLAST(node), subtree_last;			      \
	if (node->ITRB.rb_left) {					      \
		subtree_last = rb_entry(node->ITRB.rb_left,		      \
					ITSTRUCT, ITRB)->ITSUBTREE;	      \
		if (max < subtree_last)					      \
			max = subtree_last;				      \
	}								      \
	if (node->ITRB.rb_right) {					      \
		subtree_last = rb_entry(node->ITRB.rb_right,		      \
					ITSTRUCT, ITRB)->ITSUBTREE;	      \
		if (max < subtree_last)					      \
			max = subtree_last;				      \
	}								      \
	return max;							      \
}									      \
									      \
RB_DECLARE_CALLBACKS(static, ITPREFIX ## _augment, ITSTRUCT, ITRB,	      \
		     ITTYPE, ITSUBTREE, ITPREFIX ## _compute_subtree_last)    \
RB_DECLARE_CALLBACKS_MAX(static, ITPREFIX ## _augment,			      \
			 ITSTRUCT, ITRB, ITTYPE, ITSUBTREE, ITLAST)	      \
									      \
/* Insert / remove interval nodes from the tree */			      \
									      \
+35 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ rb_insert_augmented_cached(struct rb_node *node,
}

/*
 * Template for declaring augmented rbtree callbacks
 * Template for declaring augmented rbtree callbacks (generic case)
 *
 * RBSTATIC:    'static' or empty
 * RBNAME:      name of the rb_augment_callbacks structure
@@ -107,6 +107,40 @@ RBSTATIC const struct rb_augment_callbacks RBNAME = { \
	.rotate = RBNAME ## _rotate					\
};

/*
 * Template for declaring augmented rbtree callbacks,
 * computing RBAUGMENTED scalar as max(RBCOMPUTE(node)) for all subtree nodes.
 *
 * RBSTATIC:    'static' or empty
 * RBNAME:      name of the rb_augment_callbacks structure
 * RBSTRUCT:    struct type of the tree nodes
 * RBFIELD:     name of struct rb_node field within RBSTRUCT
 * RBTYPE:      type of the RBAUGMENTED field
 * RBAUGMENTED: name of RBTYPE field within RBSTRUCT holding data for subtree
 * RBCOMPUTE:   name of function that returns the per-node RBTYPE scalar
 */

#define RB_DECLARE_CALLBACKS_MAX(RBSTATIC, RBNAME, RBSTRUCT, RBFIELD,	      \
				 RBTYPE, RBAUGMENTED, RBCOMPUTE)	      \
static inline RBTYPE RBNAME ## _compute_max(RBSTRUCT *node)		      \
{									      \
	RBSTRUCT *child;						      \
	RBTYPE max = RBCOMPUTE(node);					      \
	if (node->RBFIELD.rb_left) {					      \
		child = rb_entry(node->RBFIELD.rb_left, RBSTRUCT, RBFIELD);   \
		if (child->RBAUGMENTED > max)				      \
			max = child->RBAUGMENTED;			      \
	}								      \
	if (node->RBFIELD.rb_right) {					      \
		child = rb_entry(node->RBFIELD.rb_right, RBSTRUCT, RBFIELD);  \
		if (child->RBAUGMENTED > max)				      \
			max = child->RBAUGMENTED;			      \
	}								      \
	return max;							      \
}									      \
RB_DECLARE_CALLBACKS(RBSTATIC, RBNAME, RBSTRUCT, RBFIELD,		      \
		     RBTYPE, RBAUGMENTED, RBNAME ## _compute_max)


#define	RB_RED		0
#define	RB_BLACK	1
+17 −20
Original line number Diff line number Diff line
@@ -77,26 +77,10 @@ static inline void erase_cached(struct test_node *node, struct rb_root_cached *r
}


static inline u32 augment_recompute(struct test_node *node)
{
	u32 max = node->val, child_augmented;
	if (node->rb.rb_left) {
		child_augmented = rb_entry(node->rb.rb_left, struct test_node,
					   rb)->augmented;
		if (max < child_augmented)
			max = child_augmented;
	}
	if (node->rb.rb_right) {
		child_augmented = rb_entry(node->rb.rb_right, struct test_node,
					   rb)->augmented;
		if (max < child_augmented)
			max = child_augmented;
	}
	return max;
}
#define NODE_VAL(node) ((node)->val)

RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb,
		     u32, augmented, augment_recompute)
RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
			 struct test_node, rb, u32, augmented, NODE_VAL)

static void insert_augmented(struct test_node *node,
			     struct rb_root_cached *root)
@@ -238,7 +222,20 @@ static void check_augmented(int nr_nodes)
	check(nr_nodes);
	for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
		struct test_node *node = rb_entry(rb, struct test_node, rb);
		WARN_ON_ONCE(node->augmented != augment_recompute(node));
		u32 subtree, max = node->val;
		if (node->rb.rb_left) {
			subtree = rb_entry(node->rb.rb_left, struct test_node,
					   rb)->augmented;
			if (max < subtree)
				max = subtree;
		}
		if (node->rb.rb_right) {
			subtree = rb_entry(node->rb.rb_right, struct test_node,
					   rb)->augmented;
			if (max < subtree)
				max = subtree;
		}
		WARN_ON_ONCE(node->augmented != max);
	}
}

Loading