/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file GraphPermutation.cpp

 HPCG routine
 */

#include "PermuteGraph.hpp"
#include "PrefixSum.hpp"
#include "Helpers.hpp"
/*!

  Implements the Graph Permutation to enable better performance of the
  symmetric Gauss-Seidel (SYMGS) preconditioner, in particular the sparse TRSV
  operations in SYMGS to enable wider parallelism.

  Reorder the matrix A to improve DAG for SYMGS
  performance of the benchmark version of the preconditioned CG algorithm.

  @param[in] queue  sycl::queue & object to schedule tasks.
  @param[in] nrow   number of rows in A matrix
  @param[in] nnz    number of nonzeros in A matrix
  @param[in] ia     rowptr array for A matrix (size nrows+1)
  @param[in] ja     colind array for A matrix (size nnz)
  @param[in] a      values array for A matrix (size nnz)
  @param[inout] perm    permutation array for PTAP matrix (size nrows)
  @param[inout] invperm inverse permutation array for PTAP matrix (size nrows)
  @param[inout] pia     rowptr array for PTAP matrix (size nrows+1)
  @param[inout] pja     colind array for PTAP matrix (size nnz)
  @param[inout] pa      values array for PTAP matrix (size nnz)
  @param[in] dependencies vector of sycl::events to pass into work

  @return returns sycl::event to track actions

  @see ComputeSYMGS.cpp
  @see OptimizeProblem.cpp

*/
#include "stdio.h"
#include "math.h"
#include <complex>
#include <iostream>
#include <iomanip>
#include "hpcg.hpp"

namespace { // anonymous

#ifdef HPCG_LOCAL_LONG_LONG
typedef unsigned long long local_uint_t;
#else
typedef unsigned int local_uint_t;
#endif

// value for unset colors
#define UNSET -1


template <typename T>
using device_atomic_ref_relaxed = sycl::atomic_ref<T,
                                          sycl::memory_order::relaxed,
                                          sycl::memory_scope::device>;

//
// A mechanism for printing to file the permuted matrix in CSR3 format
//
sycl::event printGraph(sycl::queue &queue, int rank,
                       const local_int_t nrow,
                       char *nonzerosInRow,
                       local_int_t **mtxIndL, double **matrixValues,
                       local_int_t *perm_dev, local_int_t *invperm_dev, local_int_t *colors_dev, local_uint_t *weights_dev,
                       const std::vector<sycl::event> &dependencies)
{
    sycl::event last;

    char fname_perm[80];
    char fname_A[80];
    sprintf(fname_perm, "perm%04d_rank%d.txt", nrow, rank);
    sprintf(fname_A, "A%04d_rank%d.txt", nrow, rank);
    std::ofstream fout_perm;
    std::ofstream fout_A;
    fout_perm.open(fname_perm);
    fout_A.open(fname_A);


    local_int_t *perm_host = sycl::malloc_host<local_int_t>(nrow, queue);
    local_int_t *colors_host = sycl::malloc_host<local_int_t>(nrow, queue);
    local_int_t *invperm_host = sycl::malloc_host<local_int_t>(nrow, queue);
    local_uint_t *weights_host = sycl::malloc_host<local_uint_t>(nrow, queue);


    queue.memcpy(perm_host, perm_dev, nrow*sizeof(local_int_t), dependencies).wait();
    queue.memcpy(colors_host, colors_dev, nrow*sizeof(local_int_t), dependencies).wait();
    queue.memcpy(weights_host, weights_dev, (nrow)*sizeof(local_uint_t), dependencies).wait();
    queue.memcpy(invperm_host, invperm_dev, nrow*sizeof(local_int_t), dependencies).wait();

    fout_perm << "row perm color weight: " << nrow << std::endl;
    for ( local_int_t row = 0; row < nrow; ++row) {
        fout_perm << row << ", " << perm_host[row] << ", " << colors_host[row] << ", " << weights_host <<  std::endl;
    }

    fout_perm << "invperm " << nrow << std::endl;
    for ( local_int_t row = 0; row < nrow; ++row) {
        fout_perm << row << ", " << invperm_host[row] << std::endl;
    }

    if (sycl::get_pointer_type(nonzerosInRow, queue.get_context()) == sycl::usm::alloc::shared) {
        fout_A << "HPCG matrix: " << nrow << std::endl;
        for (local_int_t row = 0; row < std::min(320,nrow); ++row) {

            fout_A << "row " << std::setw(8) << row << ":";
            for (int j = 0; j < nonzerosInRow[row]; ++j) {
                fout_A << " " <<  std::setw(8) << mtxIndL[row][j];
            }
            fout_A << std::endl;
        }
    }
    else {
        std::cout<< "rank: "  << rank << ":  nonzerosInRow is not of Shared USM type"  << std::endl;
    }


//    last = queue.submit([&](sycl::handler &cgh) {
//        cgh.depends_on(dependencies);
//
//        sycl::stream out(32*65536, 16*1024 + 128, cgh);
//        local_int_t max_rows = std::min(100, nrow);
//        auto kernel = [=]() {
//            for (local_int_t row = 0; row < max_rows; ++row) {
//                out << "row: " << row << ", perm = " << perm[row] << ", invperm = " << invperm[row] << sycl::endl;
//            }
//        };
//        cgh.single_task(kernel);
//    });


    fout_perm.close();
    fout_A.close();
    sycl::free(perm_host, queue);
    sycl::free(colors_host, queue);
    sycl::free(invperm_host, queue);

    return last;
}


sycl::event printPerm(sycl::queue &queue, const int rank,
                       const local_int_t nrow, local_int_t *perm_dev, local_int_t *invperm_dev, local_int_t *colors_dev, local_uint_t *weights_dev,
                       const std::vector<sycl::event> &dependencies)
{
    sycl::event last;

    char fname_perm[80];
    char fname_A[80];
    sprintf(fname_perm, "perm%04d_rank%d.txt", nrow, rank);
    std::ofstream fout_perm;
    fout_perm.open(fname_perm);


    local_int_t *perm_host = sycl::malloc_host<local_int_t>(nrow, queue);
    local_int_t *invperm_host = sycl::malloc_host<local_int_t>(nrow, queue);
    local_int_t *colors_host = sycl::malloc_host<local_int_t>(nrow, queue);
    local_uint_t *weights_host = sycl::malloc_host<local_uint_t>(nrow, queue);


    queue.memcpy(perm_host, perm_dev, nrow*sizeof(local_int_t), dependencies).wait();
    queue.memcpy(invperm_host, invperm_dev, nrow*sizeof(local_int_t), dependencies).wait();
    queue.memcpy(colors_host, colors_dev, (nrow)*sizeof(local_int_t), dependencies).wait();
    queue.memcpy(weights_host, weights_dev, (nrow)*sizeof(local_uint_t), dependencies).wait();


    fout_perm << "row, perm[row], invperm[row], color[row], weight[row]" << std::endl;


    for ( local_int_t row = 0; row < nrow; ++row) {
        fout_perm << row << ", " << perm_host[row] << ", " << invperm_host[row] << ", " << colors_host[row] << ", " << (local_int_t)weights_host[row] << std::endl;
    }

    fout_perm.close();
    sycl::free(perm_host, queue);
    sycl::free(invperm_host, queue);
    sycl::free(weights_host, queue);
    sycl::free(colors_host, queue);

    return last;
}



//
// Convert the coloring into a symmetric permutation that preserve locality as
// possible
//
sycl::event convert_colors_to_perm(sycl::queue &queue,
                                   const local_int_t nrow,
                                   local_int_t &nColors,
                                   local_int_t *colors_dev,
                                   local_int_t *xcolors_dev, // updated in this function
                                   local_int_t *perm,    // updated in this function
                                   local_int_t *invperm, // updated in this function
                                   const std::vector<sycl::event> dependencies)
{

    const local_int_t maxWG = 2048; // Max number of workgroups to launch

    // Temporary array to store counts for each color processed by each workgroup
    // Array layout as follows with N = nColors - 1, M = maxWG - 1
    // |color0:wg0, color0:wg1, ... color0:wgM, color1:wg0, ...color1:wgM, ... colorN:wg0,...colorN:wgM|
    // The ith color count for each workgroup is contiguous to readily allow prefix_sum calls.
    local_int_t *temp = (local_int_t *) sparse_malloc_device(sizeof(local_int_t) * nColors * maxWG, queue);
    auto ev_init_temp_zero = queue.memset(temp, 0, sizeof(local_int_t) * nColors * maxWG);

    sycl::event ev_prefix_sum_xcolor;

    // do prefix sum of xcolor which is pointer to where colors start.
    // This allows us to ensure that each color has it's own swim lane for
    // filling in rows in the perm.
    ev_prefix_sum_xcolor = prefix_sum(queue, nColors+1, xcolors_dev, dependencies);

    // fill in perm based on coloring of graph,  try our best to
    // preserve as close to original ordering as is possible here now
    // to get better cache effects :) a "stable sort" implementation
    // could also work
    const local_int_t batch_size = 96;

    // Each workgroup will handle batch_size * unroll rows.
    // unroll is computed to ensure the number of workgroups
    // needed to cover nrows is always <= maxWG
    local_int_t unroll = ceil_div(nrow, (maxWG * batch_size));

    // Each workgroup computes partial values for perm and stores the
    // color counts in temp
    auto evt = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_prefix_sum_xcolor);
        cgh.depends_on(ev_init_temp_zero);
        auto kernel = [=](sycl::nd_item<1> item) {
            local_int_t wgID = item.get_group().get_group_id(0);
            local_int_t offset = wgID * unroll * batch_size;

            local_int_t lastRow = sycl::min((wgID + 1) * unroll * batch_size, nrow);

            local_int_t row = offset + item.get_local_id(0);
            while (row < lastRow) {
                local_int_t color = colors_dev[row];
                device_atomic_ref_relaxed<local_int_t> temp_ref(temp[color * maxWG + wgID]);
                local_int_t newrow = temp_ref.fetch_add(1);
                perm[row] = xcolors_dev[color] + newrow;
                // workgroup 0 already has the final values so we can update invperm
                if (wgID == 0) invperm[perm[row]] = row;

                row += batch_size;
            }
        };
        cgh.parallel_for<class PermuteGraph_write_perm1>(sycl::nd_range<1>(maxWG * batch_size, batch_size), kernel);
    });

    // Compute nColor independent prefix_sum on temp
    // Now temp contains the total count for each color for each previous workgroups combined.
    std::vector<sycl::event> ev_prefix_sum_temps;
    for (local_int_t i = 0; i < nColors; i++)
        ev_prefix_sum_temps.push_back(prefix_sum(queue, maxWG, temp + i * maxWG, {evt}));

    evt = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_prefix_sum_temps);
        auto kernel = [=](sycl::nd_item<1> item) {
            // workgroup 0 doesn't need to be updated with values from temp, so start with
            // workgroup 1.
            local_int_t wgID = item.get_group().get_group_id(0) + 1;
            local_int_t offset = wgID * unroll * batch_size;

            local_int_t lastRow = sycl::min((wgID + 1) * unroll * batch_size, nrow);

            local_int_t row = offset + item.get_local_id(0);
            while (row < lastRow) {
                local_int_t color = colors_dev[row];
                // temp[color * maxWG + wgID - 1]: get the total count for color
                // of all workgroups from 0...wgID-1
                perm[row] += temp[color * maxWG + wgID - 1];
                invperm[perm[row]] = row;

                row += batch_size;
            }
        };
        cgh.parallel_for<class PermuteGraph_write_perm2>(sycl::nd_range<1>((maxWG - 1) * batch_size, batch_size), kernel);
    });

    evt.wait();
    sycl::free(temp, queue);

    return evt;
}




sycl::event colorize_graph_direct(sycl::queue &queue, int rank,
                                  local_int_t nx, local_int_t ny, local_int_t nz,
                                  const local_int_t nrow, char *nonzerosInRow,
                                  local_int_t **mtxIndL, double **matrixValues,
                                  local_int_t *perm, local_int_t *invperm,
                                  local_int_t &nColors, local_int_t *xcolors_dev,
                                  const std::vector<sycl::event> dependencies)
{

    // color unset value
    const local_int_t unset = UNSET;
    const local_int_t max_colors = std::min(256, nrow);

    // allocate colors and set data to UNSET to start with
    local_int_t *colors_dev = (local_int_t *)sparse_malloc_device(sizeof(local_int_t)*(nrow), queue);
    auto ev_fill_colors = queue.fill(colors_dev, unset, nrow, dependencies);

    // allocate xcolor ptrs and set to 0 to start with
    auto ev_fill_xcolors = queue.fill(xcolors_dev, static_cast<local_int_t>(0), max_colors+1, dependencies);

    auto ev_do_c = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_fill_colors);
        cgh.depends_on(ev_fill_xcolors);
        cgh.parallel_for(
            sycl::range<3>(nx, ny, nz),
            [=](sycl::item<3> item) {
                const local_int_t xk = item.get_id(0);
                const local_int_t yk = item.get_id(1);
                const local_int_t zk = item.get_id(2);

                // specific names for colors below found by brute force,
                // trying to find one that produces most residual reduction in 50 iterations

                // old             : 7 5 2 3 6 1 4 0 : 3.55721431587827e-07, 55 iterations on 12 ranks
                // slightly better : 7 4 2 5 6 1 3 0 : 2.81343103130995e-07, 54 iterations on 12 ranks

                local_int_t color;
                if (xk % 2 == 0 && yk % 2 == 0 && zk % 2 == 0) {
                    color = 7;
                }
                else if (xk % 2 == 1 && yk % 2 == 0 && zk % 2 == 0) {
                    color = 6;
                }
                else if (xk % 2 == 0 && yk % 2 == 1 && zk % 2 == 0) {
                    color = 2;
                }
                else if (xk % 2 == 1 && yk % 2 == 1 && zk % 2 == 0) {
                    color = 3;
                }
                else if (xk % 2 == 0 && yk % 2 == 0 && zk % 2 == 1) {
                    color = 4;
                }
                else if (xk % 2 == 1 && yk % 2 == 0 && zk % 2 == 1) {
                    color = 1;
                }
                else if (xk % 2 == 0 && yk % 2 == 1 && zk % 2 == 1) {
                    color = 5;
                }
                else {
                    color = 0;
                }

                const local_int_t currentLocalRow = zk*nx*ny+yk*nx+xk;
                colors_dev[currentLocalRow] = color;
                device_atomic_ref_relaxed<local_int_t> xcolor_ref(xcolors_dev[color+1]);
                xcolor_ref++;
            });
    });

    nColors = 8;

    auto ev_fill_perm = convert_colors_to_perm(queue, nrow, nColors, colors_dev, xcolors_dev, perm, invperm, {ev_do_c});

    ev_fill_perm.wait();
#ifdef HPCG_DEBUG
    std::cout << "PermuteGraph rearranged " << nrow << " rows into " << nColors <<  " colors" << std::endl;
#endif

    sycl::free(colors_dev, queue);

//    HPCG_fout << "printing graph ..." << std::endl;
//    ev_chain = printGraph(queue, nrow, nnz, ia, ja, a, perm, invperm, {ev_fill_perm});
    // ev_chain = ev_write_perm;

    return ev_fill_perm;
}


//
// In case things aren't progressing with assigning colors, we have this backup
// function to assign a one row per color until the end and all are covered.
//
sycl::event fill_colors_backup(sycl::queue &queue,
                               local_int_t nrow,
                               local_int_t &c_id, // updated in this function
                               local_int_t *colors_dev, // updated in this function
                               local_int_t *xcolors_dev, // updated in this function
                               local_int_t *count_rows_colored_host, // updated in this function
                               const std::vector<sycl::event> &dependencies)
{

    const local_int_t unset = UNSET;

    local_int_t *count_rows_colored_dev = (local_int_t *)sparse_malloc_device(sizeof(local_int_t)*(1), queue);

    auto ev_prefix_sum_xcolor = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(dependencies);
        cgh.single_task<class PermuteGraph_backup_fill_color>([=] () {
            local_int_t next_color = c_id+2;
            count_rows_colored_dev[0] = 0;
            for (local_int_t i = 0; i < nrow; i++ )
                if (colors_dev[i] == unset) { // fill in the rest of rows with own color sets
                    colors_dev[i] = next_color;
                    xcolors_dev[next_color+1]++;
                    count_rows_colored_dev[0]++;
                    next_color++;
                }
        });
    });

   auto ev_memcpy = queue.memcpy(count_rows_colored_host, count_rows_colored_dev, 1*sizeof(local_int_t), {ev_prefix_sum_xcolor});
   ev_memcpy.wait();
   c_id += count_rows_colored_host[0]; // adjust color count

#ifdef HPCG_DEBUG
   HPCG_fout << "filled remaining " << count_rows_colored_host[0] << " rows with their own colors for a total of " << c_id << " colors" <<  std::endl;
#endif

   sycl::free(count_rows_colored_dev, queue);

   return ev_memcpy;
}




//
// Jones-Plassmann-Luby Coloring algorithm using weighted array based on black-red coloring of origin matrix
// for priorities in each pass to fill a color.  Generates 2 colors per pass using max and min of priorities among neighbors.
//
sycl::event colorize_graph(sycl::queue &queue, const int rank,
                           local_int_t nx, local_int_t ny, local_int_t nz,
                           const local_int_t nrow, char *nonzerosInRow,
                           local_int_t **mtxIndL, double **matrixValues,
                           local_int_t *perm, local_int_t *invperm,
                           local_int_t &nColors, local_int_t *xcolors_dev,
                           const std::vector<sycl::event> dependencies)
{
    sycl::event ev_chain;

    int width = 64 / sizeof(local_int_t);

    local_int_t *count_rows_colored_host = (local_int_t *)sparse_malloc_host(sizeof(local_int_t)*2*width, queue);
    count_rows_colored_host[0] = 0;
    count_rows_colored_host[width] = 0;

    // color unset value
    const local_int_t unset = UNSET;

    // allocate random integers for hash inputs
    local_uint_t *weight_dev = (local_uint_t *)sparse_malloc_device(sizeof(local_uint_t)*nrow, queue);

    auto ev_fill_weight = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(dependencies);
        cgh.parallel_for(
            sycl::range<3>(nx, ny, nz),
            [=](sycl::item<3> item) {
                const local_int_t xk = item.get_id(0);
                const local_int_t yk = item.get_id(1);
                const local_int_t zk = item.get_id(2);

                // hash-coloring below delivers near optimal actual coloring

                // The hash_color determines a weight which only indirectly leads
                // to the actual color.

                // Each iteration of the `while (c_id < nrow ... )` loop below assigns two
                // actual colors, one to the highest weight and one to the lowest weight.
                // Since hash_color 0 and 7 lead to the highest and lowest weights, they
                // end up mapping to actual colors 0 and 1.

                // Similarly, hash_color 1 and 6 correspond to actual colors 2 and 3,
                // and so on.

                local_int_t hash_color;
                if (xk % 2 == 0 && yk % 2 == 0 && zk % 2 == 0) {
                    hash_color = 4; // -> color 7
                }
                else if (xk % 2 == 1 && yk % 2 == 0 && zk % 2 == 0) {
                    hash_color = 2; // -> color 4
                }
                else if (xk % 2 == 0 && yk % 2 == 1 && zk % 2 == 0) {
                    hash_color = 1; // -> color 2
                }
                else if (xk % 2 == 1 && yk % 2 == 1 && zk % 2 == 0) {
                    hash_color = 5; // -> color 5
                }
                else if (xk % 2 == 0 && yk % 2 == 0 && zk % 2 == 1) {
                    hash_color = 3; // -> color 6
                }
                else if (xk % 2 == 1 && yk % 2 == 0 && zk % 2 == 1) {
                    hash_color = 7; // -> color 1
                }
                else if (xk % 2 == 0 && yk % 2 == 1 && zk % 2 == 1) {
                    hash_color = 6; // -> color 3
                }
                else {
                    hash_color = 0; // -> color 0
                }

                const local_int_t currentLocalRow = zk*nx*ny + yk*nx + xk;
                const local_int_t hash_value = hash_color * nx*ny*nz + currentLocalRow;
                weight_dev[currentLocalRow] = hash_value;
            });
    });

    // allocate colors and set data to UNSET to start with
    local_int_t *colors_dev = (local_int_t *)sparse_malloc_device(sizeof(local_int_t)*(nrow), queue);
    auto ev_fill_colors = queue.fill(colors_dev, unset, nrow, dependencies);

    // allocate xcolor ptrs and set to 0 to start with
    const local_int_t max_colors = std::min(nrow, 100);
    auto ev_fill_xcolors = queue.fill(xcolors_dev, static_cast<local_int_t>(0), max_colors+1, dependencies);

    local_int_t total_rows_colored = 0;

    // loop through colors
    sycl::event ev_do_c = ev_fill_weight;
    local_int_t c_id = 0;
    local_int_t zero_set = 0;
    bool done = false;
    while (c_id < nrow && !done) {
       ev_do_c = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_do_c);
            cgh.depends_on(ev_fill_colors);
            cgh.depends_on(ev_fill_xcolors);
            const local_int_t BLOCK_SIZE = 4;
            const local_int_t nBlocks = ceil_div(nrow, BLOCK_SIZE);

            auto kernel = [=](sycl::item<1> item) {
                const local_int_t block = item.get_id(0);
                local_int_t row = block * BLOCK_SIZE;
                const local_int_t row_en = sycl::min(row + BLOCK_SIZE, nrow);

                local_int_t min_colored = 0;
                local_int_t max_colored = 0;
                for (; row < row_en; ++row) {
                    if (colors_dev[row] == unset) {
                        bool write_max_color = true; // at end, we will write_max_color if row has maximal value among it's uncolored neighbors
                        bool write_min_color = true; // at end, we will write_min_color if row has minimal value among it's uncolored neighbors

                        const local_uint_t row_val = weight_dev[row]; // get value for row to compare to neighbors

                        // loop through neighbors to find independent set
                        for (char k = 0; k < nonzerosInRow[row]; ++k)
                        {
                            const local_int_t neighbor_ind = mtxIndL[row][k];
                            if (neighbor_ind < nrow) {
                                const local_int_t neighbor_color = colors_dev[neighbor_ind];

                                // exclude from neighbor set, the nodes previously colored and yourself
                                const bool is_self = (neighbor_ind == row);
                                const bool is_colored_before = (neighbor_color != unset) && (neighbor_color != c_id) && (neighbor_color != (c_id+1));
                                if ( is_self || is_colored_before ) {
                                    //continue; // move on to next neighbor
                                }
                                else {
                                    const local_uint_t neighbor_val = weight_dev[neighbor_ind]; // get neighbor value

                                    // decide if we are in either of maximal sets
                                    if (neighbor_val <= row_val) {
                                        write_min_color = false;
                                    }

                                    if (neighbor_val >= row_val) {
                                        write_max_color = false;
                                    }

                                    if (write_min_color == false && write_max_color == false) {
                                        break; // early exit since this row is not part of this set of independent sets
                                    }
                                }
                            } // neighbor_ind < nrow
                        } // for k < nonzerosInRow[row]

                        // write if is in an independent set
                        if (write_min_color) {
                            colors_dev[row] = c_id; // set color for row
                            min_colored++;
                        }
                        else if (write_max_color) {
                            colors_dev[row] = c_id+1; // set color for row
                            max_colored++;
                        }

                    } // if colors_dev[row] == unset

                } // for row < row_en

                if (min_colored > 0) {
                    device_atomic_ref_relaxed<local_int_t> xcolor_ref(xcolors_dev[c_id+1]); // count of rows of color c_id
                    xcolor_ref += min_colored;
                }

                if (max_colored > 0) {
                    device_atomic_ref_relaxed<local_int_t> xcolor_ref(xcolors_dev[(c_id+1)+1]); // count of rows of color c_id+1
                    xcolor_ref += max_colored;
                }

            };
            cgh.parallel_for<class PermuteGraph_colorize>(sycl::range<1>(nBlocks), kernel);
       });

       queue.memcpy(count_rows_colored_host, xcolors_dev+c_id+1, 2*sizeof(local_int_t), {ev_do_c}).wait();

       total_rows_colored += count_rows_colored_host[0] + count_rows_colored_host[1]; // keep running total of rows assigned a color

#ifdef HPCG_DEBUG
       HPCG_fout << "rank: " << rank << " xcolors[" << c_id << "+1] = " << count_rows_colored_host[0] << ", xcolors[" << c_id+1 << "+1] = " << count_rows_colored_host[1]  << std::endl;
#endif
       if (total_rows_colored >= nrow) break;

       if (count_rows_colored_host[1] == 0){
           zero_set++;
           if (count_rows_colored_host[0] == 0) zero_set++;
       }

       if (zero_set > 3) {
#ifdef HPCG_DEBUG
            HPCG_fout << "we had three iterations with 0 colors added to independent set" << std::endl;
            HPCG_fout << total_rows_colored << " rows out of " << nrow << " rows have been colored" << std::endl;
#endif
            ev_do_c = fill_colors_backup(queue, nrow, c_id, colors_dev, xcolors_dev, count_rows_colored_host, {ev_do_c});
            done = true;
            break; // exit from while loop
       }

       // haven't finished yet, so move forward two colors to next available color
       c_id+=2;

    } // while c_id < nrow

    if (zero_set > 3 ) {
        nColors = c_id+1;
    }
    else {
        if (count_rows_colored_host[1] == 0) {
            nColors = std::min(c_id + 1, nrow);
        }
        else {
            nColors = std::min(c_id + 2, nrow);
        }
    }

    auto ev_fill_perm = convert_colors_to_perm(queue, nrow, nColors, colors_dev, xcolors_dev, perm, invperm, {ev_do_c});

    ev_fill_perm.wait();
#ifdef HPCG_DEBUG
    std::cout << "PermuteGraph rearranged " << nrow << " rows into " << nColors << " colors" << std::endl;
#endif

//    printPerm(queue, rank, nrow, perm, invperm, colors_dev, weight_dev, {ev_fill_perm}).wait();
//    printGraph(queue, rank, nrow, nonzerosInRow, mtxIndL, matrixValues, perm, invperm, colors_dev, weight_dev, {ev_fill_perm}).wait();

    sycl::free(count_rows_colored_host, queue);
    sycl::free(weight_dev, queue);
    sycl::free(colors_dev, queue);

    return ev_fill_perm;
}





} // anonymous namespace



sycl::event PermuteGraph(sycl::queue &queue, const int rank,
                         local_int_t nx, local_int_t ny, local_int_t nz,
                         const local_int_t nrow, char *nonzerosInRow,
                         local_int_t **mtxIndL, double **matrixValues,
                         local_int_t *perm, local_int_t *invperm,
                         local_int_t &nColors, local_int_t *xcolors_dev,
                         const std::vector<sycl::event> dependencies)
{
//    std::cout << "entering PermuteGraph()" << std::endl;

    // colorize the rows to fill perm/invperm
    auto ev_colorize = colorize_graph(queue, rank,  nx, ny, nz, nrow, nonzerosInRow, mtxIndL, matrixValues, perm, invperm, nColors, xcolors_dev, dependencies);
    // auto ev_colorize = colorize_graph_direct(queue, rank, nx, ny, nz, nrow, nonzerosInRow, mtxIndL, matrixValues, perm, invperm, nColors, xcolors_dev, dependencies);

    return ev_colorize;

}
