1 #include <ATen/native/Activation.h> 4 #include <ATen/cpu/vec256/vec256.h> 5 #include <ATen/native/TensorIterator.h> 6 #include <ATen/native/cpu/Loops.h> 8 namespace at {
namespace native {
11 static void threshold_kernel(TensorIterator& iter, Scalar threshold_scalar, Scalar value_scalar) {
12 AT_DISPATCH_ALL_TYPES(iter.dtype(),
"threshold_cpu", [&] {
13 using Vec = Vec256<scalar_t>;
14 scalar_t threshold = threshold_scalar.to<scalar_t>();
15 scalar_t value = value_scalar.to<scalar_t>();
18 [&](scalar_t x, scalar_t other) -> scalar_t {
19 return x <= threshold ? value : other;
21 [&](Vec x, Vec other) -> Vec {
22 return Vec::blendv(other, Vec(value), x <= Vec(threshold));
29 REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
Flush-To-Zero and Denormals-Are-Zero mode.