FK20 CUDA
g1p_fft.cu
Go to the documentation of this file.
1 // bls12_381: Arithmetic for BLS12-381
2 // Copyright 2022-2023 Dag Arne Osvik
3 // Copyright 2022-2023 Luan Cardoso dos Santos
4 
5 #include <stdio.h>
6 
7 #include "g1.cuh"
8 #include "fk20.cuh"
9 
10 // Workspace in shared memory
11 
12 extern __shared__ g1p_t g1p_tmp[];
13 
24 __device__ void g1p_fft(g1p_t *output, const g1p_t *input) {
25 
26  if (gridDim.y != 1) return;
27  if (gridDim.z != 1) return;
28  if (blockDim.x != 256) return;
29  if (blockDim.y != 1) return;
30  if (blockDim.z != 1) return;
31 
32  unsigned tid = threadIdx.x; // Thread number
33  unsigned bid = blockIdx.x; // Block number
34  unsigned l, r, w, src, dst;
35 
36  // Adjust IO pointers to point at each thread block's data
37 
38  input += 512*bid;
39  output += 512*bid;
40 
41  // Copy inputs to workspace
42 
43  src = tid;
44  // dst = 9 last bits of src reversed
45  asm volatile ("\n\tbrev.b32 %0, %1;" : "=r"(dst) : "r"(src << (32-9)));
46 
47  g1p_cpy(g1p_tmp[dst], input[src]);
48 
49  src |= 256;
50  dst |= 1;
51 
52  g1p_cpy(g1p_tmp[dst], input[src]);
53 
54  __syncthreads();
55 
57 
58  w = 0;
59  l = 2 * tid;
60  r = l | 1;
61 
62  //g1p_mul(g1p_tmp[r], fr_roots[w]);
63  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
64 
65  __syncthreads();
66 
68 
69  w = (tid & 1) << 7;
70  l = tid + (tid & -2U);
71  r = l | 2;
72 
73  if (w) g1p_mul(g1p_tmp[r], fr_roots[w]);
74  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
75 
76  __syncthreads();
77 
79 
80  w = (tid & 3) << 6;
81  l = tid + (tid & -4U);
82  r = l | 4;
83 
84  g1p_mul(g1p_tmp[r], fr_roots[w]);
85  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
86 
87  __syncthreads();
88 
90 
91  w = (tid & 7) << 5;
92  l = tid + (tid & -8U);
93  r = l | 8;
94 
95  g1p_mul(g1p_tmp[r], fr_roots[w]);
96  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
97 
98  __syncthreads();
99 
101 
102  w = (tid & 15) << 4;
103  l = tid + (tid & -16U);
104  r = l | 16;
105 
106  g1p_mul(g1p_tmp[r], fr_roots[w]);
107  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
108 
109  __syncthreads();
110 
112 
113  w = (tid & 31) << 3;
114  l = tid + (tid & -32U);
115  r = l | 32;
116 
117  g1p_mul(g1p_tmp[r], fr_roots[w]);
118  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
119 
120  __syncthreads();
121 
123 
124  w = (tid & 63) << 2;
125  l = tid + (tid & -64U);
126  r = l | 64;
127 
128  g1p_mul(g1p_tmp[r], fr_roots[w]);
129  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
130 
131  __syncthreads();
132 
134 
135  w = (tid & 127) << 1;
136  l = tid + (tid & -128U);
137  r = l | 128;
138 
139  g1p_mul(g1p_tmp[r], fr_roots[w]);
140  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
141 
142  __syncthreads();
143 
145 
146  w = (tid & 255) << 0;
147  l = tid + (tid & -256U);
148  r = l | 256;
149 
150  g1p_mul(g1p_tmp[r], fr_roots[w]);
151  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
152 
153  __syncthreads();
154 
155  // Copy results to output, no shuffle
156 
157  src = tid;
158  dst = src;
159 
160  g1p_cpy(output[dst], g1p_tmp[src]);
161 
162  src += 256;
163  dst += 256;
164 
165  g1p_cpy(output[dst], g1p_tmp[src]);
166 }
167 
178 __device__ void g1p_ift(g1p_t *output, const g1p_t *input) {
179 
180  if (gridDim.y != 1) return;
181  if (gridDim.z != 1) return;
182  if (blockDim.x != 256) return;
183  if (blockDim.y != 1) return;
184  if (blockDim.z != 1) return;
185 
186  unsigned tid = threadIdx.x; // Thread number
187  unsigned bid = blockIdx.x; // Block number
188  unsigned l, r, w, src, dst;
189 
190  // Adjust IO pointers to point at each thread block's data
191 
192  input += 512*bid;
193  output += 512*bid;
194 
195  // Copy inputs to workspace, no shuffle
196 
197  src = tid;
198  dst = src;
199 
200  g1p_cpy(g1p_tmp[dst], input[src]);
201 
202  src += 256;
203  dst += 256;
204 
205  g1p_cpy(g1p_tmp[dst], input[src]);
206 
207  __syncthreads();
208 
210 
211  w = (tid & 255) << 0;
212  l = tid + (tid & -256U);
213  r = l | 256;
214 
215  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
216  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
217 
218  __syncthreads();
219 
221 
222  w = (tid & 127) << 1;
223  l = tid + (tid & -128U);
224  r = l | 128;
225 
226  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
227  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
228 
229  __syncthreads();
230 
232 
233  w = (tid & 63) << 2;
234  l = tid + (tid & -64U);
235  r = l | 64;
236 
237  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
238  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
239 
240  __syncthreads();
241 
243 
244  w = (tid & 31) << 3;
245  l = tid + (tid & -32U);
246  r = l | 32;
247 
248  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
249  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
250 
251  __syncthreads();
252 
254 
255  w = (tid & 15) << 4;
256  l = tid + (tid & -16U);
257  r = l | 16;
258 
259  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
260  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
261 
262  __syncthreads();
263 
265 
266  w = (tid & 7) << 5;
267  l = tid + (tid & -8U);
268  r = l | 8;
269 
270  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
271  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
272 
273  __syncthreads();
274 
276 
277  w = (tid & 3) << 6;
278  l = tid + (tid & -4U);
279  r = l | 4;
280 
281  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
282  g1p_mul(g1p_tmp[r], fr_roots[512-w]);
283 
284  __syncthreads();
285 
287 
288  w = (tid & 1) << 0;
289  l = tid + (tid & -2U);
290  r = l | 2;
291 
292  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
293  g1p_mul(g1p_tmp[l], fr_roots[513]); // 2**-9
294  g1p_mul(g1p_tmp[r], fr_roots[513+w]); // w ? 2**-9/fr_roots[128] : 2**-9
295 
296  __syncthreads();
297 
299 
300  w = 0;
301  l = 2 * tid;
302  r = l | 1;
303 
304  g1p_addsub(g1p_tmp[l], g1p_tmp[r]);
305  //g1p_mul(g1p_tmp[r], fr_roots[512-w]);
306 
307  __syncthreads();
308 
309  // Copy results to output
310 
311  dst = tid;
312  // src = 9 last bits of dst reversed
313  asm volatile ("\n\tbrev.b32 %0, %1;" : "=r"(src) : "r"(dst << (32-9)));
314 
315  g1p_cpy(output[dst], g1p_tmp[src]);
316 
317  dst |= 256;
318  src |= 1;
319 
320  g1p_cpy(output[dst], g1p_tmp[src]);
321 }
322 
323 // Kernel wrappers for device-side FFT functions
324 
336 __global__ void g1p_fft_wrapper(g1p_t *output, const g1p_t *input) { g1p_fft(output, input); }
337 
349 __global__ void g1p_ift_wrapper(g1p_t *output, const g1p_t *input) { g1p_ift(output, input); }
350 
351 // vim: ts=4 et sw=4 si
__constant__ fr_t fr_roots[515]
Table for the precomputed root-of-unity values.
Definition: fr_roots.cu:17
__device__ void g1p_addsub(g1p_t &p, g1p_t &q)
Stores the sum and difference of p and q into p and q. Projective p and q, p,q ← p+q,...
Definition: g1p_addsub.cu:18
__device__ void g1p_mul(g1p_t &p, const fr_t &x)
p ← k·p Point multiplication by scalar, in projective coordinates. That result is stored back into p.
Definition: g1p_mul.cu:19
__device__ __host__ void g1p_cpy(g1p_t &p, const g1p_t &q)
Copy from q into p.
Definition: g1p.cu:67
__device__ void g1p_ift(g1p_t *output, const g1p_t *input)
Inverse FFT of size 512 over G1 with projective coordinates. Input and output arrays may overlap....
Definition: g1p_fft.cu:178
__shared__ g1p_t g1p_tmp[]
__global__ void g1p_fft_wrapper(g1p_t *output, const g1p_t *input)
wrapper for g1p_fft: FFT for arrays of g1p_t with length 512
Definition: g1p_fft.cu:336
__device__ void g1p_fft(g1p_t *output, const g1p_t *input)
FFT of size 512 over G1 with projective coordinates. Input and output arrays may overlap....
Definition: g1p_fft.cu:24
__global__ void g1p_ift_wrapper(g1p_t *output, const g1p_t *input)
wrapper for g1p_ift: inverse FFT for arrays of g1p_t with length 512
Definition: g1p_fft.cu:349
G1 point in projective coordinates.
Definition: g1.cuh:27