FK20 CUDA
fr_fft.cu
Go to the documentation of this file.
1 // bls12_381: Arithmetic for BLS12-381
2 // Copyright 2022-2023 Dag Arne Osvik
3 // Copyright 2022-2023 Luan Cardoso dos Santos
4 
5 #include "fr.cuh"
6 #include "fk20.cuh"
7 
12 extern __shared__ fr_t fr_smem[];
13 
26 __device__ void fr_fft(fr_t *output, const fr_t *input) {
27 
28  unsigned tid = threadIdx.x; // Thread number
29  unsigned l, r, w, src, dst;
30 
31  // Copy inputs to workspace
32 
33  src = tid;
34  // dst = 9 last bits of src reversed
35  asm volatile ("\n\tbrev.b32 %0, %1;" : "=r"(dst) : "r"(src << (32-9)));
36 
37  fr_cpy(fr_smem[dst], input[src]);
38 
39  src |= 256;
40  dst |= 1;
41 
42  fr_cpy(fr_smem[dst], input[src]);
43 
44  __syncthreads();
45 
47 
48  w = 0;
49  l = 2 * tid;
50  r = l | 1;
51 
52  //fr_mul(fr_smem[r], fr_roots[w]);
53  fr_addsub(fr_smem[l], fr_smem[r]);
54 
55  __syncthreads();
56 
58 
59  w = (tid & 1) << 7;
60  l = tid + (tid & -2U);
61  r = l | 2;
62 
63  if (w) fr_mul(fr_smem[r], fr_roots[w]);
64  fr_addsub(fr_smem[l], fr_smem[r]);
65 
66  __syncthreads();
67 
69 
70  w = (tid & 3) << 6;
71  l = tid + (tid & -4U);
72  r = l | 4;
73 
74  fr_mul(fr_smem[r], fr_roots[w]);
75  fr_addsub(fr_smem[l], fr_smem[r]);
76 
77  __syncthreads();
78 
80 
81  w = (tid & 7) << 5;
82  l = tid + (tid & -8U);
83  r = l | 8;
84 
85  fr_mul(fr_smem[r], fr_roots[w]);
86  fr_addsub(fr_smem[l], fr_smem[r]);
87 
88  __syncthreads();
89 
91 
92  w = (tid & 15) << 4;
93  l = tid + (tid & -16U);
94  r = l | 16;
95 
96  fr_mul(fr_smem[r], fr_roots[w]);
97  fr_addsub(fr_smem[l], fr_smem[r]);
98 
99  __syncthreads();
100 
102 
103  w = (tid & 31) << 3;
104  l = tid + (tid & -32U);
105  r = l | 32;
106 
107  fr_mul(fr_smem[r], fr_roots[w]);
108  fr_addsub(fr_smem[l], fr_smem[r]);
109 
110  __syncthreads();
111 
113 
114  w = (tid & 63) << 2;
115  l = tid + (tid & -64U);
116  r = l | 64;
117 
118  fr_mul(fr_smem[r], fr_roots[w]);
119  fr_addsub(fr_smem[l], fr_smem[r]);
120 
121  __syncthreads();
122 
124 
125  w = (tid & 127) << 1;
126  l = tid + (tid & -128U);
127  r = l | 128;
128 
129  fr_mul(fr_smem[r], fr_roots[w]);
130  fr_addsub(fr_smem[l], fr_smem[r]);
131 
132  __syncthreads();
133 
135 
136  w = (tid & 255) << 0;
137  l = tid + (tid & -256U);
138  r = l | 256;
139 
140  fr_mul(fr_smem[r], fr_roots[w]);
141  fr_addsub(fr_smem[l], fr_smem[r]);
142 
143  __syncthreads();
144 
145  // Copy results to output, no shuffle
146 
147  src = tid;
148  dst = src;
149 
150  fr_cpy(output[dst], fr_smem[src]);
151 
152  src += 256;
153  dst += 256;
154 
155  fr_cpy(output[dst], fr_smem[src]);
156 }
157 
170 __device__ void fr_ift(fr_t *output, const fr_t *input) {
171 
172  unsigned tid = threadIdx.x; // Thread number
173  unsigned l, r, w, src, dst;
174 
175  // Copy inputs to workspace, no shuffle
176 
177  src = tid;
178  dst = src;
179 
180  fr_cpy(fr_smem[dst], input[src]);
181 
182  src += 256;
183  dst += 256;
184 
185  fr_cpy(fr_smem[dst], input[src]);
186 
187  __syncthreads();
188 
190 
191  w = (tid & 255) << 0;
192  l = tid + (tid & -256U);
193  r = l | 256;
194 
195  fr_addsub(fr_smem[l], fr_smem[r]);
196  fr_mul(fr_smem[r], fr_roots[512-w]);
197 
198  __syncthreads();
199 
201 
202  w = (tid & 127) << 1;
203  l = tid + (tid & -128U);
204  r = l | 128;
205 
206  fr_addsub(fr_smem[l], fr_smem[r]);
207  fr_mul(fr_smem[r], fr_roots[512-w]);
208 
209  __syncthreads();
210 
212 
213  w = (tid & 63) << 2;
214  l = tid + (tid & -64U);
215  r = l | 64;
216 
217  fr_addsub(fr_smem[l], fr_smem[r]);
218  fr_mul(fr_smem[r], fr_roots[512-w]);
219 
220  __syncthreads();
221 
223 
224  w = (tid & 31) << 3;
225  l = tid + (tid & -32U);
226  r = l | 32;
227 
228  fr_addsub(fr_smem[l], fr_smem[r]);
229  fr_mul(fr_smem[r], fr_roots[512-w]);
230 
231  __syncthreads();
232 
234 
235  w = (tid & 15) << 4;
236  l = tid + (tid & -16U);
237  r = l | 16;
238 
239  fr_addsub(fr_smem[l], fr_smem[r]);
240  fr_mul(fr_smem[r], fr_roots[512-w]);
241 
242  __syncthreads();
243 
245 
246  w = (tid & 7) << 5;
247  l = tid + (tid & -8U);
248  r = l | 8;
249 
250  fr_addsub(fr_smem[l], fr_smem[r]);
251  fr_mul(fr_smem[r], fr_roots[512-w]);
252 
253  __syncthreads();
254 
256 
257  w = (tid & 3) << 6;
258  l = tid + (tid & -4U);
259  r = l | 4;
260 
261  fr_addsub(fr_smem[l], fr_smem[r]);
262  fr_mul(fr_smem[r], fr_roots[512-w]);
263 
264  __syncthreads();
265 
267 
268  w = (tid & 1) << 0;
269  l = tid + (tid & -2U);
270  r = l | 2;
271 
272  fr_addsub(fr_smem[l], fr_smem[r]);
273  fr_mul(fr_smem[l], fr_roots[513]); // 2**-9
274  fr_mul(fr_smem[r], fr_roots[513+w]); // w ? 2**-9/fr_roots[128] : 2**-9
275 
276  __syncthreads();
277 
279 
280  w = 0;
281  l = 2 * tid;
282  r = l | 1;
283 
284  fr_addsub(fr_smem[l], fr_smem[r]);
285  //fr_mul(fr_smem[r], fr_roots[512-w]);
286 
287  __syncthreads();
288 
289  // Copy results to output
290 
291  dst = tid;
292  // src = 9 last bits of dst reversed
293  asm volatile ("\n\tbrev.b32 %0, %1;" : "=r"(src) : "r"(dst << (32-9)));
294 
295  fr_cpy(output[dst], fr_smem[src]);
296 
297  dst |= 256;
298  src |= 1;
299 
300  fr_cpy(output[dst], fr_smem[src]);
301 }
302 
303 // Kernel wrappers for device-side FFT functions
304 
316 __global__ void fr_fft_wrapper(fr_t *output, const fr_t *input) {
317 
318  if (gridDim.y != 1) return;
319  if (gridDim.z != 1) return;
320  if (blockDim.x != 256) return;
321  if (blockDim.y != 1) return;
322  if (blockDim.z != 1) return;
323 
324  // Adjust IO pointers to point at each thread block's data
325 
326  unsigned bid = blockIdx.x; // Block number
327 
328  input += 512*bid;
329  output += 512*bid;
330 
331  fr_fft(output, input);
332 }
333 
345 __global__ void fr_ift_wrapper(fr_t *output, const fr_t *input) {
346 
347  if (gridDim.y != 1) return;
348  if (gridDim.z != 1) return;
349  if (blockDim.x != 256) return;
350  if (blockDim.y != 1) return;
351  if (blockDim.z != 1) return;
352 
353  // Adjust IO pointers to point at each thread block's data
354 
355  unsigned bid = blockIdx.x; // Block number
356 
357  input += 512*bid;
358  output += 512*bid;
359 
360  fr_ift(output, input);
361 }
362 
363 // vim: ts=4 et sw=4 si
uint64_t fr_t[4]
Subgroup element stored as a 256-bit array (a 4-element little-endian array of uint64_t)....
Definition: fr.cuh:24
__device__ void fr_addsub(fr_t &x, fr_t &y)
Computes the sum and the difference of the arguments, storing back into the arguments: (x,...
Definition: fr_addsub.cu:18
__constant__ fr_t fr_roots[515]
Table for the precomputed root-of-unity values.
Definition: fr_roots.cu:17
__device__ __host__ void fr_cpy(fr_t &z, const fr_t &x)
Copy from x into z.
Definition: fr_cpy.cu:14
__device__ void fr_mul(fr_t &z, const fr_t &x)
Multiply two residues module r z and x, stores back into z.
Definition: fr_mul.cu:13
__global__ void fr_fft_wrapper(fr_t *output, const fr_t *input)
wrapper for fr_fft: FFT for fr_t[512]
Definition: fr_fft.cu:316
__global__ void fr_ift_wrapper(fr_t *output, const fr_t *input)
wrapper for fr_ift: inverse FFT for fr_t[512]
Definition: fr_fft.cu:345
__shared__ fr_t fr_smem[]
Workspace in shared memory. Must be 512*sizeof(fr_t) bytes.
__device__ void fr_ift(fr_t *output, const fr_t *input)
Inverse FFT for fr_t[512].
Definition: fr_fft.cu:170
__device__ void fr_fft(fr_t *output, const fr_t *input)
FFT over Fr.
Definition: fr_fft.cu:26