Skip to content

Commit

Permalink
Merge pull request #11 from dijopaul/main
Browse files Browse the repository at this point in the history
Namespace update as per review comments
  • Loading branch information
cad-audio authored Oct 8, 2024
2 parents 8064895 + a3581f1 commit fd955cf
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 347 deletions.
14 changes: 7 additions & 7 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
- op: add.out
kernels:
- arg_meta: null
kernel_name: torch::executor::add_out
kernel_name: impl::HiFi::add_out

- op: bmm.out
kernels:
Expand All @@ -45,12 +45,12 @@
- op: div.out
kernels:
- arg_meta: null
kernel_name: torch::executor::div_out
kernel_name: impl::HiFi::div_out

- op: div.out_mode
kernels:
- arg_meta: null
kernel_name: torch::executor::div_out_mode
kernel_name: impl::HiFi::div_out_mode

- op: embedding.out
kernels:
Expand All @@ -65,7 +65,7 @@
- op: mul.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mul_out
kernel_name: impl::HiFi::mul_out

- op: permute_copy.out
kernels:
Expand All @@ -75,7 +75,7 @@
- op: sigmoid.out
kernels:
- arg_meta: null
kernel_name: torch::executor::sigmoid_out
kernel_name: impl::HiFi::sigmoid_out

- op: slice_copy.Tensor_out
kernels:
Expand All @@ -90,12 +90,12 @@
- op: sub.out
kernels:
- arg_meta: null
kernel_name: torch::executor::sub_out
kernel_name: impl::HiFi::sub_out

- op: tanh.out
kernels:
- arg_meta: null
kernel_name: torch::executor::tanh_out
kernel_name: impl::HiFi::tanh_out

- op: view_copy.out
kernels:
Expand Down
102 changes: 17 additions & 85 deletions backends/cadence/hifi/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@
#include <executorch/runtime/platform/assert.h>
#include <executorch/backends/cadence/hifi/kernels/kernels.h>

namespace torch {
namespace executor {
using exec_aten::Scalar;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::runtime::can_cast;
using executorch::runtime::CppTypeToScalarType;
using executorch::runtime::KernelRuntimeContext;
using torch::executor::Error;

namespace impl {
namespace HiFi {
namespace native {
namespace {

namespace {
template <
bool can_cast,
typename CTYPE_A,
Expand All @@ -35,7 +43,7 @@ template <
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
Expand Down Expand Up @@ -89,7 +97,7 @@ Tensor& add_out(

ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
ScalarType alpha_type = torch::executor::native::utils::get_scalar_dtype(alpha);
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
ScalarType out_type = out.scalar_type();

Expand All @@ -98,7 +106,7 @@ Tensor& add_out(
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);

float alpha_val;
utils::extract_scalar(alpha, &alpha_val);
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);

constexpr auto name = "add.out";
constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
Expand Down Expand Up @@ -168,7 +176,7 @@ Tensor& add_out(
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
CTYPE_IN alpha_val;
utils::extract_scalar(alpha, &alpha_val);
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);

ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
AddInner<
Expand All @@ -184,83 +192,7 @@ Tensor& add_out(
return out;
}

Tensor& add_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
const Scalar& alpha,
Tensor& out) {

// Resize for dynamic shape
ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(out, a.sizes()) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensor_is_realhbbf16_type(out),
InvalidArgument,
out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
ScalarType common_type =
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);

/*When Half first compute the result in float precision
and then downcast to half*/
if (common_type == ScalarType::Half) {
common_type = ScalarType::Float;
}

constexpr auto name = "add.Scalar_out";

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
CTYPE_A,
CTYPE_B,
/*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);

CTYPE_B b_val;
utils::extract_scalar(b, &b_val);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);

CTYPE_IN alpha_val;
utils::extract_scalar(alpha, &alpha_val);

using CTYPE_OUT = typename std::conditional<
std::is_same<CTYPE_A, internal::F2>::value,
internal::F2,
CTYPE_IN>::type;

apply_unary_map_fn(
[b_casted, alpha_val](const CTYPE_A val_a) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN value = a_casted + alpha_val * b_casted;
return static_cast<CTYPE_OUT>(value);
},
a.const_data_ptr<CTYPE_A>(),
out.mutable_data_ptr<CTYPE_OUT>(),
out.numel());
});
});

return out;
}

} // namespace impl
} // namespace HiFi
} // namespace native
} // namespace executor
} // namespace torch
114 changes: 12 additions & 102 deletions backends/cadence/hifi/operators/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
#include <cmath>
#include <executorch/backends/cadence/hifi/kernels/kernels.h>

namespace torch {
namespace executor {
using exec_aten::Scalar;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
using torch::executor::Error;

namespace impl {
namespace HiFi {
namespace native {

namespace {
Expand Down Expand Up @@ -127,7 +133,7 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out", CTYPE_B, [&]() {
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
Expand Down Expand Up @@ -242,7 +248,7 @@ Tensor& div_out_mode(
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out_mode", CTYPE_B, [&]() {
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out_mode", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES(out_type, ctx, "div.out_mode", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[mode](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
Expand All @@ -265,103 +271,7 @@ Tensor& div_out_mode(
return out;
}

Tensor& div_scalar_out(
RuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
Tensor& out) {
(void)ctx;

// Resize for dynamic shape
ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(out, a.sizes()) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.Scalar_out", CTYPE, [&]() {
CTYPE_B b_val;
utils::extract_scalar(b, &b_val);
CTYPE b_casted = static_cast<CTYPE>(b_val);

apply_unary_map_fn(
[b_casted](const CTYPE_A val_a) {
CTYPE a_casted = static_cast<CTYPE>(val_a);
CTYPE value = a_casted / b_casted;
return static_cast<CTYPE>(value);
},
a.const_data_ptr<CTYPE_A>(),
out.mutable_data_ptr<CTYPE>(),
out.numel());
});
});
});

return out;
}

Tensor& div_scalar_mode_out(
RuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
exec_aten::optional<exec_aten::string_view> mode,
Tensor& out) {

// Resize for dynamic shape
ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(out, a.sizes()) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

constexpr auto name = "div.Scalar_mode_out";

ET_SWITCH_REALB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
ET_SWITCH_REAL_TYPES(out_type, ctx, name, CTYPE, [&]() {
CTYPE_B b_val;
utils::extract_scalar(b, &b_val);
CTYPE b_casted = static_cast<CTYPE>(b_val);

apply_unary_map_fn(
[b_casted, mode](const CTYPE_A val_a) {
CTYPE a_casted = static_cast<CTYPE>(val_a);
CTYPE value = a_casted / b_casted;
if (mode.has_value() && mode.value() == "trunc") {
value = std::trunc(value);
} else if (mode.has_value() && mode.value() == "floor") {
value = utils::floor_divide(a_casted, b_casted);
}
return value;
},
a.const_data_ptr<CTYPE_A>(),
out.mutable_data_ptr<CTYPE>(),
out.numel());
});
});
});

return out;
}

} // namespace impl
} // namespace HiFi
} // namespace native
} // namespace executor
} // namespace torch
Loading

0 comments on commit fd955cf

Please sign in to comment.