Prevent a division by 0 in division ops.
PiperOrigin-RevId: 385223169
Change-Id: Ia4228960b5d2aa44480385f74bdd70d21a3613c3
diff --git a/tensorflow/lite/kernels/div.cc b/tensorflow/lite/kernels/div.cc
index f744b4b..51623a9 100644
--- a/tensorflow/lite/kernels/div.cc
+++ b/tensorflow/lite/kernels/div.cc
@@ -216,9 +216,23 @@
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
- if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ // TODO(b/193904910): This can written with C++ templates
+#define TF_LITE_CHECK_DIV_NON_ZERO(data_type) \
+ const auto* input2_data = GetTensorData<data_type>(input2); \
+ const size_t input2_elements = input2->bytes / sizeof(data_type); \
+ for (size_t i = 0; i < input2_elements; i++) { \
+ TF_LITE_ENSURE(context, input2_data[i] != 0); \
+ }
+
+ if (output->type == kTfLiteFloat32) {
+ // Div by zero seems ok in this case, just like in TF case infinities are
+ // returned. So we don't do a check at this point.
+ EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteInt32) {
+ TF_LITE_CHECK_DIV_NON_ZERO(int32_t);
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
+ TF_LITE_CHECK_DIV_NON_ZERO(uint8_t);
TF_LITE_ENSURE_OK(
context, EvalQuantized<kernel_type>(context, node, params, data, input1,
input2, output));
@@ -229,6 +243,7 @@
output->type);
return kTfLiteError;
}
+#undef TF_LITE_CHECK_DIV_NON_ZERO
return kTfLiteOk;
}