Added AddBroadcastReshapeLayer method to TfLiteParser

Change-Id: I6027f6dcdb3ed23505f0a9c780bd3e3d45d3daff
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 0e8d3c5..c45e794 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -463,6 +463,63 @@
     m_SubgraphConnections.clear();
 }
 
+void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
+                                            size_t operatorIndex,
+                                            IConnectableLayer *layer)
+{
+    CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+    BOOST_ASSERT(layer != nullptr);
+
+    const auto & subGraphPtr = m_Model->subgraphs[subgraphIndex];
+    const auto & operatorPtr = subGraphPtr->operators[operatorIndex];
+
+    BOOST_ASSERT(operatorPtr->inputs.size() > 1);
+
+    uint32_t reshapedInputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[0]);
+    TensorRawPtr tensorPtr = subGraphPtr->tensors[reshapedInputId].get();
+    uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[1]);
+    TensorRawPtr tensorPtr1 = subGraphPtr->tensors[inputId].get();
+
+    armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr);
+    armnn::TensorInfo inputTensorInfo    = ToTensorInfo(tensorPtr1);
+
+    if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions())
+    {
+        uint32_t id = reshapedInputId;
+        reshapedInputId = inputId;
+        inputId = id;
+
+        reshapedTensorInfo = ToTensorInfo(tensorPtr1);
+        inputTensorInfo = ToTensorInfo(tensorPtr);
+    }
+
+    uint32_t numDimensions = inputTensorInfo.GetNumDimensions();
+
+    std::vector<unsigned> reshapedDim;
+    for (unsigned int i = 0; i < reshapedTensorInfo.GetNumDimensions(); ++i)
+    {
+        reshapedDim.push_back(reshapedTensorInfo.GetShape()[i]);
+    }
+
+    std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
+    std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
+
+    reshapedTensorInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
+
+    std::string layerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+    armnn::ReshapeDescriptor desc;
+    desc.m_TargetShape = reshapedTensorInfo.GetShape();
+    armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+    reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+    reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
+    RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId});
+
+    armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1));
+    RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot);
+}
+
 INetworkPtr TfLiteParser::CreateNetworkFromBinaryFile(const char* graphFile)
 {
     ResetParser();
@@ -1008,6 +1065,9 @@
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), 1);
 
+    armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
+    armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+
     auto layerName = boost::str(boost::format("Add:%1%:%2%") % subgraphIndex % operatorIndex);
     IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str());
 
@@ -1015,7 +1075,14 @@
     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
 
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
-    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+    if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions())
+    {
+        AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer);
+    }
+    else
+    {
+        RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+    }
 
     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
 
@@ -1036,6 +1103,9 @@
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), 1);
 
+    armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
+    armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+
     auto layerName = boost::str(boost::format("Mul:%1%:%2%") % subgraphIndex % operatorIndex);
     IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str());
 
@@ -1043,7 +1113,14 @@
     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
 
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
-    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+    if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions())
+    {
+        AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer);
+    }
+    else
+    {
+        RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+    }
 
     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
 
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 6c26437..34ae07f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -124,6 +124,10 @@
 
     void ResetParser();
 
+    void AddBroadcastReshapeLayer(size_t subgraphIndex,
+                                  size_t operatorIndex,
+                                  armnn::IConnectableLayer* layer);
+
     /// Attach an activation layer to the one passed as a parameter
     armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
                                                       unsigned int outputSlot,
diff --git a/src/armnnTfLiteParser/test/Multiplication.cpp b/src/armnnTfLiteParser/test/Multiplication.cpp
index f7e2edd..dabf868 100644
--- a/src/armnnTfLiteParser/test/Multiplication.cpp
+++ b/src/armnnTfLiteParser/test/Multiplication.cpp
@@ -108,4 +108,40 @@
                                                                  45.0f, 50.0f, 55.0f } } });
 }
 
+struct MultiplicationBroadcastFixture4D1D : public MultiplicationFixture
+{
+    MultiplicationBroadcastFixture4D1D() : MultiplicationFixture("[ 1, 2, 2, 3 ]", "[ 1 ]", "[ 1, 2, 2, 3 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D)
+{
+    RunTest<4, float>(0, {{ "inputTensor1", { 0.0f,  1.0f,  2.0f,
+                                              3.0f,  4.0f,  5.0f,
+                                              6.0f,  7.0f,  8.0f,
+                                              9.0f, 10.0f, 11.0f } },
+                         { "inputTensor2", { 5.0f } } },
+                         {{ "outputTensor", { 0.0f,  5.0f, 10.0f,
+                                             15.0f, 20.0f, 25.0f,
+                                             30.0f, 35.0f, 40.0f,
+                                             45.0f, 50.0f, 55.0f } } });
+}
+
+struct MultiplicationBroadcastFixture1D4D : public MultiplicationFixture
+{
+    MultiplicationBroadcastFixture1D4D() : MultiplicationFixture("[ 1 ]", "[ 1, 2, 2, 3 ]", "[ 1, 2, 2, 3 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D)
+{
+    RunTest<4, float>(0, {{ "inputTensor1", { 3.0f } },
+                          { "inputTensor2", { 0.0f,  1.0f,  2.0f,
+                                              3.0f,  4.0f,  5.0f,
+                                              6.0f,  7.0f,  8.0f,
+                                              9.0f, 10.0f, 11.0f } } },
+                         {{ "outputTensor", { 0.0f,  3.0f,  6.0f,
+                                              9.0f, 12.0f, 15.0f,
+                                             18.0f, 21.0f, 24.0f,
+                                             27.0f, 30.0f, 33.0f } } });
+}
+
 BOOST_AUTO_TEST_SUITE_END()