/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of DPC++ device API for RNG
*       discrete distributions with oneapi::mkl::rng::philox4x32x10
*       random number generator
*
*       Discrete distributions list:
*           oneapi::mkl::rng::device::poisson
*           oneapi::mkl::rng::device::bernoulli
*           oneapi::mkl::rng::device::bits
*
*       The supported integer data types for random numbers are:
*           int32_t
*           uint32_t
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>

#include <CL/sycl.hpp>
#include "oneapi/mkl/rng/device.hpp"

// local includes
#include "common_for_rng_examples.hpp"

// example parameters defines
#define SEED    777
#define N       1000
#define N_PRINT 10

const int vec_size = 4;

template <typename IntType>
bool run_poisson_example(sycl::queue& queue) {
    // prepare array for random numbers
    std::vector<IntType> r(N);

    double lambda = 2.0;

    // submit a kernel to generate on device
    {
        sycl::buffer<IntType, 1> r_buf(r.data(), r.size());

        try {
            auto event = queue.submit([&](sycl::handler& cgh) {
                auto r_acc = r_buf.template get_access<sycl::access::mode::write>(cgh);
                cgh.parallel_for(sycl::range<1>(N / vec_size), [=](sycl::item<1> item) {
                    size_t id = item.get_id(0);
                    // skipping by id * vec_size is OK for lambdas < 27 and lambdas > 1000
                    // for 27 <= lambda < 1000 acceptance rejection method is used, so
                    // larger skips are needed to avoid usage of the same random number variate
                    // by multiple threads. Skipping techniques can be seen in parallel_generation_techniques.cpp
                    oneapi::mkl::rng::device::philox4x32x10<vec_size> engine(SEED, id * vec_size);
                    oneapi::mkl::rng::device::poisson<IntType> distr(lambda);

                    auto res = oneapi::mkl::rng::device::generate(distr, engine);

                    res.store(id, r_acc);
                });
            });
            event.wait_and_throw();
        }
        catch (sycl::exception const& e) {
            std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
            return 1;
        }

        std::cout << "\t\tOutput of generator:" << std::endl;

        std::cout << "\n\t\tOutput of generator with poisson distribution:" << std::endl;
        print_output(r_buf, N_PRINT);
    } // buffer life-time ends

    // validation
    return statistics<IntType, oneapi::mkl::rng::device::poisson<IntType>>{}.check(
        r, oneapi::mkl::rng::device::poisson<IntType>{ lambda });
}

template <typename IntType>
bool run_bernoulli_example(sycl::queue& queue) {
    // prepare array for random numbers
    std::vector<IntType> r(N);

    float p = 0.3f;

    // submit a kernel to generate on device
    {
        sycl::buffer<IntType, 1> r_buf(r.data(), r.size());

        try {
            auto event = queue.submit([&](sycl::handler& cgh) {
                auto r_acc = r_buf.template get_access<sycl::access::mode::write>(cgh);
                cgh.parallel_for(sycl::range<1>(N / vec_size), [=](sycl::item<1> item) {
                    size_t id = item.get_id(0);
                    oneapi::mkl::rng::device::philox4x32x10<vec_size> engine(SEED, id * vec_size);
                    oneapi::mkl::rng::device::bernoulli<IntType> distr(p);

                    auto res = oneapi::mkl::rng::device::generate(distr, engine);

                    res.store(id, r_acc);
                });
            });
            event.wait_and_throw();
        }
        catch (sycl::exception const& e) {
            std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
            return 1;
        }

        std::cout << "\t\tOutput of generator:" << std::endl;

        std::cout << "\n\t\tOutput of generator with bernoulli distribution:" << std::endl;
        print_output(r_buf, N_PRINT);
    } // buffer life-time ends

    // validation
    return statistics<IntType, oneapi::mkl::rng::device::bernoulli<IntType>>{}.check(
        r, oneapi::mkl::rng::device::bernoulli<IntType>{ p });
}

template <typename IntType>
bool run_bits_example(sycl::queue& queue) {
    // prepare array for random numbers
    std::vector<IntType> r(N);

    // submit a kernel to generate on device
    {
        sycl::buffer<IntType, 1> r_buf(r.data(), r.size());

        try {
            auto event = queue.submit([&](sycl::handler& cgh) {
                auto r_acc = r_buf.template get_access<sycl::access::mode::write>(cgh);
                cgh.parallel_for(sycl::range<1>(N / vec_size), [=](sycl::item<1> item) {
                    size_t id = item.get_id(0);
                    oneapi::mkl::rng::device::philox4x32x10<vec_size> engine(SEED, id * vec_size);
                    oneapi::mkl::rng::device::bits<IntType> distr;

                    auto res = oneapi::mkl::rng::device::generate(distr, engine);

                    res.store(id, r_acc);
                });
            });
            event.wait_and_throw();
        }
        catch (sycl::exception const& e) {
            std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
            return 1;
        }

        std::cout << "\t\tOutput of generator:" << std::endl;

        std::cout << "\n\t\tOutput of generator with bits distribution:" << std::endl;
        print_output(r_buf, N_PRINT);
    } // buffer life-time ends

    // validation
    return false;
}

template <typename IntType>
bool run_bits_mcg59_example(sycl::queue& queue) {
    // prepare array for random numbers
    std::vector<IntType> r(N);

    // submit a kernel to generate on device
    {
        sycl::buffer<IntType, 1> r_buf(r.data(), r.size());

        try {
            auto event = queue.submit([&](sycl::handler& cgh) {
                auto r_acc = r_buf.template get_access<sycl::access::mode::write>(cgh);
                cgh.parallel_for(sycl::range<1>(N / vec_size), [=](sycl::item<1> item) {
                    size_t id = item.get_id(0);
                    oneapi::mkl::rng::device::mcg59<vec_size> engine(SEED, id * vec_size);
                    oneapi::mkl::rng::device::bits<IntType> distr;

                    auto res = oneapi::mkl::rng::device::generate(distr, engine);

                    res.store(id, r_acc);
                });
            });
            event.wait_and_throw();
        }
        catch (sycl::exception const& e) {
            std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
            return 1;
        }

        queue.wait();

        std::cout << "\t\tOutput of generator:" << std::endl;

        std::cout << "\n\t\tOutput of mcg59 with bits distribution:" << std::endl;

        // convert to std::uint32_t
        std::vector<std::uint32_t> r32(2*N);
        auto r_accessor = r_buf.template get_access<sycl::access::mode::read>();
        int j = 0;
        for (int i = 0; i < 2*N; i+=2)
        {
            std::uint64_t number = r_accessor[j];
            std::uint32_t* ptr = reinterpret_cast<std::uint32_t*>(&number);
            r32[i] = ptr[0];
            r32[i+1] = ptr[1];
            j++;
        }
        sycl::buffer<std::uint32_t, 1> r_buf32(r32.data(), r32.size());

        print_output(r_buf32, N_PRINT);
    } // buffer life-time ends

    // validation
    return false;
}

void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Generate random numbers with discrete rng distributions:" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   oneapi::mkl::rng::device::poisson" << std::endl;
    std::cout << "#   oneapi::mkl::rng::device::bernoulli" << std::endl;
    std::cout << "#   oneapi::mkl::rng::device::bits" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported integer types:" << std::endl;
    std::cout << "#   int32_t uint32_t" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_host -- only runs host implementation
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: host, cpu and gpu devices
//

int main(int argc, char** argv) {
    // catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception during generation:\n"
                          << e.what() << std::endl;
            }
        }
    };

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";
            sycl::queue queue(my_dev, exception_handler);
            std::cout << "\tRunning with uint32_t data type:" << std::endl;
            if (run_bits_example<std::uint32_t>(queue) ||
                run_bernoulli_example<std::uint32_t>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                if (run_poisson_example<std::uint32_t>(queue)) {
                    std::cout << "FAILED" << std::endl;
                    return 1;
                }
            }
            std::cout << "\tRunning with int32_t data type:" << std::endl;
            if (run_bernoulli_example<std::int32_t>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                if (run_poisson_example<std::int32_t>(queue)) {
                    std::cout << "FAILED" << std::endl;
                    return 1;
                }
            }
            std::cout << "\tRunning bits with mcg59:" << std::endl;
            if (run_bits_mcg59_example<std::uint64_t>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
