//==----------- dot_product.hpp ------- SYCL dot-product --------------------==//
//
// Copyright (C) 2022 Intel Corporation
//
// This software and the related documents are Intel copyrighted materials, and
// your use of them is governed by the express license under which they were
// provided to you ("License"). Unless the License provides otherwise, you may not
// use, modify, copy, publish, distribute, disclose or transmit this software or
// the related documents without Intel's prior written permission.
//
// This software and the related documents are provided as is, with no express
// or implied warranties, other than those that are expressly stated in the
// License.
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// DP4A extension

#pragma once

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace intel {

union Us {
  char s[4];
  int i;
};
union Uu {
  unsigned char s[4];
  int i;
};

int dot_acc(int pa, int pb, int c) {
  Us a = *(reinterpret_cast<Us *>(&pa));
  Us b = *(reinterpret_cast<Us *>(&pb));
  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
         c;
}

int dot_acc(unsigned int pa, unsigned int pb, int c) {
  Uu a = *(reinterpret_cast<Uu *>(&pa));
  Uu b = *(reinterpret_cast<Uu *>(&pb));
  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
         c;
}

int dot_acc(int pa, unsigned int pb, int c) {
  Us a = *(reinterpret_cast<Us *>(&pa));
  Uu b = *(reinterpret_cast<Uu *>(&pb));
  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
         c;
}

int dot_acc(unsigned int pa, int pb, int c) {
  Uu a = *(reinterpret_cast<Uu *>(&pa));
  Us b = *(reinterpret_cast<Us *>(&pb));
  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
         c;
}

int dot_acc(vec<int8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
         c;
}

int dot_acc(vec<uint8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
         c;
}

int dot_acc(vec<uint8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
         c;
}

int dot_acc(vec<int8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
         c;
}

} // namespace intel
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
