Implement hasMetadata() and getModelMetadata() in the MetadataExtractor
PiperOrigin-RevId: 303829852
Change-Id: I062874670d8838a00162901a7c45012dad69a63a
diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
index 3bd60ed..054ea0e 100644
--- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
+++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
@@ -25,6 +25,7 @@
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Tensor.QuantizationParams;
import org.tensorflow.lite.schema.Tensor;
+import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
/**
@@ -96,6 +97,11 @@
zipFile = createZipFile(buffer);
}
+ /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
+ public Boolean hasMetadata() {
+ return metadataInfo != null;
+ }
+
/**
* Gets the packed associated file with the specified {@code fileName}.
*
@@ -154,6 +160,16 @@
return modelInfo.getInputTensorType(inputIndex);
}
+ /**
+ * Gets the root handler for the model metadata.
+ *
+ * @throws IllegalStateException if this model does not contain model metadata
+ */
+ public ModelMetadata getModelMetadata() {
+ assertMetadataInfo();
+ return metadataInfo.getModelMetadata();
+ }
+
/** Gets the count of output tensors in the model. */
public int getOutputTensorCount() {
return modelInfo.getOutputTensorCount();
diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
index 6a64193..ad13a30 100644
--- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
+++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
@@ -29,6 +29,9 @@
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
final class ModelMetadataInfo {
+ /** The root handler for the model metadata. */
+ private final ModelMetadata modelMetadata;
+
/** Metadata array of input tensors. */
private final List</* @Nullable */ TensorMetadata> inputsMetadata;
@@ -45,7 +48,7 @@
ModelMetadataInfo(ByteBuffer buffer) {
checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
- ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
+ modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
checkArgument(
modelMetadata.subgraphMetadataLength() > 0,
"The metadata flatbuffer does not contain any subgraph metadata.");
@@ -73,6 +76,11 @@
return inputsMetadata.get(inputIndex);
}
+ /** Gets the root handler for the model metadata. */
+ ModelMetadata getModelMetadata() {
+ return modelMetadata;
+ }
+
/** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
int getOutputTensorCount() {
return outputsMetadata.size();