lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  linux-cve-announce  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [day] [month] [year] [list]
Date:   Thu, 12 Mar 2020 11:38:10 +0000
From:   George Spelvin <lkml@....ORG>
To:     chris@...is-wilson.co.uk
Cc:     Lukas Wunner <lukas@...ner.de>,
        Joonas Lahtinen <joonas.lahtinen@...ux.intel.com>,
        Daniel Vetter <daniel.vetter@...ll.ch>,
        linux-kernel@...r.kernel.org, lkml@....org
Subject: [CODE] lib/math/prime_numbers.c: halve memory usage

I ran across the prime number generator, and after wondering WTF it was 
doing in the kernel (it turns out it's used by some of the GPU self-tests
and not routine kernel operation), noticed it could use some optimization.

In particular, a tiny bit of special-case code to handle the prime 2
lets you limit the bitmap to odd numbers, which halves the memory usage.

There are ways to do better; if you special-case 2, 3 and 5, then you
can halve memory usage again because all other primes fall into 8 residue 
classes mod 30, but that comes with a significant code-size increase.

Two other optimizations I folded in:
- Start clearing the sieve at p**2, since any smaller composite
  has smaller factors than p and has already been excluded.
- In slow_is_prime_number(), do trial division starting with the
  smallest divisors, since they are most likely to terminate
  the loop.

What follows is my current code, in a test harness for user-space testing.
The "#if 0" sections comment out kernel-only code, and the #else
sections provide the necessary stubs to let the code be compiled
as a standalone executable.  Changing that to #if 1 produces a kernel
source file.

Comments?

One change I'm thinking of making is choosing the array sizes  
so sizeof(struct primes) + bitmap_size(primes->sz) is a power of
2 in size.  That's slightly more than doubling each iteration.

--- 8< --- cut here --- 8< ---
// SPDX-License-Identifier: GPL-2.0-only
#define pr_fmt(fmt) "prime numbers: " fmt "\n"

# if 0
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/prime_numbers.h>
#include <linux/slab.h>
#else
#include <assert.h>
#include <limits.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#if ULONG_MAX == 0xffffffff
#define BITS_PER_LONG 32
#else
#define BITS_PER_LONG 64
#endif
#define BITS_TO_LONGS(x) ((size_t)((x) + BITS_PER_LONG - 1)/BITS_PER_LONG)
struct rcu_head { };

#define DEFINE_MUTEX(lock) struct rcu_head lock;
#define mutex_lock(lock) (void)(lock)
#define mutex_unlock(lock) (void)(lock)
#define __rcu /*nothing*/
#define RCU_INITIALIZER(x) (x)
#define rcu_read_lock() (void)0
#define rcu_read_unlock() (void)0
#define rcu_dereference(x) (x)
#define rcu_dereference_protected(p, c) (p)
#define rcu_assign_pointer(ptr, new) ((ptr) = (new))

static inline unsigned long __fls(unsigned long word)
{
        return (sizeof(word) * 8) - 1 - __builtin_clzl(word);
}
static inline unsigned long __ffs(unsigned long word)
{
        return __builtin_ctzl(word);
}


static unsigned long int_sqrt(unsigned long x)
{
        unsigned long b, m, y = 0;

        if (x <= 1)
                return x;

        m = 1UL << (__fls(x) & ~1UL);
        while (m != 0) {
                b = y + m;
                y >>= 1;

                if (x >= b) {
                        x -= b;
                        y += m;
                }
                m >>= 2;
        }

        return y;
}

#define BUG_ON(x) assert(!(x))
#define round_up(x, y) ((((x) - 1) | ((unsigned long)(y) - 1)) + 1)
#define round_down(x, y) ((x)  & ~((unsigned long)(y) - 1))
#define roundup(x, y) (((x) + ((y) - 1)) / (y) * (y))
#define kmalloc(size, flags) malloc(size)
#define kfree free
#define kfree_rcu(ptr, rcu) free(ptr)

#define BIT(x) (1ul << (x))
#define BIT_MASK(b) BIT((b) & (BITS_PER_LONG - 1))
#define BIT_WORD(b) ((b) / BITS_PER_LONG)
static inline void __clear_bit(int nr, volatile unsigned long *addr)
{
        unsigned long mask = BIT_MASK(nr);
        unsigned long *p = ((unsigned long *)addr) + BIT_WORD(nr);

        *p &= ~mask;
}
static inline int test_bit(int nr, const volatile unsigned long *addr)
{
        return 1UL & (addr[BIT_WORD(nr)] >> (nr & (BITS_PER_LONG-1)));
}


static inline void bitmap_fill(unsigned long *dst, unsigned int nbits)
{
        unsigned int len = BITS_TO_LONGS(nbits) * sizeof(unsigned long);
        memset(dst, 0xff, len);
}

static inline void bitmap_copy(unsigned long *dst, const unsigned long *src,
                        unsigned int nbits)
{
        unsigned int len = BITS_TO_LONGS(nbits) * sizeof(unsigned long);
        memcpy(dst, src, len);
}

#define BITMAP_FIRST_WORD_MASK(start) (~0UL << ((start) & (BITS_PER_LONG - 1)))
#define unlikely(x)     __builtin_expect(x, 0)
#define likely(x)       __builtin_expect(!!(x), 1)

static unsigned long min(unsigned long x, unsigned long y)
{
	return x < y ? x : y;
}
static inline unsigned long find_next_bit(const unsigned long *addr1,
                unsigned long nbits, unsigned long start)
{
        unsigned long tmp;

        if (unlikely(start >= nbits))
                return nbits;

        tmp = addr1[start / BITS_PER_LONG];

        /* Handle 1st word. */
        tmp &= BITMAP_FIRST_WORD_MASK(start);
        start = round_down(start, BITS_PER_LONG);

        while (!tmp) {
                start += BITS_PER_LONG;
                if (start >= nbits)
                        return nbits;

                tmp = addr1[start / BITS_PER_LONG];
        }

        return min(start + __ffs(tmp), nbits);
}

#define EXPORT_SYMBOL(x) /*nothing*/

#endif

#define bitmap_size(nbits) (BITS_TO_LONGS(nbits) * sizeof(unsigned long))

struct primes {
	struct rcu_head rcu;
	unsigned long last, sz;
	unsigned long primes[];
};

/*
 * The bitmap stores odd numbers.  Bit i corresponds to
 * the integer 2*i+1.  sz is the number of bits in the array.
 * last is the index of the last set bit (greatest prime).
 */
static const struct primes small_primes = {
	.last = (127-1)/2,
	.sz = 64,
	.primes = {
		BIT((3 - 1)/2) |
		BIT((5 - 1)/2) |
		BIT((7 - 1)/2) |
		BIT((11 - 1)/2) |
		BIT((13 - 1)/2) |
		BIT((17 - 1)/2) |
		BIT((19 - 1)/2) |
		BIT((23 - 1)/2) |
		BIT((29 - 1)/2) |
		BIT((31 - 1)/2) |
		BIT((37 - 1)/2) |
		BIT((41 - 1)/2) |
		BIT((43 - 1)/2) |
		BIT((47 - 1)/2) |
		BIT((53 - 1)/2) |
		BIT((59 - 1)/2) |
#if BITS_PER_LONG == 64
		BIT((61 - 1)/2) |
		BIT((67 - 1)/2) |
		BIT((71 - 1)/2) |
		BIT((73 - 1)/2) |
		BIT((79 - 1)/2) |
		BIT((83 - 1)/2) |
		BIT((89 - 1)/2) |
		BIT((97 - 1)/2) |
		BIT((101 - 1)/2) |
		BIT((103 - 1)/2) |
		BIT((107 - 1)/2) |
		BIT((109 - 1)/2) |
		BIT((113 - 1)/2) |
		BIT((127 - 1)/2)
#elif BITS_PER_LONG == 32
		BIT((61 - 1)/2),
		BIT((67 - 65)/2) |
		BIT((71 - 65)/2) |
		BIT((73 - 65)/2) |
		BIT((79 - 65)/2) |
		BIT((83 - 65)/2) |
		BIT((89 - 65)/2) |
		BIT((97 - 65)/2) |
		BIT((101 - 65)/2) |
		BIT((103 - 65)/2) |
		BIT((107 - 65)/2) |
		BIT((109 - 65)/2) |
		BIT((113 - 65)/2) |
		BIT((127 - 65)/2)
#else
#error "unhandled BITS_PER_LONG"
#endif
	}
};

static DEFINE_MUTEX(lock);
static const struct primes __rcu *primes = RCU_INITIALIZER(&small_primes);

static unsigned long selftest_max;

/*
 * Fallback test if we can't allocate a sieve.
 * The argument is guaranteed odd, and larger than
 * the small_primes limit.
 */
static bool slow_is_prime_number(unsigned long x)
{
	unsigned i, y = int_sqrt(x);

	assert(0);

	for (i = 3; i < y; i += 2)
		if (x % i == 0)
			return false;
	return true;
}

/*
 * Fallback test if we can't allocate a sieve.
 * The argument is guaranteed odd, and larger than
 * the small_primes limit.
 */
static unsigned long slow_next_prime_number(unsigned long x)
{
	while (x < ULONG_MAX && !slow_is_prime_number(x += 2))
		;

	return x;
}

/*
 * Return the index of the last set bit in the region.  The bit is
 * guaranteed to exist, so there's no need to check for array underrun.
 */
static unsigned long
find_last_bit(const unsigned long *addr, unsigned long size)
{
	unsigned long x, idx = BIT_WORD(size);

	do
		x = addr[--idx];
	while (!x);
	return idx * BITS_PER_LONG + __fls(x);
}

/*
 * Generate a new, larger, prime bitmap.
 * @x: the index of the number we're trying to look up.
 *
 * Called with rcu_read_lock held.  Drops it internally.
 * Returns true with it reacquired on success.
 * Returns false with it dropped on failure.
 */
static bool expand_to_next_prime(unsigned long x)
{
	const struct primes *p;
	struct primes *new;
	unsigned long sz, y;

	rcu_read_unlock();

	/* Betrand's Postulate (or Chebyshev's theorem) states that if n > 3,
	 * there is always at least one prime p between n and 2n - 2.
	 * Equivalently, if n > 1, then there is always at least one prime p
	 * such that n < p < 2n.
	 *
	 * http://mathworld.wolfram.com/BertrandsPostulate.html
	 * https://en.wikipedia.org/wiki/Bertrand's_postulate
	 */
	sz = 2 * x;
	if (sz < x)
		return false;

	sz = round_up(sz, BITS_PER_LONG);
	new = kmalloc(sizeof(*new) + bitmap_size(sz),
		      GFP_KERNEL | __GFP_NOWARN);
	if (!new)
		return false;

	mutex_lock(&lock);
	p = rcu_dereference_protected(primes, lockdep_is_held(&lock));
	if (x < p->last) {	/* Recheck under lock */
		kfree(new);
		goto unlock;
	}

	new->sz = sz;
	bitmap_fill(new->primes, sz);
	bitmap_copy(new->primes, p->primes, p->sz);
	/*
	 * This is where all the magic happens.  For each prime z in
	 * the bitmap, starting at max(z**2, p->sz*2+1), clear the bits
	 * corresponding to multiples of z.  When z**2 exceeds the new
	 * size, we're done.
	 */
	y = 0;
	for (;;) {
		unsigned long z, m;

		y = find_next_bit(new->primes, sz, y + 1);
		m = 2 * y * (y + 1);	/* The index of (2*y+1)**2 */

		if (m >= sz)
			break;
		z = 2 * y + 1;	/* The prime corresponding to index y */
		if (m < p->sz)
			m = roundup(p->sz - y, z) + y;
		do {
			__clear_bit(m, new->primes);
			m += z;
		} while (m < sz);
	}
	new->last = find_last_bit(new->primes, sz);

	BUG_ON(new->last <= x);

	rcu_assign_pointer(primes, new);
	if (p != &small_primes)
		kfree_rcu((struct primes *)p, rcu);

unlock:
	mutex_unlock(&lock);
	rcu_read_lock();
	return true;
}

static void free_primes(void)
{
	const struct primes *p;

	mutex_lock(&lock);
	p = rcu_dereference_protected(primes, lockdep_is_held(&lock));
	if (p != &small_primes) {
		rcu_assign_pointer(primes, &small_primes);
		kfree_rcu((struct primes *)p, rcu);
	}
	mutex_unlock(&lock);
}

/**
 * next_prime_number - return the next prime number
 * @x: the starting point for searching to test
 *
 * A prime number is an integer greater than 1 that is only divisible by
 * itself and 1.  The set of prime numbers is computed using the sieve of
 * Eratosthenes (on finding a prime, all multiples of that prime are removed
 * from the set) enabling a fast lookup of the next prime number larger than
 * @x. If the sieve fails (memory limitation), the search falls back to using
 * slow trial-divison, up to the value of ULONG_MAX (which is reported as the
 * final prime as a sentinel).
 *
 * Returns: the next prime number larger than @x
 */
unsigned long next_prime_number(unsigned long x)
{
	const struct primes *p;

	if (x < 2)
		return 2;
	x = (x - 1) >> 1;

	rcu_read_lock();
	while (x >= (p = rcu_dereference(primes))->last)
		if (!expand_to_next_prime(x))
			return slow_next_prime_number(2*x+3);
	x = find_next_bit(p->primes, p->last, x + 1);
	rcu_read_unlock();

	return 2*x+1;
}
EXPORT_SYMBOL(next_prime_number);

/**
 * is_prime_number - test whether the given number is prime
 * @x: the number to test
 *
 * A prime number is an integer greater than 1 that is only divisible by
 * itself and 1. Internally a cache of prime numbers is kept (to speed up
 * searching for sequential primes, see next_prime_number()), but if the number
 * falls outside of that cache, its primality is tested using trial-divison.
 *
 * Returns: true if @x is prime, false for composite numbers.
 */
bool is_prime_number(unsigned long x)
{
	const struct primes *p;
	bool result;

	if (x % 2 == 0)
		return x == 2;
	x >>= 1;

	rcu_read_lock();
	while (x >= (p = rcu_dereference(primes))->sz)
		if (!expand_to_next_prime(x))
			return slow_is_prime_number(2*x+1);
	result = test_bit(x, p->primes);
	rcu_read_unlock();

	return result;
}
EXPORT_SYMBOL(is_prime_number);

#if 0
static void dump_primes(void)
{
	const struct primes *p;
	char *buf;

	buf = kmalloc(PAGE_SIZE, GFP_KERNEL);

	rcu_read_lock();
	p = rcu_dereference(primes);

	if (buf)
		bitmap_print_to_pagebuf(true, buf, p->primes, p->sz);
	pr_info("primes.{last=%lu, .sz=%lu, .primes[]=...x%lx} = %s",
		p->last, p->sz, p->primes[BITS_TO_LONGS(p->sz) - 1], buf);

	rcu_read_unlock();

	kfree(buf);
}

static int selftest(unsigned long max)
{
	unsigned long x, last;

	if (!max)
		return 0;

	for (last = 0, x = 2; x < max; x++) {
		bool slow = slow_is_prime_number(x);
		bool fast = is_prime_number(x);

		if (slow != fast) {
			pr_err("inconsistent result for is-prime(%lu): slow=%s, fast=%s!",
			       x, slow ? "yes" : "no", fast ? "yes" : "no");
			goto err;
		}

		if (!slow)
			continue;

		if (next_prime_number(last) != x) {
			pr_err("incorrect result for next-prime(%lu): expected %lu, got %lu",
			       last, x, next_prime_number(last));
			goto err;
		}
		last = x;
	}

	pr_info("selftest(%lu) passed, last prime was %lu", x, last);
	return 0;

err:
	dump_primes();
	return -EINVAL;
}

static int __init primes_init(void)
{
	return selftest(selftest_max);
}

static void __exit primes_exit(void)
{
	free_primes();
}

module_init(primes_init);
module_exit(primes_exit);

module_param_named(selftest, selftest_max, ulong, 0400);

MODULE_AUTHOR("Intel Corporation");
MODULE_LICENSE("GPL");
#else
int
main(void)
{
	unsigned long x;

	for (x = 0; x < 1<<24; x = next_prime_number(x))
		printf("%lu\n", x);
	free_primes();
	return 0;
}
#endif

Powered by blists - more mailing lists

Powered by Openwall GNU/*/Linux Powered by OpenVZ