Include input tensor size in NNAPI compilation cache key
PiperOrigin-RevId: 340384758
Change-Id: I5bb0d7465522b386a08beee8ff12c6a42526f440
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 43c3abc..f5a9b85 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -496,9 +496,9 @@
}
// Compute the hash of a TfLiteIntArray.
-uint64_t GetHash(const TfLiteIntArray* int_array) {
+uint64_t GetHash(const TfLiteIntArray* int_array, uint64_t combine_with = 0) {
constexpr auto kHashConst = 0x9e3779b97f4a7800ULL;
- uint64_t result = 0;
+ uint64_t result = combine_with;
for (auto i : TfLiteIntArrayView(int_array)) {
result = result ^ (i + kHashConst + (result << 10) + (result >> 4));
}
@@ -3559,15 +3559,27 @@
// token.
// TODO(b/133342794): use a generic token generator class.
uint64_t token_parts[4];
- // bits from model_token.
+ // Create bits from model_token.
+ // TODO(b/172237993): should not use std::hash, as that is not
+ // guaranteed to be stable across program invocations.
token_parts[0] = std::hash<std::string>{}(model_token);
- // bits from params->nodes_to_replace.
+ // Create bits from params->nodes_to_replace.
token_parts[1] = GetHash(params->nodes_to_replace);
- // bits from params->input_tensors.
+ // Create bits from params->input_tensors. These include the input tensor
+ // sizes, as the cached compilations are size-dependent.
token_parts[2] = GetHash(params->input_tensors);
+ for (int i : TfLiteIntArrayView(params->input_tensors)) {
+ if (i != kTfLiteOptionalTensor) {
+ TfLiteTensor* t = &context->tensors[i];
+ TF_LITE_ENSURE(context, t->dims);
+ token_parts[2] = GetHash(t->dims, token_parts[2]);
+ }
+ }
// bits from params->output_tensors.
token_parts[3] = GetHash(params->output_tensors);
// NNAPI requires the token to be 256bit long.
+ // TODO(b/172238515): get token size from header instead of
+ // hardcoding.
std::vector<uint8_t> nnapi_cache_token(32, 0);
// Copy the token bits.
uint8_t* p = reinterpret_cast<uint8_t*>(token_parts);
@@ -4478,8 +4490,9 @@
builder.AddTensorInput(input_index, hybrid_op));
break;
case kTfLiteInt64: {
- // We made sure that dimensions are constant and fit into int32 in
- // Map(), so we can safely create a new tensor with casted values.
+ // We made sure that dimensions are constant and fit into int32
+ // in Map(), so we can safely create a new tensor with casted
+ // values.
const int dims_size = dims_tensor.dims->data[0];
std::vector<int32_t> dims_int32(dims_size);
std::copy(dims_tensor.data.i64, dims_tensor.data.i64 + dims_size,