2 minute read

Enumerate all Primes to n - Elements of Programming Interviews (EPI)

Problem:

Write a function that takes integer \(n\) and returns a list of all primes from 2 to \(n\). If \(n=10\) we should return [2,3,5,7]

Solution:

The brute force approach would be to iterate over every number and check if its prime or not. This is similar to our approach where we iterated over the odds and brute force checked if the number is prime, if it is we append it to our prime list.

def generate_primes(n: int) -> List[int]:
    """Function to enumerate all primes upto n"""
    if n >= 2:
        primes = [2]
    else:
        return []
    tmp = 3
    isprime = True
    while tmp <= n:
        for p in primes:
            if tmp % p == 0:
                isprime = False
                break
            if p**2 > tmp:
                break
        if isprime:
            primes.append(tmp)
        tmp += 2
        isprime = True
    return primes

There is a better solution which involves taking advantage of the fact that we are interested in enumerating all primes upto n. So whenever we find a number that is prime we can ignore all of its multiples upto \(n\) and not check them. This approach involves trading space for time. We can keep an array which keeps track of all numbers that can possibly be prime. We will only keep track of odd numbers upto \(n\) so this each index i in the array represents the integer 2*i + 3 so index 0 will represent 3. We initialize this array to be all True and whenever we encounter a prime, we remove all of its multiples from this array by storing False in its location. This method is called sieving.

There is another improvement we can make to this approach. Instead of removing all multiples of the prime by starting from its current position, we can instead start iterating from \(p^2\) since previous sieving iterations would have already removed all smaller multiples of this prime. In code:

def generate_primes(n: int) -> List[int]:

    if n < 2:
        return []
    size = (n - 3) // 2 + 1
    primes = [2]  # Stores the primes from 1 to n.
    # is_prime[i] represents (2i + 3) is prime or not.
    # For example, is_prime[0] represents 3 is prime or not, is_prime[1]
    # represents 5, is_prime[2] represents 7, etc.
    # Initially set each to true. Then use sieving to eliminate nonprimes.
    is_prime = [True] * size
    for i in range(size):
        if is_prime[i]:
            p = i * 2 + 3
            primes.append(p)
            # Sieving from p^2, where p^2 = (4i^2 + 12i + 9). The index in is_prime
            # is (2i^2 + 6i + 3) because is_prime[i] represents 2i + 3.
            #
            # Note that we need to use long for j because p^2 might overflow.
            for j in range(2 * i**2 + 6 * i + 3, size, p):
                is_prime[j] = False
    return primes