diff --git a/dpnp/tensor/_elementwise_funcs.py b/dpnp/tensor/_elementwise_funcs.py index 5d38cad0c2a..4040f33bf38 100644 --- a/dpnp/tensor/_elementwise_funcs.py +++ b/dpnp/tensor/_elementwise_funcs.py @@ -33,6 +33,7 @@ _acceptance_fn_divide, _acceptance_fn_negative, _acceptance_fn_reciprocal, + _acceptance_fn_round, _acceptance_fn_subtract, _resolve_weak_types_all_py_ints, ) @@ -1723,7 +1724,11 @@ """ round = UnaryElementwiseFunc( - "round", ti._round_result_type, ti._round, _round_docstring + "round", + ti._round_result_type, + ti._round, + _round_docstring, + acceptance_fn=_acceptance_fn_round, ) del _round_docstring diff --git a/dpnp/tensor/_type_utils.py b/dpnp/tensor/_type_utils.py index 3da9e799476..b03ca1e1c79 100644 --- a/dpnp/tensor/_type_utils.py +++ b/dpnp/tensor/_type_utils.py @@ -133,6 +133,13 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev): return True +def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev): + # for boolean input, prefer floating-point output over integral + if arg_dtype.kind == "b" and res_dt.kind != "f": + return False + return True + + def _acceptance_fn_subtract( arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev ): @@ -970,6 +977,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q): "_find_buf_dtype2", "_to_device_supported_dtype", "_acceptance_fn_default_unary", + "_acceptance_fn_round", "_acceptance_fn_reciprocal", "_acceptance_fn_default_binary", "_acceptance_fn_divide", diff --git a/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 18867a09bce..b20166a4d50 100644 --- a/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -116,7 +116,6 @@ template struct RoundOutputType { using value_type = typename std::disjunction< - td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry,