1 #ifndef MHO_OpenCLBatchedMultidimensionalFastFourierTransform_HH__
2 #define MHO_OpenCLBatchedMultidimensionalFastFourierTransform_HH__
17 #define ENFORCE_CL_FINISH
26 "Array element type must be a complex floating point type.");
31 :
MHO_Operator(), fIsValid(false), fIsForward(true), fInitialized(false), fAllSpatialDimensionsAreEqual(true),
32 fFillFromHostData(true), fReadOutDataToHost(true), fMaxBufferSize(0), fTotalDataSize(0), fOpenCLFlags(
""),
33 fFFTKernel(nullptr), fSpatialDimensionBufferCL(nullptr), fTwiddleBufferCL(nullptr),
34 fConjugateTwiddleBufferCL(nullptr), fScaleBufferCL(nullptr), fCirculantBufferCL(nullptr), fDataBufferCL(nullptr),
35 fPermuationArrayCL(nullptr), fWorkspaceBufferCL(nullptr), fNLocal(0)
43 delete fSpatialDimensionBufferCL;
44 delete fTwiddleBufferCL;
45 delete fConjugateTwiddleBufferCL;
46 delete fScaleBufferCL;
47 delete fCirculantBufferCL;
49 delete fPermuationArrayCL;
50 delete fWorkspaceBufferCL;
53 virtual void SetInput(XArgType* in) { fInput = in; };
55 virtual void SetOutput(XArgType* out) { fOutput = out; };
98 std::cout <<
"initializing" << std::endl;
101 if(DoInputOutputDimensionsMatch())
104 this->fInput->GetDimensions(fDimensionSize);
105 for(
unsigned int i = 0; i < XArgType::rank::value - 1; i++)
107 fSpatialDim[i] = fDimensionSize[i + 1];
112 std::cout <<
"dim mismatch" << std::endl;
114 fInitialized =
false;
119 std::cout <<
"its valid" << std::endl;
120 ConstructWorkspace();
121 ConstructOpenCLKernels();
134 if(fIsValid && fInitialized)
137 unsigned int n_multdim_ffts = fDimensionSize[0];
138 fFFTKernel->setArg(0, n_multdim_ffts);
141 fFFTKernel->setArg(2, 1);
145 fFFTKernel->setArg(2, 0);
152 for(
unsigned int D = 0; D < XArgType::rank::value; D++)
155 unsigned int n_global = fDimensionSize[0];
156 unsigned int n_local_1d_transforms = 1;
157 for(
unsigned int i = 0; i < XArgType::rank::value - 1; i++)
161 n_global *= fSpatialDim[i];
162 n_local_1d_transforms *= fSpatialDim[i];
167 unsigned int nDummy = fNLocal - (n_global % fNLocal);
168 if(nDummy == fNLocal)
174 cl::NDRange global(n_global);
175 cl::NDRange local(fNLocal);
177 fFFTKernel->setArg(1, D);
179 if(fAllSpatialDimensionsAreEqual)
185 #ifdef ENFORCE_CL_FINISH
193 *fTwiddleBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2), &(fTwiddle[D][0]));
194 #ifdef ENFORCE_CL_FINISH
198 0, fMaxBufferSize *
sizeof(
CL_TYPE2),
199 &(fConjugateTwiddle[D][0]));
200 #ifdef ENFORCE_CL_FINISH
204 *fScaleBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2), &(fScale[D][0]));
205 #ifdef ENFORCE_CL_FINISH
209 *fCirculantBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2), &(fCirculant[D][0]));
210 #ifdef ENFORCE_CL_FINISH
214 *fPermuationArrayCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
unsigned int), &(fPermuationArray[D][0]));
215 #ifdef ENFORCE_CL_FINISH
222 #ifdef ENFORCE_CL_FINISH
235 <<
"MHO_OpenCLBatchedMultidimensionalFastFourierTransform::Execute: Not valid and initialized. Aborting."
244 void ConstructWorkspace()
247 fTotalDataSize = MHO_NDArrayMath::TotalArraySize< XArgType::rank::value >(fDimensionSize);
251 fAllSpatialDimensionsAreEqual =
true;
252 unsigned int previous_dim = fSpatialDim[0];
253 for(
unsigned int i = 0; i < XArgType::rank::value - 1; i++)
255 if(previous_dim != fSpatialDim[i])
257 fAllSpatialDimensionsAreEqual =
false;
260 if(fSpatialDim[i] > fMaxBufferSize)
262 fMaxBufferSize = fSpatialDim[i];
268 fSpatialDim[i]) > fMaxBufferSize)
278 std::stringstream ss;
279 ss <<
" -D FFT_NDIM=" << XArgType::rank::value;
280 ss <<
" -D FFT_BUFFERSIZE=" << fMaxBufferSize;
291 fOpenCLFlags = ss.str();
294 for(
unsigned int i = 0; i < XArgType::rank::value - 1; i++)
296 unsigned int N = fSpatialDim[i];
305 fTwiddle[i].resize(M);
306 fConjugateTwiddle[i].resize(M);
308 fCirculant[i].resize(M);
309 fPermuationArray[i].resize(M);
316 M, &(fConjugateTwiddle[i][0]));
323 N, M, &(fTwiddle[i][0]), &(fScale[i][0]), &(fCirculant[i][0]));
327 std::cout <<
"build workspace" << std::endl;
333 void ConstructOpenCLKernels()
335 std::cout <<
"opencl kernels" << std::endl;
337 std::stringstream clFile;
339 <<
"/MHO_MultidimensionalFastFourierTransform_kernel.cl";
342 std::stringstream options;
345 MHO_OpenCLKernelBuilder k_builder;
347 k_builder.BuildKernel(clFile.str(), std::string(
"MultidimensionalFastFourierTransform_Stage"), options.str());
355 fPreferredWorkgroupMultiple = fFFTKernel->getWorkGroupInfo< CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE >(
358 if(fPreferredWorkgroupMultiple < fNLocal)
360 fNLocal = fPreferredWorkgroupMultiple;
369 std::cout <<
"building buffers" << std::endl;
372 XArgType::rank::value *
sizeof(
unsigned int));
396 fMaxBufferSize *
sizeof(
unsigned int));
400 for(
unsigned int D = 0; D < XArgType::rank::value; D++)
403 unsigned int n_global = fDimensionSize[0];
404 unsigned int n_local_1d_transforms = 1;
405 for(
unsigned int i = 0; i < XArgType::rank::value - 1; i++)
409 n_global *= fSpatialDim[i];
410 n_local_1d_transforms *= fSpatialDim[i];
415 unsigned int nDummy = fNLocal - (n_global % fNLocal);
416 if(nDummy == fNLocal)
422 if(fMaxNWorkItems < n_global)
424 fMaxNWorkItems = n_global;
436 unsigned int n_multdim_ffts = fDimensionSize[0];
438 fFFTKernel->setArg(0, n_multdim_ffts);
441 fFFTKernel->setArg(1, 0);
444 fFFTKernel->setArg(2, 0);
447 fFFTKernel->setArg(3, *fSpatialDimensionBufferCL);
448 Q.enqueueWriteBuffer(*fSpatialDimensionBufferCL, CL_TRUE, 0, (XArgType::rank::value - 1) *
sizeof(
unsigned int),
450 #ifdef ENFORCE_CL_FINISH
457 fFFTKernel->setArg(4, *fTwiddleBufferCL);
458 fFFTKernel->setArg(5, *fConjugateTwiddleBufferCL);
459 fFFTKernel->setArg(6, *fScaleBufferCL);
460 fFFTKernel->setArg(7, *fCirculantBufferCL);
461 fFFTKernel->setArg(8, *fPermuationArrayCL);
463 if(fAllSpatialDimensionsAreEqual)
466 Q.enqueueWriteBuffer(*fTwiddleBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2), &(fTwiddle[0][0]));
467 #ifdef ENFORCE_CL_FINISH
470 Q.enqueueWriteBuffer(*fConjugateTwiddleBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2),
471 &(fConjugateTwiddle[0][0]));
472 #ifdef ENFORCE_CL_FINISH
475 Q.enqueueWriteBuffer(*fScaleBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2), &(fScale[0][0]));
476 #ifdef ENFORCE_CL_FINISH
479 Q.enqueueWriteBuffer(*fCirculantBufferCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
CL_TYPE2), &(fCirculant[0][0]));
480 #ifdef ENFORCE_CL_FINISH
483 Q.enqueueWriteBuffer(*fPermuationArrayCL, CL_TRUE, 0, fMaxBufferSize *
sizeof(
unsigned int),
484 &(fPermuationArray[0][0]));
485 #ifdef ENFORCE_CL_FINISH
491 fFFTKernel->setArg(9, *fDataBufferCL);
496 void FillDataBuffer()
498 if(fFillFromHostData)
502 auto* ptr = (
CL_TYPE2*)(&((this->fInput->GetData())[0]));
503 std::cout <<
"total data size = " << fTotalDataSize << std::endl;
504 Q.enqueueWriteBuffer(*fDataBufferCL, CL_TRUE, 0, fTotalDataSize *
sizeof(
CL_TYPE2), ptr);
505 #ifdef ENFORCE_CL_FINISH
515 void ReadOutDataBuffer()
517 if(fReadOutDataToHost)
521 auto* ptr = (
CL_TYPE2*)(&((this->fInput->GetData())[0]));
522 Q.enqueueReadBuffer(*fDataBufferCL, CL_TRUE, 0, fTotalDataSize *
sizeof(
CL_TYPE2), ptr);
523 #ifdef ENFORCE_CL_FINISH
532 virtual bool DoInputOutputDimensionsMatch()
534 size_t in[XArgType::rank::value];
535 size_t out[XArgType::rank::value];
537 this->fInput->GetDimensions(in);
538 this->fOutput->GetDimensions(out);
540 for(
unsigned int i = 0; i < XArgType::rank::value; i++)
555 bool fAllSpatialDimensionsAreEqual;
556 bool fFillFromHostData;
557 bool fReadOutDataToHost;
558 size_t fDimensionSize[XArgType::rank::value];
559 unsigned int fSpatialDim[XArgType::rank::value - 1];
560 unsigned int fMaxNWorkItems;
562 unsigned int fMaxBufferSize;
563 unsigned int fTotalDataSize;
567 std::vector< std::complex< double > > fTwiddle[XArgType::rank::value];
568 std::vector< std::complex< double > > fConjugateTwiddle[XArgType::rank::value];
569 std::vector< std::complex< double > > fScale[XArgType::rank::value];
570 std::vector< std::complex< double > > fCirculant[XArgType::rank::value];
571 std::vector< unsigned int > fPermuationArray[XArgType::rank::value];
575 std::string fOpenCLFlags;
577 mutable cl::Kernel* fFFTKernel;
580 cl::Buffer* fSpatialDimensionBufferCL;
583 cl::Buffer* fTwiddleBufferCL;
586 cl::Buffer* fConjugateTwiddleBufferCL;
589 cl::Buffer* fScaleBufferCL;
592 cl::Buffer* fCirculantBufferCL;
595 cl::Buffer* fDataBufferCL;
598 cl::Buffer* fPermuationArrayCL;
601 cl::Buffer* fWorkspaceBufferCL;
603 unsigned int fNLocal;
604 unsigned int fNGlobal;
605 unsigned int fPreferredWorkgroupMultiple;
#define CL_TYPE2
Definition: MHO_OpenCLInterface.hh:54
static void ComputeBitReversedIndicesBaseTwo(unsigned int N, unsigned int *index_arr)
Computes bit-reversed indices using Buneman algorithm for input N, must have N = 2^P,...
Definition: MHO_BitReversalPermutation.cc:119
static bool IsPowerOfTwo(unsigned int N)
Checks if an unsigned integer is a power of two.
Definition: MHO_BitReversalPermutation.cc:10
std::string GetKernelPath() const
Definition: MHO_OpenCLInterface.hh:133
cl::Device GetDevice() const
Definition: MHO_OpenCLInterface.hh:119
cl::CommandQueue & GetQueue(int i=-1) const
Definition: MHO_OpenCLInterface.cc:125
static MHO_OpenCLInterface * GetInstance()
Definition: MHO_OpenCLInterface.cc:32
Class MHO_Operator.
Definition: MHO_Operator.hh:21
Definition: MHO_ChannelLabeler.hh:17
Definition: MHO_Meta.hh:341