From 7034378af021b6cedb7a32511d277b2b1d9ab3ec Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 14 Apr 2026 11:27:09 -0700 Subject: [PATCH 1/3] Add device-aware output dtype for dpt.round() with bool dtype --- dpnp/tensor/_elementwise_funcs.py | 7 ++++++- dpnp/tensor/_type_utils.py | 18 ++++++++++++++---- .../kernels/elementwise_functions/round.hpp | 2 +- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/dpnp/tensor/_elementwise_funcs.py b/dpnp/tensor/_elementwise_funcs.py index 5d38cad0c2a5..4040f33bf38e 100644 --- a/dpnp/tensor/_elementwise_funcs.py +++ b/dpnp/tensor/_elementwise_funcs.py @@ -33,6 +33,7 @@ _acceptance_fn_divide, _acceptance_fn_negative, _acceptance_fn_reciprocal, + _acceptance_fn_round, _acceptance_fn_subtract, _resolve_weak_types_all_py_ints, ) @@ -1723,7 +1724,11 @@ """ round = UnaryElementwiseFunc( - "round", ti._round_result_type, ti._round, _round_docstring + "round", + ti._round_result_type, + ti._round, + _round_docstring, + acceptance_fn=_acceptance_fn_round, ) del _round_docstring diff --git a/dpnp/tensor/_type_utils.py b/dpnp/tensor/_type_utils.py index 3da9e7994760..a29391972d53 100644 --- a/dpnp/tensor/_type_utils.py +++ b/dpnp/tensor/_type_utils.py @@ -133,6 +133,13 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev): return True +def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev): + # for boolean input, prefer floating-point output over integral + if arg_dtype.char == "?" and res_dt.kind in "biu": + return False + return True + + def _acceptance_fn_subtract( arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev ): @@ -188,17 +195,19 @@ def _dtype_supported_by_device_impl( def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn): + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + res_dt = query_fn(arg_dtype) if res_dt: - return None, res_dt + if _dtype_supported_by_device_impl(res_dt, _fp16, _fp64): + return None, res_dt - _fp16 = sycl_dev.has_aspect_fp16 - _fp64 = sycl_dev.has_aspect_fp64 all_dts = _all_data_types(_fp16, _fp64) for buf_dt in all_dts: if _can_cast(arg_dtype, buf_dt, _fp16, _fp64): res_dt = query_fn(buf_dt) - if res_dt: + if res_dt and _dtype_supported_by_device_impl(res_dt, _fp16, _fp64): acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev) if acceptable: return buf_dt, res_dt @@ -970,6 +979,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q): "_find_buf_dtype2", "_to_device_supported_dtype", "_acceptance_fn_default_unary", + "_acceptance_fn_round", "_acceptance_fn_reciprocal", "_acceptance_fn_default_binary", "_acceptance_fn_divide", diff --git a/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 18867a09bcef..cc6cb7baf3ad 100644 --- a/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -116,7 +116,7 @@ template struct RoundOutputType { using value_type = typename std::disjunction< - td_ns::TypeMapResultEntry, + // td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, From 772e981439a00f63fe85919d8201817752066007 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 14 Apr 2026 11:53:25 -0700 Subject: [PATCH 2/3] Remove bool type mapping from round kernel --- .../libtensor/include/kernels/elementwise_functions/round.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index cc6cb7baf3ad..b20166a4d505 100644 --- a/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -116,7 +116,6 @@ template struct RoundOutputType { using value_type = typename std::disjunction< - // td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, From 4a2309060695e35525d94eca5cff02844e1bdb06 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 15 Apr 2026 03:02:15 -0700 Subject: [PATCH 3/3] Update _acceptance_fn_round and revert _find_buf_dtype changes --- dpnp/tensor/_type_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/dpnp/tensor/_type_utils.py b/dpnp/tensor/_type_utils.py index a29391972d53..b03ca1e1c79d 100644 --- a/dpnp/tensor/_type_utils.py +++ b/dpnp/tensor/_type_utils.py @@ -135,7 +135,7 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev): def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev): # for boolean input, prefer floating-point output over integral - if arg_dtype.char == "?" and res_dt.kind in "biu": + if arg_dtype.kind == "b" and res_dt.kind != "f": return False return True @@ -195,19 +195,17 @@ def _dtype_supported_by_device_impl( def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn): - _fp16 = sycl_dev.has_aspect_fp16 - _fp64 = sycl_dev.has_aspect_fp64 - res_dt = query_fn(arg_dtype) if res_dt: - if _dtype_supported_by_device_impl(res_dt, _fp16, _fp64): - return None, res_dt + return None, res_dt + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 all_dts = _all_data_types(_fp16, _fp64) for buf_dt in all_dts: if _can_cast(arg_dtype, buf_dt, _fp16, _fp64): res_dt = query_fn(buf_dt) - if res_dt and _dtype_supported_by_device_impl(res_dt, _fp16, _fp64): + if res_dt: acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev) if acceptable: return buf_dt, res_dt