Support non-unit batch with FC operator in GPU delegate
PiperOrigin-RevId: 299779651
Change-Id: I68da12096ea646144bc9cd70dbf4eca0aaa55c91
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index d1521ff..b2bd6f0 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -1281,7 +1281,8 @@
conv = graph->NewNode(); // reset conv pointer!
Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
reshaped_value->tensor.type = DataType::FLOAT32;
- reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w);
+ reshaped_value->tensor.shape =
+ BHWC(input->tensor.shape.b, 1, 1, weights.shape.w);
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
reshape->operation.type = ToString(OperationType::RESHAPE);
ReshapeAttributes attr;