blob: 018482c7e82db39d72c7c472f124246c087a0bfe [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.task.vision.segmenter;
import com.google.auto.value.AutoValue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.tensorflow.lite.support.image.TensorImage;
/** Represents the segmentation result of an {@link ImageSegmenter}. */
@AutoValue
public abstract class Segmentation {
/**
* Creates a {@link Segmentation} object.
*
* <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}:
*
* <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
* grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each
* pixel in this mask represents the class to which the pixel in the mask belongs. The pixel
* values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is
* associated with {@code coloredLabels.get(i)}.
*
* <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
* are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with
* {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
* shape (height, width), in row major order. The value of each pixel in these masks represents
* the confidence score for this particular class.
*
* <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
* \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
* Orientation} flag of the input FrameBuffer, <br>
* \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
* dimensions.
*
* <p>Example of such post-processing, assuming: <br>
* \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
* will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
* \* a model outputting masks of size 224x224. <br>
* In order to be directly displayable on top of the input image assumed to be displayed *with*
* the {@code Orientation} flag taken into account (according to the <a
* href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be:
* re-scaled to 640 x 480, then rotated 90° clockwise.
*
* @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
* {@code outputType}
*/
static Segmentation create(
OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
return new AutoValue_Segmentation(
outputType,
Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
}
public abstract OutputType getOutputType();
// As an open source project, we've been trying avoiding depending on common java libraries,
// such as Guava, because it may introduce conflicts with clients who also happen to use those
// libraries. Therefore, instead of using ImmutableList here, we convert the List into
// unmodifiableList in create() to make it less vulnerable.
public abstract List<TensorImage> getMasks();
public abstract List<ColoredLabel> getColoredLabels();
}