#include <iostream>
#include <cstdio>
#include <vector>
#include "gravity_sim_init.h"
using namespace std;
static const bool is3D = true; const static int total_particles = 4000;
static const int reset = 3000;
static const float min_dist = 3;
static const int width = 768, height = 768, depth = 768;
static const float eps = 10.f;
static const int gravity_constant = 20000;
float mass_range = 0;
float min_mass   = 0;
void initial_conditions_rand(
af::array &mass, vector<af::array> &pos, vector<af::array> &vels, vector<af::array> &forces) {
     for(int i=0; i< (int)pos.size(); ++i) {
        pos[i]    = 
af::randn(total_particles) * width + width;
        vels[i]   = 0 * 
af::randu(total_particles) - 0.5;
    }
}
void initial_conditions_galaxy(
af::array &mass, vector<af::array> &pos, vector<af::array> &vels, vector<af::array> &forces) {
     initial_cond_consts = initial_cond_consts.
T();
    for(int i=0; i< (int)pos.size(); ++i) {
        pos[i]    = 
af::randn(total_particles) * width + width;
        vels[i]   = 0 * (
af::randu(total_particles) - 0.5);
    }
    mass    =  initial_cond_consts(
span, 0);
    pos[0]  = (initial_cond_consts(
span, 1)/32 + 0.6) * width;
    pos[1]  = (initial_cond_consts(
span, 2)/32 + 0.3) * height;
    pos[2]  = (initial_cond_consts(
span, 3)/32 + 0.5) * depth;
    vels[0] = (initial_cond_consts(
span, 4)/32) * width;
    vels[1] = (initial_cond_consts(
span, 5)/32) * height;
    vels[2] = (initial_cond_consts(
span, 6)/32) * depth;
    pos[0](
seq(0, pos[0].dims(0)-1, 2)) -=  0.4 * width;
    pos[1](
seq(0, pos[0].dims(0)-1, 2)) +=  0.4 * height;
    vels[0](
seq(0, pos[0].dims(0)-1, 2)) +=  4;
    min_mass = min<float>(mass);
    mass_range = max<float>(mass) - min<float>(mass);
}
af::array ids_from_pos(vector<af::array> &pos) {
     return (pos[0].as(
u32) * height) + pos[1].
as(
u32);
 }
af::array ids_from_3D(vector<af::array> &pos, 
float Rx, 
float Ry, 
float Rz) {
     af::array y0  = (pos[1] - height/2) * 
cos(Rx) + (pos[2] - depth/2) * 
sin(Rx);
     af::array z0  = (pos[2] - depth/2)  * 
cos(Rx) - (pos[2] - depth/2) * 
sin(Rx);
     x2 += width/2;
    y2 += height/2;
}
    x2 += width/2;
    y2 += height/2;
}
void simulate(
af::array &mass, vector<af::array> &pos, vector<af::array> &vels, vector<af::array> &forces, 
float dt) {
     for(int i=0; i< (int)pos.size(); ++i) {
        pos[i] += vels[i] * dt;
    }
    
    vector<af::array> diff(pos.size());
    for(int i=0; i< (int)pos.size(); ++i) {
        diff[i] = 
tile(pos[i], 1, pos[i].dims(0)) - 
transpose(
tile(pos[i], 1, pos[i].dims(0)));
        dist += (diff[i]*diff[i]);
    }
    dist *= dist * dist;
    for(int i=0; i< (int)pos.size(); ++i) {
        
        forces[i] = diff[i] / dist;
        forces[i].eval();
        
        
        
        
        forces[i] = 
matmul(forces[i].T(), mass);
        
        forces[i] *= (gravity_constant);
        forces[i].eval();
        
        vels[i] += forces[i] * dt;
        vels[i].eval();
        
        
        
        
    }
}
void collisions(vector<af::array> &pos, vector<af::array> &vels, bool is3D) {
    
    af::array invalid_x = -2 * (pos[0] > width-1 || pos[0] < 0) + 1;
     af::array invalid_y = -2 * (pos[1] > height-1 || pos[1] < 0) + 1;
     
    
    vels[0]= invalid_x * vels[0] ;
    vels[1]= invalid_y * vels[1] ;
    pos[0] = projected_px;
    pos[1] = projected_py;
    if(is3D){
        af::array invalid_z = -2 * (pos[2] > depth-1 || pos[2] < 0) + 1;
         vels[2]= invalid_z * vels[2] ;
        pos[2] = projected_pz;
    }
}
int main(int argc, char *argv[])
{
    try {
        af::Window myWindow(width, height, 
"Gravity Simulation using ArrayFire");
         int frame_count = 0;
        
        const int dims = (is3D)? 3 : 2;
        vector<af::array> pos(dims);
        vector<af::array> vels(dims);
        vector<af::array> forces(dims);
        
        initial_conditions_galaxy(mass, pos, vels, forces);
        while(!myWindow.
close()) {
             ids = (is3D)? ids_from_3D(pos, 0, 0+frame_count/150.f, 0, mid) :  ids_from_pos(pos);
            
            image(ids) += 4.f;
            mid = mass(
span) > (min_mass + 2*mass_range/3);
            ids = (is3D)? ids_from_3D(pos, 0, 0+frame_count/150.f, 0, mid) :  ids_from_pos(pos);
            
            image(ids) += 4.f;
            ids = (is3D)? ids_from_3D(pos, 0, 0+frame_count/150.f, 0) :  ids_from_pos(pos);
            
            image(ids) += 4.f;
            frame_count++;
            
            if(frame_count % reset == 0) {
                initial_conditions_galaxy(mass, pos, vels, forces);
            }
            
            simulate(mass, pos, vels, forces, dt);
            
            collisions(pos, vels, is3D);
        }
        fprintf(stderr, 
"%s\n", e.
what());
        throw;
    }
    return 0;
}