Implement hasMetadata() and getModelMetadata() in the MetadataExtractor

PiperOrigin-RevId: 303829852
Change-Id: I062874670d8838a00162901a7c45012dad69a63a
diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
index 3bd60ed..054ea0e 100644
--- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
+++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
@@ -25,6 +25,7 @@
 import org.tensorflow.lite.DataType;
 import org.tensorflow.lite.Tensor.QuantizationParams;
 import org.tensorflow.lite.schema.Tensor;
+import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
 import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
 
 /**
@@ -96,6 +97,11 @@
     zipFile = createZipFile(buffer);
   }
 
+  /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
+  public Boolean hasMetadata() {
+    return metadataInfo != null;
+  }
+
   /**
    * Gets the packed associated file with the specified {@code fileName}.
    *
@@ -154,6 +160,16 @@
     return modelInfo.getInputTensorType(inputIndex);
   }
 
+  /**
+   * Gets the root handler for the model metadata.
+   *
+   * @throws IllegalStateException if this model does not contain model metadata
+   */
+  public ModelMetadata getModelMetadata() {
+    assertMetadataInfo();
+    return metadataInfo.getModelMetadata();
+  }
+
   /** Gets the count of output tensors in the model. */
   public int getOutputTensorCount() {
     return modelInfo.getOutputTensorCount();
diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
index 6a64193..ad13a30 100644
--- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
+++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
@@ -29,6 +29,9 @@
 
 /** Extracts model metadata information out of TFLite metadata FlatBuffer. */
 final class ModelMetadataInfo {
+  /** The root handler for the model metadata. */
+  private final ModelMetadata modelMetadata;
+
   /** Metadata array of input tensors. */
   private final List</* @Nullable */ TensorMetadata> inputsMetadata;
 
@@ -45,7 +48,7 @@
   ModelMetadataInfo(ByteBuffer buffer) {
     checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
 
-    ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
+    modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
     checkArgument(
         modelMetadata.subgraphMetadataLength() > 0,
         "The metadata flatbuffer does not contain any subgraph metadata.");
@@ -73,6 +76,11 @@
     return inputsMetadata.get(inputIndex);
   }
 
+  /** Gets the root handler for the model metadata. */
+  ModelMetadata getModelMetadata() {
+    return modelMetadata;
+  }
+
   /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
   int getOutputTensorCount() {
     return outputsMetadata.size();