FK20 CUDA
fk20_poly2hext_fft.cu
Go to the documentation of this file.
1 // bls12_381: Arithmetic for BLS12-381
2 // Copyright 2022-2023 Dag Arne Osvik
3 // Copyright 2022-2023 Luan Cardoso dos Santos
4 
5 #include "fr.cuh"
6 #include "g1.cuh"
7 #include "fk20.cuh"
8 
9 static __device__ fr_t fr_tmp[512*512]; // 16 KiB memory per threadblock
10 
24 __global__ void fk20_poly2hext_fft(g1p_t *hext_fft, const fr_t *polynomial, const g1p_t xext_fft[8192]) {
25 
26  // gridDim.x is the number of rows
27  if (gridDim.y != 1) return;
28  if (gridDim.z != 1) return;
29  if (blockDim.x != 256) return; // k
30  if (blockDim.y != 1) return;
31  if (blockDim.z != 1) return;
32 
33  unsigned tid = threadIdx.x; // Thread number
34  unsigned bid = blockIdx.x; // Block number
35 
36  // Accumulators and temporaries in registers or local
37  // (thread-interleaved global) memory
38 
39  g1p_t a0, a1, t;
40 
41  g1p_inf(a0);
42  g1p_inf(a1);
43 
44  polynomial += 4096 * bid;
45  hext_fft += 512 * bid;
46 
47  fr_t *fr = fr_tmp + 512 * bid;
48 
49  // MSM Loop
50 
51  for (int i=0; i<16; i++) {
52 
53  // Copy from the polynomial into half of the coefficient array
54 
55  unsigned src = 16*tid + 15 - i;
56  unsigned dst = (tid+257) % 512;
57 
58  if (tid > 0)
59  fr_cpy(fr[dst], polynomial[src]);
60  else
61  fr_zero(fr[dst]);
62 
63  __syncthreads();
64 
65  // Zero the other half of coefficients before FFT
66 
67  fr_zero(fr[tid+1]);
68 
69  // Compute FFT
70 
71  __syncthreads();
72  fr_fft(fr, fr);
73  __syncthreads();
74 
75  // Multiply and accumulate
76 
77  g1p_cpy(t, xext_fft[512*i + tid + 0]);
78  g1p_mul(t, fr[tid]);
79  __syncthreads();
80  g1p_add(a0, t);
81 
82  g1p_cpy(t, xext_fft[512*i + tid + 256]);
83  g1p_mul(t, fr[tid+256]);
84  __syncthreads();
85  g1p_add(a1, t);
86  }
87 
88  // Store accumulators
89 
90  g1p_cpy(hext_fft[tid+ 0], a0);
91  g1p_cpy(hext_fft[tid+256], a1);
92 }
93 
94 // vim: ts=4 et sw=4 si
__managed__ g1p_t xext_fft[16][512]
__managed__ g1p_t hext_fft[512 *512]
__managed__ fr_t polynomial[512 *4096]
__global__ void fk20_poly2hext_fft(g1p_t *hext_fft, const fr_t *polynomial, const g1p_t xext_fft[8192])
polynomial + xext_fft -> hext_fft
__device__ __host__ void fr_zero(fr_t &z)
Sets the value of z to zero.
Definition: fr.cu:15
uint64_t fr_t[4]
Subgroup element stored as a 256-bit array (a 4-element little-endian array of uint64_t)....
Definition: fr.cuh:24
__device__ __host__ void fr_cpy(fr_t &z, const fr_t &x)
Copy from x into z.
Definition: fr_cpy.cu:14
__device__ void fr_fft(fr_t *output, const fr_t *input)
FFT over Fr.
Definition: fr_fft.cu:26
__device__ __host__ void g1p_inf(g1p_t &p)
Set p to the point-at-infinity (0,1,0)
Definition: g1p.cu:93
__device__ void g1p_add(g1p_t &p, const g1p_t &q)
Computes the sum of two points q into p, using projective coordinates. and stores in p.
Definition: g1p_add.cu:29
__device__ void g1p_mul(g1p_t &p, const fr_t &x)
p ← k·p Point multiplication by scalar, in projective coordinates. That result is stored back into p.
Definition: g1p_mul.cu:19
__device__ __host__ void g1p_cpy(g1p_t &p, const g1p_t &q)
Copy from q into p.
Definition: g1p.cu:67
G1 point in projective coordinates.
Definition: g1.cuh:27