#include "bifrost_nir.h"

#include "nir.h"
#include "nir_builder.h"
#include "nir_search.h"
#include "nir_search_helpers.h"

/* What follows is NIR algebraic transform code for the following 18
 * transforms:
 *    ('fmul', 'a', 2.0) => ('fadd', 'a', 'a')
 *    ('fmin', ('fmax', 'a', -1.0), 1.0) => ('fsat_signed_mali', 'a')
 *    ('fmax', ('fmin', 'a', 1.0), -1.0) => ('fsat_signed_mali', 'a')
 *    ('fmax', 'a', 0.0) => ('fclamp_pos_mali', 'a')
 *    ('fabs', ('fddx', 'a')) => ('fabs', ('fddx_must_abs_mali', 'a'))
 *    ('fabs', ('fddy', 'b')) => ('fabs', ('fddy_must_abs_mali', 'b'))
 *    ('b2f16', ('inot', 'a@8')) => ('b8csel', 'a', 0.0, 1.0)
 *    ('b2f16', 'a@8') => ('b8csel', 'a', 1.0, 0.0)
 *    ('b2f32', ('inot', 'a@8')) => ('b8csel', 'a', 0.0, 1.0)
 *    ('b2f32', 'a@8') => ('b8csel', 'a', 1.0, 0.0)
 *    ('b2f16', ('inot', 'a@16')) => ('b16csel', 'a', 0.0, 1.0)
 *    ('b2f16', 'a@16') => ('b16csel', 'a', 1.0, 0.0)
 *    ('b2f32', ('inot', 'a@16')) => ('b16csel', 'a', 0.0, 1.0)
 *    ('b2f32', 'a@16') => ('b16csel', 'a', 1.0, 0.0)
 *    ('b2f16', ('inot', 'a@32')) => ('b32csel', 'a', 0.0, 1.0)
 *    ('b2f16', 'a@32') => ('b32csel', 'a', 1.0, 0.0)
 *    ('b2f32', ('inot', 'a@32')) => ('b32csel', 'a', 0.0, 1.0)
 *    ('b2f32', 'a@32') => ('b32csel', 'a', 1.0, 0.0)
 */


static const nir_search_value_union bifrost_nir_lower_algebraic_late_values[] = {
   /* ('fmul', 'a', 2.0) => ('fadd', 'a', 'a') */
   { .variable = {
      { nir_search_value_variable, -1 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .constant = {
      { nir_search_value_constant, -1 },
      nir_type_float, { 0x4000000000000000 /* 2.0 */ },
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fmul,
      0, 1,
      { 0, 1 },
      -1,
   } },

   /* replace0_0 -> 0 in the cache */
   /* replace0_1 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fadd,
      -1, 0,
      { 0, 0 },
      -1,
   } },

   /* ('fmin', ('fmax', 'a', -1.0), 1.0) => ('fsat_signed_mali', 'a') */
   /* search1_0_0 -> 0 in the cache */
   { .constant = {
      { nir_search_value_constant, -1 },
      nir_type_float, { 0xbff0000000000000 /* -1.0 */ },
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fmax,
      1, 1,
      { 0, 4 },
      -1,
   } },
   { .constant = {
      { nir_search_value_constant, -1 },
      nir_type_float, { 0x3ff0000000000000 /* 1.0 */ },
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fmin,
      0, 2,
      { 5, 6 },
      -1,
   } },

   /* replace1_0 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fsat_signed_mali,
      -1, 0,
      { 0 },
      -1,
   } },

   /* ('fmax', ('fmin', 'a', 1.0), -1.0) => ('fsat_signed_mali', 'a') */
   /* search2_0_0 -> 0 in the cache */
   /* search2_0_1 -> 6 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fmin,
      1, 1,
      { 0, 6 },
      -1,
   } },
   /* search2_1 -> 4 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fmax,
      0, 2,
      { 9, 4 },
      -1,
   } },

   /* replace2_0 -> 0 in the cache */
   /* replace2 -> 8 in the cache */

   /* ('fmax', 'a', 0.0) => ('fclamp_pos_mali', 'a') */
   /* search3_0 -> 0 in the cache */
   { .constant = {
      { nir_search_value_constant, -1 },
      nir_type_float, { 0x0 /* 0.0 */ },
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fmax,
      0, 1,
      { 0, 11 },
      -1,
   } },

   /* replace3_0 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fclamp_pos_mali,
      -1, 0,
      { 0 },
      -1,
   } },

   /* ('fabs', ('fddx', 'a')) => ('fabs', ('fddx_must_abs_mali', 'a')) */
   /* search4_0_0 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fddx,
      -1, 0,
      { 0 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fabs,
      -1, 0,
      { 14 },
      -1,
   } },

   /* replace4_0_0 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fddx_must_abs_mali,
      -1, 0,
      { 0 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fabs,
      -1, 0,
      { 16 },
      -1,
   } },

   /* ('fabs', ('fddy', 'b')) => ('fabs', ('fddy_must_abs_mali', 'b')) */
   { .variable = {
      { nir_search_value_variable, -1 },
      0, /* b */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fddy,
      -1, 0,
      { 18 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fabs,
      -1, 0,
      { 19 },
      -1,
   } },

   /* replace5_0_0 -> 18 in the cache */
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fddy_must_abs_mali,
      -1, 0,
      { 18 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, -1 },
      false,
      false,
      false,
      nir_op_fabs,
      -1, 0,
      { 21 },
      -1,
   } },

   /* ('b2f16', ('inot', 'a@8')) => ('b8csel', 'a', 0.0, 1.0) */
   { .variable = {
      { nir_search_value_variable, 8 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, 8 },
      false,
      false,
      false,
      nir_op_inot,
      -1, 0,
      { 23 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b2f16,
      -1, 0,
      { 24 },
      -1,
   } },

   /* replace6_0 -> 23 in the cache */
   { .constant = {
      { nir_search_value_constant, 16 },
      nir_type_float, { 0x0 /* 0.0 */ },
   } },
   { .constant = {
      { nir_search_value_constant, 16 },
      nir_type_float, { 0x3ff0000000000000 /* 1.0 */ },
   } },
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b8csel,
      -1, 0,
      { 23, 26, 27 },
      -1,
   } },

   /* ('b2f16', 'a@8') => ('b8csel', 'a', 1.0, 0.0) */
   /* search7_0 -> 23 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b2f16,
      -1, 0,
      { 23 },
      -1,
   } },

   /* replace7_0 -> 23 in the cache */
   /* replace7_1 -> 27 in the cache */
   /* replace7_2 -> 26 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b8csel,
      -1, 0,
      { 23, 27, 26 },
      -1,
   } },

   /* ('b2f32', ('inot', 'a@8')) => ('b8csel', 'a', 0.0, 1.0) */
   /* search8_0_0 -> 23 in the cache */
   /* search8_0 -> 24 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b2f32,
      -1, 0,
      { 24 },
      -1,
   } },

   /* replace8_0 -> 23 in the cache */
   { .constant = {
      { nir_search_value_constant, 32 },
      nir_type_float, { 0x0 /* 0.0 */ },
   } },
   { .constant = {
      { nir_search_value_constant, 32 },
      nir_type_float, { 0x3ff0000000000000 /* 1.0 */ },
   } },
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b8csel,
      -1, 0,
      { 23, 32, 33 },
      -1,
   } },

   /* ('b2f32', 'a@8') => ('b8csel', 'a', 1.0, 0.0) */
   /* search9_0 -> 23 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b2f32,
      -1, 0,
      { 23 },
      -1,
   } },

   /* replace9_0 -> 23 in the cache */
   /* replace9_1 -> 33 in the cache */
   /* replace9_2 -> 32 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b8csel,
      -1, 0,
      { 23, 33, 32 },
      -1,
   } },

   /* ('b2f16', ('inot', 'a@16')) => ('b16csel', 'a', 0.0, 1.0) */
   { .variable = {
      { nir_search_value_variable, 16 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_inot,
      -1, 0,
      { 37 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b2f16,
      -1, 0,
      { 38 },
      -1,
   } },

   /* replace10_0 -> 37 in the cache */
   /* replace10_1 -> 26 in the cache */
   /* replace10_2 -> 27 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b16csel,
      -1, 0,
      { 37, 26, 27 },
      -1,
   } },

   /* ('b2f16', 'a@16') => ('b16csel', 'a', 1.0, 0.0) */
   /* search11_0 -> 37 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b2f16,
      -1, 0,
      { 37 },
      -1,
   } },

   /* replace11_0 -> 37 in the cache */
   /* replace11_1 -> 27 in the cache */
   /* replace11_2 -> 26 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b16csel,
      -1, 0,
      { 37, 27, 26 },
      -1,
   } },

   /* ('b2f32', ('inot', 'a@16')) => ('b16csel', 'a', 0.0, 1.0) */
   /* search12_0_0 -> 37 in the cache */
   /* search12_0 -> 38 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b2f32,
      -1, 0,
      { 38 },
      -1,
   } },

   /* replace12_0 -> 37 in the cache */
   /* replace12_1 -> 32 in the cache */
   /* replace12_2 -> 33 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b16csel,
      -1, 0,
      { 37, 32, 33 },
      -1,
   } },

   /* ('b2f32', 'a@16') => ('b16csel', 'a', 1.0, 0.0) */
   /* search13_0 -> 37 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b2f32,
      -1, 0,
      { 37 },
      -1,
   } },

   /* replace13_0 -> 37 in the cache */
   /* replace13_1 -> 33 in the cache */
   /* replace13_2 -> 32 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b16csel,
      -1, 0,
      { 37, 33, 32 },
      -1,
   } },

   /* ('b2f16', ('inot', 'a@32')) => ('b32csel', 'a', 0.0, 1.0) */
   { .variable = {
      { nir_search_value_variable, 32 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_inot,
      -1, 0,
      { 47 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b2f16,
      -1, 0,
      { 48 },
      -1,
   } },

   /* replace14_0 -> 47 in the cache */
   /* replace14_1 -> 26 in the cache */
   /* replace14_2 -> 27 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b32csel,
      -1, 0,
      { 47, 26, 27 },
      -1,
   } },

   /* ('b2f16', 'a@32') => ('b32csel', 'a', 1.0, 0.0) */
   /* search15_0 -> 47 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b2f16,
      -1, 0,
      { 47 },
      -1,
   } },

   /* replace15_0 -> 47 in the cache */
   /* replace15_1 -> 27 in the cache */
   /* replace15_2 -> 26 in the cache */
   { .expression = {
      { nir_search_value_expression, 16 },
      false,
      false,
      false,
      nir_op_b32csel,
      -1, 0,
      { 47, 27, 26 },
      -1,
   } },

   /* ('b2f32', ('inot', 'a@32')) => ('b32csel', 'a', 0.0, 1.0) */
   /* search16_0_0 -> 47 in the cache */
   /* search16_0 -> 48 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b2f32,
      -1, 0,
      { 48 },
      -1,
   } },

   /* replace16_0 -> 47 in the cache */
   /* replace16_1 -> 32 in the cache */
   /* replace16_2 -> 33 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b32csel,
      -1, 0,
      { 47, 32, 33 },
      -1,
   } },

   /* ('b2f32', 'a@32') => ('b32csel', 'a', 1.0, 0.0) */
   /* search17_0 -> 47 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b2f32,
      -1, 0,
      { 47 },
      -1,
   } },

   /* replace17_0 -> 47 in the cache */
   /* replace17_1 -> 33 in the cache */
   /* replace17_2 -> 32 in the cache */
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      nir_op_b32csel,
      -1, 0,
      { 47, 33, 32 },
      -1,
   } },

};



static const struct transform bifrost_nir_lower_algebraic_late_transforms[] = {
   { ~0, ~0, ~0 }, /* Sentinel */

   { 2, 3, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 12, 13, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 29, 30, 0 },
   { 35, 36, 0 },
   { 41, 42, 0 },
   { 45, 46, 0 },
   { 51, 52, 0 },
   { 55, 56, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 10, 8, 0 },
   { 12, 13, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 7, 8, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 15, 17, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 20, 22, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 25, 28, 0 },
   { 29, 30, 0 },
   { 31, 34, 0 },
   { 35, 36, 0 },
   { 39, 40, 0 },
   { 41, 42, 0 },
   { 43, 44, 0 },
   { 45, 46, 0 },
   { 49, 50, 0 },
   { 51, 52, 0 },
   { 53, 54, 0 },
   { 55, 56, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

};

static const struct per_op_table bifrost_nir_lower_algebraic_late_pass_op_table[nir_num_search_ops] = {
   [nir_op_fmul] = {
      .filter = (const uint16_t []) {
         0,
         1,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
      },
      
      .num_filtered_states = 2,
      .table = (const uint16_t []) {
      
         0,
         2,
         2,
         2,
      },
   },
   [nir_op_fmin] = {
      .filter = (const uint16_t []) {
         0,
         1,
         0,
         0,
         2,
         0,
         0,
         0,
         0,
         2,
         0,
         0,
         0,
         0,
      },
      
      .num_filtered_states = 3,
      .table = (const uint16_t []) {
      
         0,
         3,
         0,
         3,
         3,
         10,
         0,
         10,
         0,
      },
   },
   [nir_op_fmax] = {
      .filter = (const uint16_t []) {
         0,
         1,
         0,
         2,
         0,
         0,
         0,
         0,
         0,
         0,
         2,
         0,
         0,
         0,
      },
      
      .num_filtered_states = 3,
      .table = (const uint16_t []) {
      
         0,
         4,
         0,
         4,
         4,
         9,
         0,
         9,
         0,
      },
   },
   [nir_op_fabs] = {
      .filter = (const uint16_t []) {
         0,
         0,
         0,
         0,
         0,
         1,
         2,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
      },
      
      .num_filtered_states = 3,
      .table = (const uint16_t []) {
      
         0,
         11,
         12,
      },
   },
   [nir_op_fddx] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         5,
      },
   },
   [nir_op_fddy] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         6,
      },
   },
   [nir_search_op_b2f] = {
      .filter = (const uint16_t []) {
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         1,
         0,
         0,
         0,
         0,
         0,
      },
      
      .num_filtered_states = 2,
      .table = (const uint16_t []) {
      
         7,
         13,
      },
   },
   [nir_op_inot] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         8,
      },
   },
};

/* Mapping from state index to offset in transforms (0 being no transforms) */
static const uint16_t bifrost_nir_lower_algebraic_late_transform_offsets[] = {
   0,
   0,
   1,
   0,
   3,
   0,
   0,
   5,
   0,
   12,
   15,
   17,
   19,
   21,
};

static const nir_algebraic_table bifrost_nir_lower_algebraic_late_table = {
   .transforms = bifrost_nir_lower_algebraic_late_transforms,
   .transform_offsets = bifrost_nir_lower_algebraic_late_transform_offsets,
   .pass_op_table = bifrost_nir_lower_algebraic_late_pass_op_table,
   .values = bifrost_nir_lower_algebraic_late_values,
   .expression_cond = NULL,
   .variable_cond = NULL,
};

bool
bifrost_nir_lower_algebraic_late(nir_shader *shader)
{
   bool progress = false;
   bool condition_flags[1];
   const nir_shader_compiler_options *options = shader->options;
   const shader_info *info = &shader->info;
   (void) options;
   (void) info;

   STATIC_ASSERT(57 == ARRAY_SIZE(bifrost_nir_lower_algebraic_late_values));
   condition_flags[0] = true;

   nir_foreach_function(function, shader) {
      if (function->impl) {
         progress |= nir_algebraic_impl(function->impl, condition_flags,
                                        &bifrost_nir_lower_algebraic_late_table);
      }
   }

   return progress;
}

