diff --git a/lib/maple_tree.c b/lib/maple_tree.c index ee1ff0c59fd7..472ebc43d9fe 100644 --- a/lib/maple_tree.c +++ b/lib/maple_tree.c @@ -58,6 +58,7 @@ #include #include #include +#include #define CREATE_TRACE_POINTS #include @@ -301,7 +302,7 @@ static inline struct maple_node *mas_mn(const struct ma_state *mas) */ static inline void mte_set_node_dead(struct maple_enode *mn) { - mte_to_node(mn)->parent = ma_parent_ptr(mte_to_node(mn)); + WRITE_ONCE(mte_to_node(mn)->parent, ma_parent_ptr(mte_to_node(mn))); smp_wmb(); /* Needed for RCU */ } @@ -521,7 +522,7 @@ static inline unsigned int mte_parent_slot(const struct maple_enode *enode) static inline struct maple_node *mte_parent(const struct maple_enode *enode) { return (void *)((unsigned long) - (mte_to_node(enode)->parent) & ~MAPLE_NODE_MASK); + (READ_ONCE(mte_to_node(enode)->parent)) & ~MAPLE_NODE_MASK); } /* @@ -536,7 +537,7 @@ static inline bool ma_dead_node(const struct maple_node *node) /* Do not reorder reads from the node prior to the parent check */ smp_rmb(); - parent = (void *)((unsigned long) node->parent & ~MAPLE_NODE_MASK); + parent = (void *)((unsigned long) READ_ONCE(node->parent) & ~MAPLE_NODE_MASK); return (parent == node); } @@ -1699,6 +1700,7 @@ static inline void mas_adopt_children(struct ma_state *mas, do { child = mas_slot_locked(mas, slots, offset); mas_set_parent(mas, child, parent, offset); + smp_wmb(); /* Needed for RCU */ } while (offset--); } @@ -2775,6 +2777,7 @@ static inline void *mtree_range_walk(struct ma_state *mas) unsigned char end; unsigned long max, min; unsigned long prev_max, prev_min; + unsigned long my_pivot, mas_index; next = mas->node; min = mas->min; @@ -2789,22 +2792,27 @@ static inline void *mtree_range_walk(struct ma_state *mas) if (unlikely(ma_dead_node(node))) goto dead_node; - if (pivots[offset] >= mas->index) { + my_pivot = READ_ONCE(pivots[offset]); + mas_index = READ_ONCE(mas->index); + + if (my_pivot >= mas_index) { prev_max = max; prev_min = min; - max = pivots[offset]; + max = my_pivot; goto next; } do { offset++; - } while ((offset < end) && (pivots[offset] < mas->index)); + my_pivot = READ_ONCE(pivots[offset]); + } while ((offset < end) && (my_pivot < mas_index)); prev_min = min; - min = pivots[offset - 1] + 1; + min = READ_ONCE(pivots[offset - 1]) + 1; prev_max = max; - if (likely(offset < end && pivots[offset])) - max = pivots[offset]; + + if (likely(offset < end && my_pivot)) + max = my_pivot; next: slots = ma_slots(node, type);