FK20 CUDA
frtest_fft.cu
Go to the documentation of this file.
1 // bls12_381: Arithmetic for BLS12-381
2 // Copyright 2022-2023 Dag Arne Osvik
3 // Copyright 2022-2023 Luan Cardoso dos Santos
4 
5 #include "fr.cuh"
6 #include "frtest.cuh"
7 
8 #include "fr_fft_testvector.cu"
9 
10 __managed__ fr_t fft[512];
11 __managed__ uint8_t cmp[512];
12 
17 void FrTestFFT() {
18 
19  const size_t sharedmem = 512*4*8; // 512 residues, 4 words/residue, 8 bytes/word
20  cudaError_t err;
21 
22  bool pass = true;
23 
24  err = cudaFuncSetAttribute(fr_fft_wrapper, cudaFuncAttributeMaxDynamicSharedMemorySize, sharedmem);
25  cudaDeviceSynchronize();
26  if (err != cudaSuccess) printf("Error cudaFuncSetAttribute: %d (%s)\n", err, cudaGetErrorName(err));
27 
28  err = cudaFuncSetAttribute(fr_ift_wrapper, cudaFuncAttributeMaxDynamicSharedMemorySize, sharedmem);
29  cudaDeviceSynchronize();
30  if (err != cudaSuccess) printf("Error cudaFuncSetAttribute: %d (%s)\n", err, cudaGetErrorName(err));
31 
33 
34  printf("=== RUN %s\n", "fr_fft");
35  fr_fft_wrapper<<<1, 256, sharedmem>>>(fft, q);
36 
37  err = cudaDeviceSynchronize();
38  if (err != cudaSuccess) printf("Error fr_fft_wrapper: %d (%s)\n", err, cudaGetErrorName(err));
39 
40  // Clear comparison results
41 
42  for (int i=0; i<512; i++)
43  cmp[i] = 0;
44 
45  fr_eq_wrapper<<<16, 32>>>(cmp, 512, fft, a);
46 
47  err = cudaDeviceSynchronize();
48  if (err != cudaSuccess) printf("Error fr_eq_wrapper: %d (%s)\n", err, cudaGetErrorName(err));
49 
50  // Check FFT result
51 
52  for (int i=0; pass && i<512; i++)
53  if (cmp[i] != 1) {
54  printf("FFT error %d\n", i);
55  pass = false;
56  }
57 
58  PRINTPASS(pass);
59 
61 
62  printf("=== RUN %s\n", "fr_ift");
63  fr_ift_wrapper<<<1, 256, sharedmem>>>(fft, a);
64 
65  err = cudaDeviceSynchronize();
66  if (err != cudaSuccess) printf("Error fr_ift_wrapper: %d (%s)\n", err, cudaGetErrorName(err));
67 
68  // Clear comparison results
69 
70  for (int i=0; i<512; i++)
71  cmp[i] = 0;
72 
73  fr_eq_wrapper<<<16, 32>>>(cmp, 512, fft, q);
74 
75  err = cudaDeviceSynchronize();
76  if (err != cudaSuccess) printf("Error fr_eq_wrapper: %d (%s)\n", err, cudaGetErrorName(err));
77 
78  // Check IFT result
79 
80  for (int i=0; pass && i<512; i++)
81  if (cmp[i] != 1) {
82  printf("IFT error %d\n", i);
83  pass = false;
84  }
85 
86  PRINTPASS(pass);
87 
89 }
90 
91 // vim: ts=4 et sw=4 si
uint64_t fr_t[4]
Subgroup element stored as a 256-bit array (a 4-element little-endian array of uint64_t)....
Definition: fr.cuh:24
__global__ void fr_fft_wrapper(fr_t *output, const fr_t *input)
wrapper for fr_fft: FFT for fr_t[512]
Definition: fr_fft.cu:316
__global__ void fr_ift_wrapper(fr_t *output, const fr_t *input)
wrapper for fr_ift: inverse FFT for fr_t[512]
Definition: fr_fft.cu:345
__managed__ uint8_t cmp[512]
Definition: frtest_fft.cu:11
void FrTestFFT()
Tests fft and inverse fft over Fr using KAT.
Definition: frtest_fft.cu:17
__managed__ fr_t fft[512]
Definition: frtest_fft.cu:10
#define PRINTPASS(pass)
Definition: test.h:25