#include <gtest/gtest.h>
#include <testHelpers.hpp>
#include <string>
using std::string;
using std::cout;
using std::endl;
using std::ostream_iterator;
using std::vector;
template<typename T>
class MatrixMultiply : public ::testing::Test
{
};
typedef ::testing::Types<float, af::cfloat, double, af::cdouble> TestTypes;
TYPED_TEST_CASE(MatrixMultiply, TestTypes);
template<typename T, bool isBVector>
void MatMulCheck(string TestFile)
{
    if (noDoubleTests<T>()) return;
    using std::vector;
    vector<af::dim4> numDims;
    vector<vector<T> > hData;
    vector<vector<T> > tests;
    readTests<T,T,int>(TestFile, numDims, hData, tests);
    {
        atdims[0]   =    atdims[1];
        atdims[1]   =    f;
    }
    {
        btdims[0] = btdims[1];
        btdims[1] = f;
    }
    vector<af_array> out(tests.size(), 0);
    if(isBVector) {
    }
    else {
    }
    for(size_t i = 0; i < tests.size(); i++) {
        vector<T> h_out(elems);
        if( false == equal(h_out.begin(), h_out.end(), tests[i].begin()) ) {
            cout << "Failed test " << i << "\nCalculated: " << endl;
            copy(h_out.begin(), h_out.end(), ostream_iterator<T>(cout, 
", "));
 
            cout << "Expected: " << endl;
            copy(tests[i].begin(), tests[i].
end(), ostream_iterator<T>(cout, 
", "));
 
            FAIL();
        }
    }
    for (size_t i = 0; i <  out.size(); i++) {
    }
}
TYPED_TEST(MatrixMultiply, Square)
{
    MatMulCheck<TypeParam, false>(TEST_DIR"/blas/Basic.test");
}
TYPED_TEST(MatrixMultiply, NonSquare)
{
    MatMulCheck<TypeParam, false>(TEST_DIR"/blas/NonSquare.test");
}
TYPED_TEST(MatrixMultiply, SquareVector)
{
    MatMulCheck<TypeParam, true>(TEST_DIR"/blas/SquareVector.test");
}
TYPED_TEST(MatrixMultiply, RectangleVector)
{
    MatMulCheck<TypeParam, true>(TEST_DIR"/blas/RectangleVector.test");
}
template<typename T, bool isBVector>
void cppMatMulCheck(string TestFile)
{
    if (noDoubleTests<T>()) return;
    using std::vector;
    vector<af::dim4> numDims;
    vector<vector<T> > hData;
    vector<vector<T> > tests;
    readTests<T,T,int>(TestFile, numDims, hData, tests);
    {
        atdims[0]   =    atdims[1];
        atdims[1]   =    f;
    }
    {
        btdims[0] = btdims[1];
        btdims[1] = f;
    }
    vector<af::array> out(tests.size());
    if(isBVector) {
    }
    else {
    }
    for(size_t i = 0; i < tests.size(); i++) {
        dim_t elems = out[i].elements();
 
        vector<T> h_out(elems);
        out[i].host((void*)&h_out.front());
        if (false == equal(h_out.begin(), h_out.end(), tests[i].begin())) {
            cout << "Failed test " << i << "\nCalculated: " << endl;
            copy(h_out.begin(), h_out.end(), ostream_iterator<T>(cout, 
", "));
 
            cout << "Expected: " << endl;
            copy(tests[i].begin(), tests[i].
end(), ostream_iterator<T>(cout, 
", "));
 
            FAIL();
        }
    }
}
TYPED_TEST(MatrixMultiply, Square_CPP)
{
    cppMatMulCheck<TypeParam, false>(TEST_DIR"/blas/Basic.test");
}
TYPED_TEST(MatrixMultiply, NonSquare_CPP)
{
    cppMatMulCheck<TypeParam, false>(TEST_DIR"/blas/NonSquare.test");
}
TYPED_TEST(MatrixMultiply, SquareVector_CPP)
{
    cppMatMulCheck<TypeParam, true>(TEST_DIR"/blas/SquareVector.test");
}
TYPED_TEST(MatrixMultiply, RectangleVector_CPP)
{
    cppMatMulCheck<TypeParam, true>(TEST_DIR"/blas/RectangleVector.test");
}
#define DEVICE_ITERATE(func) do {                                           \
    const char* ENV = getenv("AF_MULTI_GPU_TESTS");                         \
    if(ENV && ENV[0] == '0') {                                              \
        func;                                                               \
    } else {                                                                \
        int oldDevice = af::getDevice();                                    \
        for(int i = 0; i < af::getDeviceCount(); i++) {                     \
            af::setDevice(i);                                               \
            func;                                                           \
        }                                                                   \
        af::setDevice(oldDevice);                                           \
    }                                                                       \
} while(0);
TYPED_TEST(MatrixMultiply, MultiGPUSquare_CPP)
{
    DEVICE_ITERATE((cppMatMulCheck<TypeParam, false>(TEST_DIR"/blas/Basic.test")));
}
TYPED_TEST(MatrixMultiply, MultiGPUNonSquare_CPP)
{
    DEVICE_ITERATE((cppMatMulCheck<TypeParam, false>(TEST_DIR"/blas/NonSquare.test")));
}
TYPED_TEST(MatrixMultiply, MultiGPUSquareVector_CPP)
{
    DEVICE_ITERATE((cppMatMulCheck<TypeParam, true>(TEST_DIR"/blas/SquareVector.test")));
}
TYPED_TEST(MatrixMultiply, MultiGPURectangleVector_CPP)
{
    DEVICE_ITERATE((cppMatMulCheck<TypeParam, true>(TEST_DIR"/blas/RectangleVector.test")));
}
#undef DEVICE_ITERATE