/*******************************************************************************
* Copyright 2018-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
 *
 *  Content:
 *            oneapi::mkl::vm::add example program text (SYCL interface)
 *
 *******************************************************************************/

#include <CL/sycl.hpp>
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <map>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl/types.hpp"
#include "oneapi/mkl/vm.hpp"

#include "common_for_examples.hpp"

#include "vml_common.hpp"

static constexpr int VLEN = 4;

static ulp_table_type ulp_table = {
    {MAX_HA_ULP_S, 4.5}, {MAX_LA_ULP_S, 5.0}, {MAX_EP_ULP_S, 5.0E3},
    {MAX_HA_ULP_D, 2.0}, {MAX_LA_ULP_D, 5.0}, {MAX_EP_ULP_D, 7.0E7},
    {MAX_HA_ULP_C, 2.0}, {MAX_LA_ULP_C, 5.0}, {MAX_EP_ULP_C, 5.0E3},
    {MAX_HA_ULP_Z, 2.0}, {MAX_LA_ULP_Z, 5.0}, {MAX_EP_ULP_Z, 7.0E7},
};

//!
//! @brief Accuracy test
//!
template <typename A, typename R>
bool vAddAccuracyLiteTest(const sycl::device &dev) {

    int argtype =

        ARG2_RES1;
    // *************************************************************
    // Data table declaraion
    // *************************************************************
data_3_t data
{
    .i = 0
,

    .data_f32 = std::vector<data_3_f32_t>

    {

{ UINT32_C(0x40D9B85C), UINT32_C(0xC007309A), UINT32_C(0x4096200F) }, //  0: vsAdd ( 6.80375481     , -2.1123414      ) = ( 4.6914134       );
{ UINT32_C(0x40B52EFA), UINT32_C(0x40BF006A), UINT32_C(0x413A17B2) }, //  1: vsAdd ( 5.66198444     , 5.96880054      ) = ( 11.630785       );
{ UINT32_C(0x4103BA28), UINT32_C(0xC0C1912F), UINT32_C(0x400BC642) }, //  2: vsAdd ( 8.2329483      , -6.04897261     ) = ( 2.1839757       );
{ UINT32_C(0xC052EA36), UINT32_C(0x40ABAABC), UINT32_C(0x40046B42) }, //  3: vsAdd ( -3.2955451     , 5.3645916       ) = ( 2.0690465       );
 }

,
    .data_f64 = std::vector<data_3_f64_t>

    {

{ UINT64_C(0x401B370B60E66E18), UINT64_C(0xC000E6134801CC26), UINT64_C(0x4012C401BCE58805) }, //  0: vdAdd ( 6.80375434309419092      , -2.11234146361813924      ) = ( 4.69141287947605168       );
{ UINT64_C(0x4016A5DF421D4BBE), UINT64_C(0x4017E00D485FC01A), UINT64_C(0x402742F6453E85EC) }, //  1: vdAdd ( 5.66198447517211711      , 5.96880066952146571       ) = ( 11.6307851446935828       );
{ UINT64_C(0x40207744D998EE8A), UINT64_C(0xC0183225E080644C), UINT64_C(0x400178C7A562F190) }, //  2: vdAdd ( 8.23294715873568705      , -6.04897261413232101      ) = ( 2.18397454460336604       );
{ UINT64_C(0xC00A5D46A314BA8E), UINT64_C(0x4015755793FAEAB0), UINT64_C(0x40008D6884E11AD2) }, //  3: vdAdd ( -3.2955448857022196      , 5.36459189623808186       ) = ( 2.06904701053586226       );
 }

,

    .data_c32 = std::vector <data_3_c32_t>

    {

{ { UINT32_C(0xC007309A), UINT32_C(0x40D9B85C) }, { UINT32_C(0x40BF006A), UINT32_C(0x40B52EFA) }, { UINT32_C(0x4076D03A), UINT32_C(0x414773AB) } }, //  0: vcAdd ( -2.1123414      + i * 6.80375481     , 5.96880054      + i * 5.66198444      ) = ( 3.85645914      + i * 12.4657393      );
{ { UINT32_C(0xC0C1912F), UINT32_C(0x4103BA28) }, { UINT32_C(0x40ABAABC), UINT32_C(0xC052EA36) }, { UINT32_C(0xBF2F3398), UINT32_C(0x409DFF35) } }, //  1: vcAdd ( -6.04897261     + i * 8.2329483      , 5.3645916       + i * -3.2955451      ) = ( -0.684381008    + i * 4.9374032       );
{ { UINT32_C(0x3F8A29C0), UINT32_C(0xC08E3964) }, { UINT32_C(0x4024F46C), UINT32_C(0xBEE77440) }, { UINT32_C(0x406A094C), UINT32_C(0xC09CB0A8) } }, //  2: vcAdd ( 1.07939911      + i * -4.44450569    , 2.57741833      + i * -0.452058792    ) = ( 3.65681744      + i * -4.89656448     );
{ { UINT32_C(0x3E8939C0), UINT32_C(0xC02D136C) }, { UINT32_C(0x41052EB4), UINT32_C(0x4110B6A8) }, { UINT32_C(0x41097882), UINT32_C(0x40CAE39A) } }, //  3: vcAdd ( 0.268018723     + i * -2.70431042    , 8.32390213      + i * 9.04459381      ) = ( 8.59192085      + i * 6.34028339      );
 }

,
    .data_c64 = std::vector <data_3_c64_t>

    {

{ { UINT64_C(0xC000E6134801CC26), UINT64_C(0x401B370B60E66E18) }, { UINT64_C(0x4017E00D485FC01A), UINT64_C(0x4016A5DF421D4BBE) }, { UINT64_C(0x400EDA0748BDB40E), UINT64_C(0x4028EE755181DCEB) } }, //  0: vzAdd ( -2.11234146361813924      + i * 6.80375434309419092      , 5.96880066952146571       + i * 5.66198447517211711       ) = ( 3.85645920590332647       + i * 12.465738818266308        );
{ { UINT64_C(0xC0183225E080644C), UINT64_C(0x40207744D998EE8A) }, { UINT64_C(0x4015755793FAEAB0), UINT64_C(0xC00A5D46A314BA8E) }, { UINT64_C(0xBFE5E672642BCCE0), UINT64_C(0x4013BFE661A77FCD) } }, //  1: vzAdd ( -6.04897261413232101      + i * 8.23294715873568705      , 5.36459189623808186       + i * -3.2955448857022196       ) = ( -0.684380717894239154     + i * 4.93740227303346746       );
{ { UINT64_C(0x3FF1453801E28A70), UINT64_C(0xC011C72C86338E59) }, { UINT64_C(0x40049E8D96893D1C), UINT64_C(0xBFDCEE88B739DD20) }, { UINT64_C(0x400D4129977A8254), UINT64_C(0xC013961511A72C2B) } }, //  2: vzAdd ( 1.07939911590861115       + i * -4.44450578393624429     , 2.57741849523848821       + i * -0.452058962756796134     ) = ( 3.65681761114709936       + i * -4.89656474669304043      );
{ { UINT64_C(0x3FD12735D3224E60), UINT64_C(0xC005A26D910B44DC) }, { UINT64_C(0x4020A5D666294BAC), UINT64_C(0x402216D5173C2DAA) }, { UINT64_C(0x40212F1014C25E1F), UINT64_C(0x40195C7365F2B8E6) } }, //  3: vzAdd ( 0.268018203912310682      + i * -2.70431054416313366     , 8.32390136007401082       + i * 9.04459450349425609       ) = ( 8.59191956398632151       + i * 6.34028395933112243       );
 }

};

    // Catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const &e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch (sycl::exception const &e) {
                std::cout << "Caught asynchronous SYCL exception:\n"
                          << e.what() << std::endl;
            }
        } // for (std::exception_ptr const& e : exceptions)
    };

    // Create execution queue with asynchronous error handling
    sycl::queue main_queue(dev, exception_handler);

    // Get device name
    std::string dev_name =
        main_queue.get_device().get_info<sycl::info::device::name>();

    // *************************************************************
    // Variable declaraions
    // *************************************************************
    // Input arguments
    A arg1;

    A arg2;

    A *varg1 = own_malloc<A>(main_queue, VLEN * sizeof(A));

    A *varg2 = own_malloc<A>(main_queue, VLEN * sizeof(A));

    R ref1;
    // Output results
    R *vres1 = own_malloc<R>(main_queue, VLEN * sizeof(R));
    R *vref1 = own_malloc<R>(main_queue, VLEN * sizeof(R));

    // Output strided results
    R *vresi1 = own_malloc<R>(main_queue, VLEN * sizeof(R));
    R *vrefi1 = own_malloc<R>(main_queue, VLEN * sizeof(R));
    // Number of errors
    int errs = 0;

    // *************************************************************
    // Vector input data initialization
    // *************************************************************
    for (int i = 0; i < VLEN; ++i) {
        // Getting values from reference data table

        data.get_values(arg1, arg2, ref1);

        // Pushing values into vectors
        varg1[i] = arg1;
        vref1[i] = ref1;

        // Fill only uneven  results for strided case
        vrefi1[i] = (i & 1) ? ref1 : R(777);

        varg2[i] = arg2;
    } // for (int i = 0; i < VLEN; ++i)

    // *************************************************************
    // Loop by all 3 accuracy modes of VM: HA, LA, EP:
    // set computation mode, run VM and check results
    // *************************************************************
    for (int acc = 0; acc < ACCURACY_NUM; ++acc) {
        std::vector<sycl::event> no_deps;

        // Clear result vectors
        for (int i = 0; i < VLEN; ++i) {
            vres1[i] = R(777);

            vresi1[i] = R(777);
        }

        // Run VM function

        oneapi::mkl::vm::add(main_queue, VLEN, varg1, varg2, vres1, no_deps,
                             accuracy_mode[acc]);
        // Strided VM calls for each uneven element
        {

            oneapi::mkl::vm::add(main_queue, varg1,
                                 oneapi::mkl::slice(1, VLEN / 2, 2), varg2,
                                 oneapi::mkl::slice(1, VLEN / 2, 2), vresi1,
                                 oneapi::mkl::slice(1, VLEN / 2, 2), no_deps,
                                 accuracy_mode[acc]);
        }

        // Catch sycl exceptions
        try {
            main_queue.wait_and_throw();
        } catch (sycl::exception const &e) {
            std::cerr << "SYCL exception during Accuracy Test\n"
                      << e.what() << " code: " << e.code().value() << std::endl;
            return false;
        }

        // *************************************************************
        // Compute ulp between computed and expected (reference)
        // values and check
        // *************************************************************
        for (int i = 0; i < VLEN; ++i) {
            // Check simple indexing function
            errs +=
                check_result<A, R>("", "add", ",simple", i, argtype, acc,
                                   varg1[i], varg2[i], vres1[i], vres1[i],
                                   vref1[i], vref1[i], ulp_table, false, false);

            // Check strided indexing function
            errs += check_result<A, R>("", "add", ",strided", i, argtype, acc,
                                       varg1[i], varg2[i], vresi1[i], vresi1[i],
                                       vrefi1[i], vrefi1[i], ulp_table, false,
                                       false);

        } // for (int i = 0; i < VLEN; ++i)
    }     // for (int acc = 0; acc < ACCURACY_NUM; ++acc)

    std::cout << "\tResult: " << ((errs == 0) ? "PASS" : "FAIL") << std::endl;

    own_free<A>(main_queue, varg1);

    own_free<A>(main_queue, varg2);

    own_free<R>(main_queue, vres1);
    own_free<R>(main_queue, vref1);

    own_free<R>(main_queue, vresi1);
    own_free<R>(main_queue, vrefi1);
    return (errs == 0);
} // template <typename A, typename R> bool vFuncAccuracyText (const device
  // &dev)
//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "#############################################################"
                 "###########"
              << std::endl;
    std::cout << "# General VM "
              << "add"
              << " Function Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   vm::"
              << "add" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;

    std::cout << "#   float" << std::endl;

    std::cout << "#   double" << std::endl;

    std::cout << "#   std::complex<float>" << std::endl;

    std::cout << "#   std::complex<double>" << std::endl;

    std::cout << "#############################################################"
                 "###########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_host -- only runs SYCL HOST device
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: HOST, CPU and GPU devices
//
//  For each device selected and each data type supported, the example
//  runs with all supported data types
//
int main(int argc, char **argv) {
    // Number of errors occured
    int errs = 0;

    // Print standard banner for VM examples
    print_example_banner();

    // List of available devices
    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    // Loop by all available devices
    for (auto dev_type : list_of_devices) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, dev_type);

        // Run tests if the device is available
        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[dev_type]
                      << ".\n";

            std::cout << "\tRunning with single precision real data type:"
                      << std::endl;
            errs += vAddAccuracyLiteTest<float, float>(my_dev);

            std::cout << "\tRunning with double precision real data type:"
                      << std::endl;
            errs += vAddAccuracyLiteTest<double, double>(my_dev);

            std::cout << "\tRunning with single precision complex data type:"
                      << std::endl;
            errs +=
                vAddAccuracyLiteTest<std::complex<float>, std::complex<float>>(
                    my_dev);

            std::cout << "\tRunning with double precision complex data type:"
                      << std::endl;
            errs += vAddAccuracyLiteTest<std::complex<double>,
                                         std::complex<double>>(my_dev);

        } else {

            std::cout << "No " << sycl_device_names[dev_type]
                      << " devices found; skipping "
                      << sycl_device_names[dev_type] << " tests.\n";
        }
    }

    return 0;
}
