#include "cl_helpers.h"
#include <mutex>
#include <complex>
#include <cmath>
#include <vector>
#include <iostream>
#include <iterator>
#include <algorithm>
const unsigned DIMX = 1000;
const unsigned DIMY = 800;
static const float ZMIN = 0.1f;
static const float ZMAX = 10.f;
const float DX = 0.005;
static const unsigned ZSIZE = (ZMAX-ZMIN)/DX+1;
using namespace std;
#define USE_FORGE_OPENCL_COPY_HELPERS
static const std::string sincos_surf_kernel =
"kernel void generateCurve(global float* out, const float t, const float dx, const float zmin, const unsigned SIZE)\n"
"{\n"
"    int offset = get_global_id(0);\n"
"\n"
"    float z = zmin + offset*dx;\n"
"    if (offset < SIZE) {\n"
"       out[offset*3 + 0] = cos(z*t+t)/z;\n"
"       out[offset*3 + 1] = sin(z*t+t)/z;\n"
"       out[offset*3 + 2] = z + 0.1*sin(t);\n"
"    }\n"
"}\n";
inline int divup(int a, int b)
{
    return (a+b-1)/b;
}
void kernel(cl::Buffer& devOut, cl::CommandQueue& queue, float t)
{
    static std::once_flag   compileFlag;
    static cl::Program      prog;
    static cl::Kernel       kern;
    std::call_once(compileFlag,
        [queue]() {
        prog = cl::Program(queue.getInfo<CL_QUEUE_CONTEXT>(), sincos_surf_kernel, true);
            kern = cl::Kernel(prog, "generateCurve");
        });
    NDRange global(ZSIZE);
    kern.setArg(0, devOut);
    kern.setArg(1, t);
    kern.setArg(2, DX);
    kern.setArg(3, ZMIN);
    kern.setArg(4, ZSIZE);
    queue.enqueueNDRangeKernel(kern, cl::NullRange, global);
}
int main(void)
{
    try {
        
        forge::Window wnd(DIMX, DIMY, 
"Three dimensional line plot demo");
         
        context = createCLGLContext(wnd);
        Device device = context.getInfo<CL_CONTEXT_DEVICES>()[0];
        queue = CommandQueue(context, device);
        cl::Buffer devOut(context, CL_MEM_READ_WRITE, sizeof(float) * ZSIZE * 3);
        static float t=0;
        kernel(devOut, queue, t);
        
        do {
            t+=0.01;
            kernel(devOut, queue, t);
        releaseGLBuffer(handle);
        std::cout << err.
what() << 
"(" << err.
err() << 
")" << std::endl;
    } catch (cl::Error err) {
        std::cout << err.what() << "(" << err.err() << ")" << std::endl;
    }
    return 0;
}