#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <stdexcept>
#include <testHelpers.hpp>
using std::string;
using std::vector;
{
    vector<char>   in(100,1);
}
TEST(
fft2, Invalid_Array)
{
    if (noDoubleTests<float>()) return;
    vector<float>   in(100,1);
}
TEST(
fft3, Invalid_Array)
{
    if (noDoubleTests<float>()) return;
    vector<float>   in(100,1);
}
TEST(
ifft2, Invalid_Array)
{
    if (noDoubleTests<float>()) return;
    vector<float>   in(100,1);
}
TEST(
ifft3, Invalid_Array)
{
    if (noDoubleTests<float>()) return;
    vector<float>   in(100,1);
}
template<typename inType, typename outType, bool isInverse>
{
    if (noDoubleTests<inType>()) return;
    if (noDoubleTests<outType>()) return;
    vector<af::dim4>        numDims;
    vector<vector<inType> >       in;
    vector<vector<outType> >   tests;
    readTestsFromFile<inType, outType>(pTestFile, numDims, in, tests);
    if (isInverse){
        switch (dims.ndims()) {
            case 3 : ASSERT_EQ(
AF_SUCCESS, 
af_ifft3(&outArray, inArray, 1.0, pad0, pad1, pad2));  
break;
 
            default: throw std::runtime_error("This error shouldn't happen, pls check");
        }
    } else {
        switch(dims.ndims()) {
            case 3 : ASSERT_EQ(
AF_SUCCESS, 
af_fft3(&outArray, inArray, 1.0, pad0, pad1, pad2));   
break;
 
            default: throw std::runtime_error("This error shouldn't happen, pls check");
        }
    }
    size_t out_size = tests[0].size();
    outType *outData= new outType[out_size];
    vector<outType> goldBar(tests[0].begin(), tests[0].
end());
    size_t test_size = 0;
    switch(dims.ndims()) {
        case 1  : test_size = dims[0]/2+1;                       break;
        case 2  : test_size = dims[1] * (dims[0]/2+1);           break;
        case 3  : test_size = dims[2] * dims[1] * (dims[0]/2+1); break;
        default : test_size = dims[0]/2+1;                       break;
    }
    outType output_scale = (outType)(isInverse ? test_size : 1);
    for (size_t elIter=0; elIter<test_size; ++elIter) {
        bool isUnderTolerance = 
abs(goldBar[elIter]-outData[elIter])<0.001;
 
        ASSERT_EQ(true, isUnderTolerance)<<
            "Expected value="<<goldBar[elIter] <<"\t Actual Value="<<
            (output_scale*outData[elIter]) << " at: " << elIter<< std::endl;
    }
    
    delete[] outData;
}
#define INSTANTIATE_TEST(func, name, is_inverse, in_t, out_t, ...)  \
    TEST(func, name)                                                \
    {                                                               \
        fftTest<in_t, out_t, is_inverse>(__VA_ARGS__);              \
    }
INSTANTIATE_TEST(
fft ,  R2C_Float, 
false,  
float,  
cfloat, 
string(TEST_DIR
"/signal/fft_r2c.test") );
INSTANTIATE_TEST(
fft , R2C_Double, 
false, 
double, 
cdouble, 
string(TEST_DIR
"/signal/fft_r2c.test") );
INSTANTIATE_TEST(
fft2,  R2C_Float, 
false,  
float,  
cfloat, 
string(TEST_DIR
"/signal/fft2_r2c.test"));
INSTANTIATE_TEST(
fft2, R2C_Double, 
false, 
double, 
cdouble, 
string(TEST_DIR
"/signal/fft2_r2c.test"));
INSTANTIATE_TEST(
fft3,  R2C_Float, 
false,  
float,  
cfloat, 
string(TEST_DIR
"/signal/fft3_r2c.test"));
INSTANTIATE_TEST(
fft3, R2C_Double, 
false, 
double, 
cdouble, 
string(TEST_DIR
"/signal/fft3_r2c.test"));
INSTANTIATE_TEST(
fft ,  C2C_Float, 
false,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/fft_c2c.test") );
INSTANTIATE_TEST(
fft , C2C_Double, 
false, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/fft_c2c.test") );
INSTANTIATE_TEST(
fft2,  C2C_Float, 
false,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/fft2_c2c.test"));
INSTANTIATE_TEST(
fft2, C2C_Double, 
false, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/fft2_c2c.test"));
INSTANTIATE_TEST(
fft3,  C2C_Float, 
false,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/fft3_c2c.test"));
INSTANTIATE_TEST(
fft3, C2C_Double, 
false, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/fft3_c2c.test"));
INSTANTIATE_TEST(
fft2,  R2C_Float_Trunc, 
false,  
float,  
cfloat, 
string(TEST_DIR
"/signal/fft2_r2c_trunc.test"), 16, 16);
INSTANTIATE_TEST(
fft2, R2C_Double_Trunc, 
false, 
double, 
cdouble, 
string(TEST_DIR
"/signal/fft2_r2c_trunc.test"), 16, 16);
INSTANTIATE_TEST(
fft2,  C2C_Float_Pad, 
false,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/fft2_c2c_pad.test"), 16, 16);
INSTANTIATE_TEST(
fft2, C2C_Double_Pad, 
false, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/fft2_c2c_pad.test"), 16, 16);
INSTANTIATE_TEST(
ifft ,  C2C_Float, 
true,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/ifft_c2c.test") );
INSTANTIATE_TEST(
ifft , C2C_Double, 
true, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/ifft_c2c.test") );
INSTANTIATE_TEST(
ifft2,  C2C_Float, 
true,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/ifft2_c2c.test"));
INSTANTIATE_TEST(
ifft2, C2C_Double, 
true, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/ifft2_c2c.test"));
INSTANTIATE_TEST(
ifft3,  C2C_Float, 
true,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/ifft3_c2c.test"));
INSTANTIATE_TEST(
ifft3, C2C_Double, 
true, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/ifft3_c2c.test"));
template<typename inType, typename outType, int rank, bool isInverse>
void fftBatchTest(
string pTestFile, 
dim_t pad0=0, 
dim_t pad1=0, 
dim_t pad2=0)
 
{
    if (noDoubleTests<inType>()) return;
    if (noDoubleTests<outType>()) return;
    vector<af::dim4>        numDims;
    vector<vector<inType> >       in;
    vector<vector<outType> >   tests;
    readTestsFromFile<inType, outType>(pTestFile, numDims, in, tests);
    if(isInverse) {
            case 3 : ASSERT_EQ(
AF_SUCCESS, 
af_ifft3(&outArray, inArray, 1.0, pad0, pad1, pad2));  
break;
 
            default: throw std::runtime_error("This error shouldn't happen, pls check");
        }
    } else {
            case 3 : ASSERT_EQ(
AF_SUCCESS, 
af_fft3(&outArray, inArray, 1.0, pad0, pad1, pad2));   
break;
 
            default: throw std::runtime_error("This error shouldn't happen, pls check");
        }
    }
    size_t out_size = tests[0].size();
    outType *outData= new outType[out_size];
    vector<outType> goldBar(tests[0].begin(), tests[0].
end());
    size_t test_size = 0;
    size_t batch_count = dims[
rank];
 
        case 1  : test_size = dims[0]/2+1;                       break;
        case 2  : test_size = dims[1] * (dims[0]/2+1);           break;
        case 3  : test_size = dims[2] * dims[1] * (dims[0]/2+1); break;
        default : test_size = dims[0]/2+1;                       break;
    }
    size_t batch_stride = 1;
    for(
int i=0; i<
rank; ++i) batch_stride *= dims[i];
 
    outType output_scale = (outType)(isInverse ? test_size : 1);
    for(size_t batchId=0; batchId<batch_count; ++batchId) {
        size_t off = batchId*batch_stride;
        for (size_t elIter=0; elIter<test_size; ++elIter) {
            bool isUnderTolerance = 
abs(goldBar[elIter+off]-outData[elIter+off])<0.001;
 
            ASSERT_EQ(true, isUnderTolerance)<<"Batch id = "<<batchId<<
                "; Expected value="<<goldBar[elIter+off] <<"\t Actual Value="<<
                (output_scale*outData[elIter+off]) << " at: " << elIter<< std::endl;
        }
    }
    
    delete[] outData;
}
#define INSTANTIATE_BATCH_TEST(func, name, rank, is_inverse, in_t, out_t, ...) \
    TEST(func, name##_Batch)                                                   \
    {                                                                          \
        fftBatchTest<in_t, out_t, rank, is_inverse>(__VA_ARGS__);              \
    }
INSTANTIATE_BATCH_TEST(
fft , R2C_Float, 1, 
false, 
float, 
cfloat, 
string(TEST_DIR
"/signal/fft_r2c_batch.test") );
INSTANTIATE_BATCH_TEST(
fft2, R2C_Float, 2, 
false, 
float, 
cfloat, 
string(TEST_DIR
"/signal/fft2_r2c_batch.test"));
INSTANTIATE_BATCH_TEST(
fft3, R2C_Float, 3, 
false, 
float, 
cfloat, 
string(TEST_DIR
"/signal/fft3_r2c_batch.test"));
INSTANTIATE_BATCH_TEST(
fft , C2C_Float, 1, 
false, 
cfloat, 
cfloat, 
string(TEST_DIR
"/signal/fft_c2c_batch.test") );
INSTANTIATE_BATCH_TEST(
fft2, C2C_Float, 2, 
false, 
cfloat, 
cfloat, 
string(TEST_DIR
"/signal/fft2_c2c_batch.test"));
INSTANTIATE_BATCH_TEST(
fft3, C2C_Float, 3, 
false, 
cfloat, 
cfloat, 
string(TEST_DIR
"/signal/fft3_c2c_batch.test"));
INSTANTIATE_BATCH_TEST(
ifft , C2C_Float, 1, 
true, 
cfloat, 
cfloat, 
string(TEST_DIR
"/signal/ifft_c2c_batch.test") );
INSTANTIATE_BATCH_TEST(
ifft2, C2C_Float, 2, 
true, 
cfloat, 
cfloat, 
string(TEST_DIR
"/signal/ifft2_c2c_batch.test"));
INSTANTIATE_BATCH_TEST(
ifft3, C2C_Float, 3, 
true, 
cfloat, 
cfloat, 
string(TEST_DIR
"/signal/ifft3_c2c_batch.test"));
INSTANTIATE_BATCH_TEST(
fft2,  R2C_Float_Trunc, 2, 
false,  
float,  
cfloat, 
string(TEST_DIR
"/signal/fft2_r2c_trunc_batch.test"), 16, 16);
INSTANTIATE_BATCH_TEST(
fft2, R2C_Double_Trunc, 2, 
false, 
double, 
cdouble, 
string(TEST_DIR
"/signal/fft2_r2c_trunc_batch.test"), 16, 16);
INSTANTIATE_BATCH_TEST(
fft2,  C2C_Float_Pad, 2, 
false,  
cfloat,  
cfloat, 
string(TEST_DIR
"/signal/fft2_c2c_pad_batch.test"), 16, 16);
INSTANTIATE_BATCH_TEST(
fft2, C2C_Double_Pad, 2, 
false, 
cdouble, 
cdouble, 
string(TEST_DIR
"/signal/fft2_c2c_pad_batch.test"), 16, 16);
template<typename inType, typename outType, bool isInverse>
void cppFFTTest(
string pTestFile, 
dim_t pad0=0, 
dim_t pad1=0, 
dim_t pad2=0)
 
{
    if (noDoubleTests<inType>()) return;
    if (noDoubleTests<outType>()) return;
    vector<af::dim4>        numDims;
    vector<vector<inType> >       in;
    vector<vector<outType> >   tests;
    readTestsFromFile<inType, outType>(pTestFile, numDims, in, tests);
    if (isInverse){
    } else {
    }
    size_t out_size = tests[0].size();
    output.
host((
void*)outData);
    vector<cfloat> goldBar(tests[0].begin(), tests[0].
end());
    size_t test_size = 0;
    switch(dims.ndims()) {
        case 1  : test_size = dims[0]/2+1;                       break;
        case 2  : test_size = dims[1] * (dims[0]/2+1);           break;
        case 3  : test_size = dims[2] * dims[1] * (dims[0]/2+1); break;
        default : test_size = dims[0]/2+1;                       break;
    }
    outType output_scale = (outType)(isInverse ? test_size : 1);
    for (size_t elIter=0; elIter<test_size; ++elIter) {
        bool isUnderTolerance = 
abs(goldBar[elIter]-outData[elIter])<0.001;
 
        ASSERT_EQ(true, isUnderTolerance)<<
            "Expected value="<<goldBar[elIter] <<"\t Actual Value="<<
            (output_scale*outData[elIter]) << " at: " << elIter<< std::endl;
    }
    
    delete[] outData;
}
template<typename inType, typename outType, bool isInverse>
void cppDFTTest(
string pTestFile, 
dim_t pad0=0, 
dim_t pad1=0, 
dim_t pad2=0)
 
{
    if (noDoubleTests<inType>()) return;
    if (noDoubleTests<outType>()) return;
    vector<af::dim4>        numDims;
    vector<vector<inType> >       in;
    vector<vector<outType> >   tests;
    readTestsFromFile<inType, outType>(pTestFile, numDims, in, tests);
    if (isInverse){
    } else {
    }
    size_t out_size = tests[0].size();
    output.
host((
void*)outData);
    vector<cfloat> goldBar(tests[0].begin(), tests[0].
end());
    size_t test_size = 0;
    switch(dims.ndims()) {
        case 1  : test_size = dims[0]/2+1;                       break;
        case 2  : test_size = dims[1] * (dims[0]/2+1);           break;
        case 3  : test_size = dims[2] * dims[1] * (dims[0]/2+1); break;
        default : test_size = dims[0]/2+1;                       break;
    }
    outType output_scale = (outType)(isInverse ? test_size : 1);
    for (size_t elIter=0; elIter<test_size; ++elIter) {
        bool isUnderTolerance = 
abs(goldBar[elIter]-outData[elIter])<0.001;
 
        ASSERT_EQ(true, isUnderTolerance)<<
            "Expected value="<<goldBar[elIter] <<"\t Actual Value="<<
            (output_scale*outData[elIter]) << " at: " << elIter<< std::endl;
    }
    
    delete[] outData;
}
{
    cppFFTTest<cfloat, cfloat, false>(string(TEST_DIR"/signal/fft3_c2c.test"));
}
{
    cppFFTTest<cfloat, cfloat, true>(string(TEST_DIR"/signal/ifft3_c2c.test"));
}
{
    af::dim4 aStrides(1, aDims[0], aDims[0]*aDims[1], aDims[0]*aDims[1]*aDims[2]);
 
    af::dim4 cStrides(1, cDims[0], cDims[0]*cDims[1], cDims[0]*cDims[1]*cDims[2]);
 
    for (int k=0; k<(int)aDims[2]; ++k) {
        int gkOff = k*aStrides[2];
        int okOff = k*cStrides[2];
        for (int j=0; j<(int)aDims[1]; ++j) {
            int gjOff = j*aStrides[1];
            int ojOff = j*cStrides[1];
            for (int i=0; i<(int)aDims[0]; ++i) {
                int giOff = i*aStrides[0];
                int oiOff = i*cStrides[0];
                int gi = gkOff + gjOff + giOff;
                int oi = okOff + ojOff + oiOff;
                bool isUnderTolerance = 
std::abs(gold[gi]-out[2*oi])<0.001;
 
                ASSERT_EQ(true, isUnderTolerance)<< "Expected value="<<
                    gold[gi] <<"\t Actual Value="<< out[2*oi] << " at: " <<gi<< std::endl;
            }
        }
    }
    delete[] gold;
    delete[] out;
}
{
    cppDFTTest<cfloat, cfloat, false>(string(TEST_DIR"/signal/fft_c2c.test"));
}
{
    cppDFTTest<cfloat, cfloat, true>(string(TEST_DIR"/signal/ifft_c2c.test"));
}
TEST(dft2, CPP)
{
    cppDFTTest<cfloat, cfloat, false>(string(TEST_DIR"/signal/fft2_c2c.test"));
}
TEST(idft2, CPP)
{
    cppDFTTest<cfloat, cfloat, true>(string(TEST_DIR"/signal/ifft2_c2c.test"));
}
TEST(dft3, CPP)
{
    cppDFTTest<cfloat, cfloat, false>(string(TEST_DIR"/signal/fft3_c2c.test"));
}
TEST(idft3, CPP)
{
    cppDFTTest<cfloat, cfloat, true>(string(TEST_DIR"/signal/ifft3_c2c.test"));
}
{
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(h_b[i], h_B[i]) << "at: " << i << std::endl;
    }
    delete[] h_b;
    delete[] h_B;
}
{
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(h_b[i], h_B[i]) << "at: " << i << std::endl;
    }
    delete[] h_b;
    delete[] h_B;
}
{
    }
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(h_b[i], h_c[i]) << "at: " << i << std::endl;
    }
    delete[] h_b;
    delete[] h_c;
}
{
    }
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(h_b[i], h_c[i]) << "at: " << i << std::endl;
    }
    delete[] h_b;
    delete[] h_c;
}
{
    }
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(h_b[i], h_c[i]) << "at: " << i << std::endl;
    }
    delete[] h_b;
    delete[] h_c;
}
{
    std::vector<af::cfloat> ha(a.
elements());
    std::vector<af::cfloat> hb(b.
elements());
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(ha[i], hb[i]);
    }
}
{
    std::vector<af::cfloat> ha(a.
elements());
    std::vector<af::cfloat> hb(b.
elements());
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(ha[i], hb[i]);
    }
}
{
    std::vector<af::cfloat> ha(a.
elements());
    std::vector<af::cfloat> hb(b.
elements());
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(ha[i], hb[i]);
    }
}
{
    std::vector<af::cfloat> ha(a.
elements());
    std::vector<af::cfloat> hb(b.
elements());
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(ha[i], hb[i]);
    }
}
{
    std::vector<af::cfloat> ha(a.
elements());
    std::vector<af::cfloat> hb(b.
elements());
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(ha[i], hb[i]);
    }
}
{
    std::vector<af::cfloat> ha(a.
elements());
    std::vector<af::cfloat> hb(b.
elements());
    for (
int i = 0; i < (int)a.
elements(); i++) {
 
        ASSERT_EQ(ha[i], hb[i]);
    }
}