/*******************************************************************************
* Copyright 2014-2020 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file ComputeSPMV.cpp

 HPCG routine
 */

#include "ComputeSPMV.hpp"
#include "ComputeSPMV_ref.hpp"
#include "ComputeDotProduct.hpp"
#include "UsmUtil.hpp"
#ifndef HPCG_NO_MPI
#include "ExchangeHalo.hpp"

#include <mpi.h>
#include "Geometry.hpp"
#include <cstdlib>
#endif

#ifdef HPCG_USE_CUSTOM_KERNELS
#include "CustomKernels.hpp"
#endif

#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif

using namespace sycl;

/*!
  Routine to compute sparse matrix vector product y = Ax where:
  Precondition: First call exchange_externals to get off-processor values of x

  This routine calls the reference SpMV implementation by default, but
  can be replaced by a custom, optimized routine suited for
  the target system.

  @param[in]  A the known system matrix
  @param[in]  x the known vector
  @param[out] y the On exit contains the result: Ax.

  @return returns 0 upon success and non-zero otherwise

  @see ComputeSPMV_ref
*/

sycl::event ComputeSPMV( const SparseMatrix & A, Vector & x, Vector & y, sycl::queue & main_queue,
                         int& ierr, const std::vector<sycl::event> & deps)
{
#ifdef HPCG_LOCAL_LONG_LONG
    ComputeSPMV_ref(A,x,y);
#else
    struct optData *optData = (struct optData *)A.optimizationData;

#ifdef HPCG_USE_CUSTOM_KERNELS
    custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;
#else
    oneapi::mkl::sparse::matrix_handle_t hMatrixA = (oneapi::mkl::sparse::matrix_handle_t)optData->csrA;
    oneapi::mkl::sparse::matrix_handle_t hMatrixB = (oneapi::mkl::sparse::matrix_handle_t)optData->csrB;
#endif


    BEGIN_PROFILE("SPMV:halo");
#ifndef HPCG_NO_MPI
    sycl::event halo_ev = ExchangeHalo(A, x, main_queue, deps);
    std::vector<sycl::event> halo_deps({halo_ev});
    END_PROFILE_WAIT("SPMV:halo", halo_ev);
#else
    const std::vector<sycl::event>& halo_deps = deps;
    END_PROFILE("SPMV:halo");
#endif

    sycl::event last_ev;
    try {

#ifdef HPCG_USE_CUSTOM_KERNELS
        BEGIN_PROFILE("SPMV:gemv");
        last_ev = custom::SpGEMV(main_queue, sparseM, x.values, y.values, halo_deps);
        END_PROFILE_WAIT("SPMV:gemv", last_ev);
#else
        BEGIN_PROFILE("SPMV:local_gemv");
        last_ev = oneapi::mkl::sparse::gemv( main_queue, oneapi::mkl::transpose::nontrans,
                    1.0, hMatrixA, x.values, 0.0, y.values, deps);
        END_PROFILE_WAIT("SPMV:local_gemv", last_ev);

        if ( A.geom->size > 1 )
        {
            BEGIN_PROFILE("SPMV:nonlocal_gemv");
            auto nonlocal_gemv_ev = oneapi::mkl::sparse::gemv( main_queue, oneapi::mkl::transpose::nontrans,
                    1.0, hMatrixB, x.values, 0.0, optData->dtmp, halo_deps);
            END_PROFILE_WAIT("SPMV:nonlocal_gemv", nonlocal_gemv_ev);

            BEGIN_PROFILE("SPMV:yupdate");
            last_ev = main_queue.submit([&](handler &cgh) {
                cgh.depends_on({last_ev, nonlocal_gemv_ev});
                auto kernel = [=] (item<1> item) {
                    local_int_t row = item.get_id(0);
                    y.values[optData->bmap[row]] += optData->dtmp[row];
                };
                cgh.parallel_for<class ComputeSPMV_y_update>(range<1>(optData->nrow_b), kernel);
            });
            END_PROFILE_WAIT("SPMV:yupdate", last_ev);
        }
#endif

    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception in SPMV:\n" << e.what() << std::endl;
        ierr += 1;
        return last_ev;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception in SPMV:\n" << e.what() << std::endl;
        ierr += 1;
        return last_ev;
    }

#endif
    return last_ev;
}

sycl::event ComputeSPMV_DOT( const SparseMatrix & A, Vector & x, Vector & y, double * pAp,
                             sycl::queue & main_queue, int& ierr, const std::vector<sycl::event> & deps)
{
    struct optData *optData = (struct optData *)A.optimizationData;

#ifdef HPCG_USE_CUSTOM_KERNELS
    custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;
#else
    oneapi::mkl::sparse::matrix_handle_t hMatrixA = (oneapi::mkl::sparse::matrix_handle_t)optData->csrA;
    oneapi::mkl::sparse::matrix_handle_t hMatrixB = (oneapi::mkl::sparse::matrix_handle_t)optData->csrB;
#endif

    BEGIN_PROFILE("SPMV_DOT:halo");
#ifndef HPCG_NO_MPI
    sycl::event halo_ev = ExchangeHalo(A,x, main_queue, deps);
    std::vector<sycl::event> halo_deps({halo_ev});
    END_PROFILE_WAIT("SPMV_DOT:halo", halo_ev);
#else
    const std::vector<sycl::event>& halo_deps = deps;
    END_PROFILE("SPMV_DOT:halo");
#endif


    //gemv + dot product
    sycl::event last_ev;
    try {

#ifdef HPCG_USE_CUSTOM_KERNELS

        BEGIN_PROFILE("SPMV_DOT:gemv_dot");
        last_ev = custom::SpGEMV_DOT(main_queue, sparseM, x.values, y.values, pAp, halo_deps);
        END_PROFILE_WAIT("SPMV_DOT:gemv_dot", last_ev);

#else // HPCG_USE_CUSTOM_KERNELS
        BEGIN_PROFILE("SPMV_DOT:local_gemv");
        auto local_gemv_ev = oneapi::mkl::sparse::gemv( main_queue, oneapi::mkl::transpose::nontrans,
                1.0, hMatrixA, x.values, 0.0, y.values, deps);
        END_PROFILE_WAIT("SPMV_DOT:local_gemv", local_gemv_ev);

        BEGIN_PROFILE("SPMV_DOT:dot");
        last_ev = ComputeDotProductLocal(A.localNumberOfRows, x, y, pAp, main_queue, {local_gemv_ev});
        END_PROFILE_WAIT("SPMV_DOT:dot", last_ev);

        if ( A.geom->size > 1 ) {
            BEGIN_PROFILE("SPMV_DOT:nonlocal_gemv");
            auto nonlocal_gemv_ev = oneapi::mkl::sparse::gemv( main_queue, oneapi::mkl::transpose::nontrans,
                    1.0, hMatrixB, x.values, 0.0, optData->dtmp, halo_deps);
            END_PROFILE_WAIT("SPMV_DOT:nonlocal_gemv", nonlocal_gemv_ev);

            BEGIN_PROFILE("SPMV_DOT:yupdate");
            last_ev = main_queue.submit([&](sycl::handler &cgh) {
                cgh.depends_on({nonlocal_gemv_ev, last_ev});
                auto kernel = [=](sycl::item<1> idx, auto& pAp_correction) {
                    int i = idx.get_id(0);
                    const local_int_t ind = optData->bmap[i];
                    y.values[ind] += optData->dtmp[i];
                    pAp_correction += optData->dtmp[i] * x.values[ind];
                };
                auto pAp_reduction = sycl::reduction(pAp, sycl::plus<>());
                cgh.parallel_for<class ComputeSPMV_DOT_y_update>(sycl::range<1>(optData->nrow_b), pAp_reduction, kernel);
            });
            END_PROFILE_WAIT("SPMV_DOT:yupdate", last_ev);
        }
#endif // HPCG_USE_CUSTOM_KERNELS
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception in SPMV_DOT:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception in SPMV_DOT:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }

    return last_ev;
}
