primes_of_bounded_norm

307 days ago by andrew.ohana

%cython include 'sage/libs/pari/decl.pxi' from sage.libs.pari.gen import pari from libc.stdint cimport uint8_t, uint_fast8_t, uint32_t, uint_fast32_t, uint_fast64_t cdef extern from "pari/pari.h": cdef void NEXT_PRIME_VIADIFF(uint32_t, uint_fast8_t *) cdef uint_fast8_t[4] unitTab unitTab[0] = 1u unitTab[1] = 11u unitTab[2] = 19u unitTab[3] = 29u cdef uint_fast8_t[4] diffTab diffTab[0] = 10u diffTab[1] = 8u diffTab[2] = 10u diffTab[3] = 2u cdef uint_fast8_t[30] invTab invTab[1] = 1u invTab[7] = 13u invTab[11] = 11u invTab[13] = 7u invTab[17] = 23u invTab[19] = 19u invTab[23] = 17u invTab[29] = 29u cdef uint_fast32_t[32] shiftTab shiftTab[ 0] = 0x00000001u shiftTab[ 1] = 0x00000002u shiftTab[ 2] = 0x00000004u shiftTab[ 3] = 0x00000008u shiftTab[ 4] = 0x00000010u shiftTab[ 5] = 0x00000020u shiftTab[ 6] = 0x00000040u shiftTab[ 7] = 0x00000080u shiftTab[ 8] = 0x00000100u shiftTab[ 9] = 0x00000200u shiftTab[10] = 0x00000400u shiftTab[11] = 0x00000800u shiftTab[12] = 0x00001000u shiftTab[13] = 0x00002000u shiftTab[14] = 0x00004000u shiftTab[15] = 0x00008000u shiftTab[16] = 0x00010000u shiftTab[17] = 0x00020000u shiftTab[18] = 0x00040000u shiftTab[19] = 0x00080000u shiftTab[20] = 0x00100000u shiftTab[21] = 0x00200000u shiftTab[22] = 0x00400000u shiftTab[23] = 0x00800000u shiftTab[24] = 0x01000000u shiftTab[25] = 0x02000000u shiftTab[26] = 0x04000000u shiftTab[27] = 0x08000000u shiftTab[28] = 0x10000000u shiftTab[29] = 0x20000000u shiftTab[30] = 0x40000000u shiftTab[31] = 0x80000000u cdef uint_fast8_t[256] twoDiv cdef uint_fast32_t tempItr, tempVar twoDiv[0] = 8u for tempItr in range(1,255u): tempVar = tempItr while not tempVar&1u: twoDiv[tempItr] += 1u tempVar >>= 1u cdef uint_fast32_t exp_mod(uint_fast64_t b, uint_fast32_t e, uint_fast32_t p): cdef uint_fast64_t q if e&1u: q = b else: q = 1ull e >>= 1u while e: b *= b if b > 4294967295ull: b %= p if e&1u: q *= b if q > 4294967295ull: q %= p e >>= 1u if q > 4294967295ull: q %= p return q cdef uint_fast32_t non_residue(uint_fast32_t p): cdef uint8_t *pariPrimePtr = <uint8_t *>diffptr cdef uint32_t pariP = 0u NEXT_PRIME_VIADIFF(pariP, pariPrimePtr) while True: NEXT_PRIME_VIADIFF(pariP, pariPrimePtr) if exp_mod(p,pariP>>1u,pariP)%pariP > 1u: return pariP cdef uint_fast32_t sqrt5_mod(uint_fast32_t p): if p&3u == 3u: return exp_mod(5ull, (p>>2u)+1u, p)%p cdef uint_fast64_t q, z if p&7u == 5u: q = exp_mod(10ull, p>>3u, p) z = q*q if z > 1844674407370955161ull: z %= p z *= 10ull z -= 1ull if z > 4294967295ull: z %= p q *= 5ull if q > 4294967295ull: q %= p return q*z%p cdef uint_fast64_t d, dp if p&15u == 9u: q = exp_mod(10ull, p>>4u, p) if q > 4294967295ull: q %= p z = q*q if z > 1844674407370955161ull: z %= p z *= 10ull if z > 4294967295ull: z %= p q *= 5ull if q > 4294967295ull: q %= p if z*z%p == 1ull: d = non_residue(p) dp = exp_mod(d, p>>3u, p) q *= dp if q > 4294967295ull: q %= p dp *= dp if dp > 4294967295ull: dp %= p z *= dp if z > 4294967295ull: z %= p z -= 1ull return q*z%p cdef uint_fast8_t r = 4u q = p>>r while not q&1u: r += twoDiv[q&0xFFu] q = p>>r r -= 2u cdef uint_fast64_t v = exp_mod(non_residue(p), q, p) d = exp_mod(5ull, q>>1u, p) cdef uint_fast64_t res = 5ull*d if res > 4294967295ull: res %= p d *= d if d > 3689348814741910323ull: d %= p d *= 5ull d %= p cdef uint_fast8_t m while d != 1u: m = r dp = d*d%p while dp != 1u: dp *= dp dp %= p m -= 1u z = exp_mod(v, shiftTab[m], p) res *= z if res > 4294967295ull: res %= p z *= z if z > 4294967295ull: z %= p d *= z d %= p return res%p def get_primes(uint32_t p): cdef uint32_t q = sqrt5_mod(p)+1u cdef uint32_t w = ((p+1u)>>1u)*q%p q = p+1u-w if (q<<1u) > p: return [(int(p),int(q)-int(p)),(int(p),int(w))] return [(int(p),int(w)-int(p)),(int(p),int(q))] def primes_of_bounded_norm(uint_fast32_t B, py_tuples = None): if py_tuples is None: py_tuples = False if B <= 30u: r = [] if B >= 4u: r.append((2,0)) if B >= 5u: r.append((5,-2)) if B >= 9u: r.append((3,0)) if B >= 11u: r.append((11,-3)) r.append((11,4)) if B >= 19u: r.append((19,-4)) r.append((19,5)) if B >= 29u: r.append((29,-5)) r.append((29,6)) if py_tuples: return r R = sage.rings.polynomial.polynomial_ring_constructor.PolynomialRing(sage.rings.integer_ring.ZZ, 'x') K = sage.rings.number_field.number_field.NumberField(R([-1,-1,1]), names='a') a = K.gen(0) idealR = [] for P in r: if P[1]: idealRet.append(K.ideal(P[0],a-P[1])) else: idealRet.append(K.ideal(P[0])) return idealR cdef mpz_t t mpz_init(t) mpz_set_ui(t,B) mpz_sqrt(t,t) cdef uint_fast32_t piSqrtB,low mpz_export(&low,NULL,-1,sizeof(uint_fast32_t),0,0,t) mpz_clear(t) low += 30u-(low%30u) piSqrtB = pari(low).primepi() cdef uint8_t *pariPrimePtr = <uint8_t *>diffptr cdef uint32_t pariP = 0u cdef uint32_t *smallP = <uint32_t *>sage_malloc(piSqrtB*sizeof(uint32_t)) cdef uint32_t *prime, *bound prime = smallP; bound = smallP+piSqrtB while prime != bound: NEXT_PRIME_VIADIFF(pariP, pariPrimePtr) prime[0] = pariP prime += 1 cdef uint_fast32_t D = (B-low+960u-((B-low)%960u))/30u cdef uint32_t *buf[4] cdef uint_fast8_t k cdef uint_fast32_t j,i for k in range(4u): buf[k] = <uint32_t *>sage_malloc(sizeof(uint32_t)*D>>5u) for j in range(D>>5u): buf[k][j] = 0u prime = smallP+3 cdef uint_fast32_t p while True: p = prime[0] if p*p > B: break prime += 1 for k in range(4u): i = unitTab[k] j = p-1u-(low-1u)%p if i >= j: j += p*(invTab[p%30u]*(i-j)%30u) else: j += p*(30u-invTab[p%30u]*(j-i)%30u) j /= 30u while j < D: buf[k][j>>5u] |= shiftTab[j&31u] j += p cdef list ret = [(2,0),(5,-2),(3,0)] cdef uint32_t *inertPrimeBound, *inertP inertPrimeBound = smallP prime = smallP+3 inertP = smallP inertPV = 7u while prime != bound: p = prime[0] prime += 1 if p%5u == 1u or p%5u == 4u: while inertP != inertPrimeBound and inertP[0]*inertP[0] < p: ret.append((int(inertP[0]),0)) inertP += 1 ret += get_primes(p) elif p*p <= B: inertPrimeBound[0] = p inertPrimeBound += 1 p = low+1u cdef uint_fast8_t finished = 0 for j in range(D): for k in range(4u): if p > B: finished = 1 break if not buf[k][j>>5u]&shiftTab[j&31u]: while inertP != inertPrimeBound and inertP[0]*inertP[0] < p: ret.append((int(inertP[0]),0)) inertP += 1 ret += get_primes(p) p += diffTab[k] if finished: break while inertP != inertPrimeBound: ret.append((int(inertP[0]),0)) inertP += 1 sage_free(smallP) if py_tuples: return ret R = sage.rings.polynomial.polynomial_ring_constructor.PolynomialRing(sage.rings.integer_ring.ZZ, 'x') K = sage.rings.number_field.number_field.NumberField(R([-1,-1,1]), names='a') a = K.gen(0) idealRet = [] for P in ret: if P[1]: idealRet.append(K.ideal(P[0],a-P[1])) else: idealRet.append(K.ideal(P[0])) return idealRet 
def williams_primes_of_bounded_norm(B): R = sage.rings.polynomial.polynomial_ring_constructor.PolynomialRing(sage.rings.integer_ring.ZZ, 'x') K = sage.rings.number_field.number_field.NumberField(R([-1,-1,1]), names='a') v = sum([K.primes_above(p) for p in primes(B+1)],[]) v = [(p.norm(), p) for p in v if p.norm() <= B] v.sort() return [p[1] for p in v] 
       
L = williams_primes_of_bounded_norm(10**4) R = primes_of_bounded_norm(10**4) # set comparison seems to be broken print set(L) == set(R) print str(set(L)) == str(set(R)) equal = True for l in L: if l not in R: equal = False for r in R: if r not in L: equal = False print equal 
       
False
False
False
False
False
False
timeit('williams_primes_of_bounded_norm(10**4)') 
       
5 loops, best of 3: 1.57 s per loop
5 loops, best of 3: 1.57 s per loop
timeit('primes_of_bounded_norm(10**4)') 
       
5 loops, best of 3: 100 ms per loop
5 loops, best of 3: 100 ms per loop
timeit('primes_of_bounded_norm(10**4,py_tuples=True)') 
       
625 loops, best of 3: 292 µs per loop
625 loops, best of 3: 292 µs per loop
time L = primes_of_bounded_norm(10**6) 
       
Time: CPU 12.11 s, Wall: 12.11 s
Time: CPU 12.11 s, Wall: 12.11 s