Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dpnp/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_acceptance_fn_divide,
_acceptance_fn_negative,
_acceptance_fn_reciprocal,
_acceptance_fn_round,
_acceptance_fn_subtract,
_resolve_weak_types_all_py_ints,
)
Expand Down Expand Up @@ -1723,7 +1724,11 @@
"""

round = UnaryElementwiseFunc(
"round", ti._round_result_type, ti._round, _round_docstring
"round",
ti._round_result_type,
ti._round,
_round_docstring,
acceptance_fn=_acceptance_fn_round,
)
del _round_docstring

Expand Down
8 changes: 8 additions & 0 deletions dpnp/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
return True


def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev):
# for boolean input, prefer floating-point output over integral
if arg_dtype.kind == "b" and res_dt.kind != "f":
return False
return True


def _acceptance_fn_subtract(
arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
):
Expand Down Expand Up @@ -970,6 +977,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"_find_buf_dtype2",
"_to_device_supported_dtype",
"_acceptance_fn_default_unary",
"_acceptance_fn_round",
"_acceptance_fn_reciprocal",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ template <typename T>
struct RoundOutputType
{
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
Comment thread
ndgrigorian marked this conversation as resolved.
td_ns::TypeMapResultEntry<T, std::uint16_t>,
td_ns::TypeMapResultEntry<T, std::uint32_t>,
Expand Down
Loading