blob: d31fe6e108cc6ef38dca061bfa28851e189e13cf [file] [log] [blame]
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.view.textclassifier.cts;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.LocaleList;
import android.os.Parcel;
import android.os.Parcelable;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
import androidx.core.os.BuildCompat;
import androidx.test.InstrumentationRegistry;
import androidx.test.filters.SmallTest;
import com.google.common.collect.Range;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@SmallTest
@RunWith(Parameterized.class)
public class TextClassifierTest {
private static final String BUNDLE_KEY = "key";
private static final String BUNDLE_VALUE = "value";
private static final Bundle BUNDLE = new Bundle();
static {
BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE);
}
private static final LocaleList LOCALES = LocaleList.forLanguageTags("en");
private static final int START = 1;
private static final int END = 3;
// This text has lots of things that are probably entities in many cases.
private static final String TEXT = "An email address is test@example.com. A phone number"
+ " might be +12122537077. Somebody lives at 123 Main Street, Mountain View, CA,"
+ " and there's good stuff at https://www.android.com :)";
private static final TextSelection.Request TEXT_SELECTION_REQUEST =
new TextSelection.Request.Builder(TEXT, START, END)
.setDefaultLocales(LOCALES)
.build();
private static final TextClassification.Request TEXT_CLASSIFICATION_REQUEST =
new TextClassification.Request.Builder(TEXT, START, END)
.setDefaultLocales(LOCALES)
.build();
private static final TextLanguage.Request TEXT_LANGUAGE_REQUEST =
new TextLanguage.Request.Builder(TEXT)
.setExtras(BUNDLE)
.build();
private static final ConversationActions.Message FIRST_MESSAGE =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_SELF)
.setText(TEXT)
.build();
private static final ConversationActions.Message SECOND_MESSAGE =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText(TEXT)
.build();
private static final ConversationActions.Request CONVERSATION_ACTIONS_REQUEST =
new ConversationActions.Request.Builder(
Arrays.asList(FIRST_MESSAGE, SECOND_MESSAGE)).build();
private static final String CURRENT = "current";
private static final String SESSION = "session";
private static final String DEFAULT = "default";
private static final String NO_OP = "no_op";
@Parameterized.Parameters(name = "{0}")
public static Iterable<Object> textClassifierTypes() {
return Arrays.asList(CURRENT, SESSION, DEFAULT, NO_OP);
}
@Parameterized.Parameter
public String mTextClassifierType;
private TextClassifier mClassifier;
@Before
public void setup() {
TextClassificationManager manager = InstrumentationRegistry.getTargetContext()
.getSystemService(TextClassificationManager.class);
manager.setTextClassifier(null); // Resets the classifier.
if (mTextClassifierType.equals(CURRENT)) {
mClassifier = manager.getTextClassifier();
} else if (mTextClassifierType.equals(SESSION)) {
mClassifier = manager.createTextClassificationSession(
new TextClassificationContext.Builder(
InstrumentationRegistry.getTargetContext().getPackageName(),
TextClassifier.WIDGET_TYPE_TEXTVIEW)
.build());
} else if (mTextClassifierType.equals(NO_OP)) {
mClassifier = TextClassifier.NO_OP;
} else {
mClassifier = TextClassifierService.getDefaultTextClassifierImplementation(
InstrumentationRegistry.getTargetContext());
}
}
@After
public void tearDown() {
mClassifier.destroy();
}
@Test
public void testTextClassifierDestroy() {
mClassifier.destroy();
if (mTextClassifierType.equals(SESSION)) {
assertEquals(true, mClassifier.isDestroyed());
}
}
@Test
@Ignore("b/315110905")
public void testGetMaxGenerateLinksTextLength() {
// TODO(b/143249163): Verify the value get from TextClassificationConstants
assertTrue(mClassifier.getMaxGenerateLinksTextLength() >= 0);
}
@Test
public void testSmartSelection() {
assertValidResult(mClassifier.suggestSelection(TEXT_SELECTION_REQUEST));
}
@Test
public void testSuggestSelectionWith4Param() {
assertValidResult(mClassifier.suggestSelection(TEXT, START, END, LOCALES));
}
@Test
public void testClassifyText() {
assertValidResult(mClassifier.classifyText(TEXT_CLASSIFICATION_REQUEST));
}
@Test
public void testClassifyTextWith4Param() {
assertValidResult(mClassifier.classifyText(TEXT, START, END, LOCALES));
}
@Test
public void testGenerateLinks() {
assertValidResult(mClassifier.generateLinks(new TextLinks.Request.Builder(TEXT).build()));
}
@Test
public void testSuggestConversationActions() {
ConversationActions conversationActions =
mClassifier.suggestConversationActions(CONVERSATION_ACTIONS_REQUEST);
assertValidResult(conversationActions);
}
@Test
public void testLanguageDetection() {
assertValidResult(mClassifier.detectLanguage(TEXT_LANGUAGE_REQUEST));
}
@Test(expected = RuntimeException.class)
public void testLanguageDetection_nullRequest() {
assertValidResult(mClassifier.detectLanguage(null));
}
@Test
public void testOnSelectionEvent() {
// Doesn't crash.
mClassifier.onSelectionEvent(
SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, 0));
}
@Test
public void testOnTextClassifierEvent() {
// Doesn't crash.
mClassifier.onTextClassifierEvent(
new TextClassifierEvent.ConversationActionsEvent.Builder(
TextClassifierEvent.TYPE_SMART_ACTION)
.build());
}
@Test
public void testResolveEntityListModifications_only_hints() {
TextClassifier.EntityConfig entityConfig = TextClassifier.EntityConfig.createWithHints(
Arrays.asList("some_hint"));
assertEquals(1, entityConfig.getHints().size());
assertTrue(entityConfig.getHints().contains("some_hint"));
assertEquals(new HashSet<String>(Arrays.asList("foo", "bar")),
entityConfig.resolveEntityListModifications(Arrays.asList("foo", "bar")));
}
@Test
public void testResolveEntityListModifications_include_exclude() {
TextClassifier.EntityConfig entityConfig = TextClassifier.EntityConfig.create(
Arrays.asList("some_hint"),
Arrays.asList("a", "b", "c"),
Arrays.asList("b", "d", "x"));
assertEquals(1, entityConfig.getHints().size());
assertTrue(entityConfig.getHints().contains("some_hint"));
assertEquals(new HashSet(Arrays.asList("a", "c", "w")),
new HashSet(entityConfig.resolveEntityListModifications(
Arrays.asList("c", "w", "x"))));
}
@Test
public void testResolveEntityListModifications_explicit() {
TextClassifier.EntityConfig entityConfig =
TextClassifier.EntityConfig.createWithExplicitEntityList(Arrays.asList("a", "b"));
assertEquals(Collections.EMPTY_LIST, entityConfig.getHints());
assertEquals(new HashSet<String>(Arrays.asList("a", "b")),
entityConfig.resolveEntityListModifications(Arrays.asList("w", "x")));
}
@Test
public void testEntityConfig_full() {
TextClassifier.EntityConfig entityConfig =
new TextClassifier.EntityConfig.Builder()
.setIncludedTypes(
Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
.setExcludedTypes(
Collections.singletonList(ConversationAction.TYPE_CALL_PHONE))
.build();
TextClassifier.EntityConfig recovered =
parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR);
assertFullEntityConfig(entityConfig);
assertFullEntityConfig(recovered);
}
@Test
public void testEntityConfig_full_notIncludeTypesFromTextClassifier() {
TextClassifier.EntityConfig entityConfig =
new TextClassifier.EntityConfig.Builder()
.includeTypesFromTextClassifier(false)
.setIncludedTypes(
Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
.setExcludedTypes(
Collections.singletonList(ConversationAction.TYPE_CALL_PHONE))
.build();
TextClassifier.EntityConfig recovered =
parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR);
assertFullEntityConfig_notIncludeTypesFromTextClassifier(entityConfig);
assertFullEntityConfig_notIncludeTypesFromTextClassifier(recovered);
}
@Test
public void testEntityConfig_minimal() {
TextClassifier.EntityConfig entityConfig =
new TextClassifier.EntityConfig.Builder().build();
TextClassifier.EntityConfig recovered =
parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR);
assertMinimalEntityConfig(entityConfig);
assertMinimalEntityConfig(recovered);
}
private static void assertValidResult(TextSelection selection) {
assertNotNull(selection);
assertTrue(selection.getSelectionStartIndex() >= 0);
assertTrue(selection.getSelectionEndIndex() > selection.getSelectionStartIndex());
assertTrue(selection.getEntityCount() >= 0);
for (int i = 0; i < selection.getEntityCount(); i++) {
final String entity = selection.getEntity(i);
assertNotNull(entity);
final float confidenceScore = selection.getConfidenceScore(entity);
assertTrue(confidenceScore >= 0);
assertTrue(confidenceScore <= 1);
}
if (BuildCompat.isAtLeastS()) {
assertThat(selection.getTextClassification()).isNull();
}
}
private static void assertValidResult(TextClassification classification) {
assertNotNull(classification);
assertTrue(classification.getEntityCount() >= 0);
for (int i = 0; i < classification.getEntityCount(); i++) {
final String entity = classification.getEntity(i);
assertNotNull(entity);
final float confidenceScore = classification.getConfidenceScore(entity);
assertTrue(confidenceScore >= 0);
assertTrue(confidenceScore <= 1);
}
assertNotNull(classification.getActions());
}
private static void assertValidResult(TextLinks links) {
assertNotNull(links);
for (TextLinks.TextLink link : links.getLinks()) {
assertTrue(link.getEntityCount() > 0);
assertTrue(link.getStart() >= 0);
assertTrue(link.getStart() <= link.getEnd());
for (int i = 0; i < link.getEntityCount(); i++) {
String entityType = link.getEntity(i);
assertNotNull(entityType);
final float confidenceScore = link.getConfidenceScore(entityType);
assertTrue(confidenceScore >= 0);
assertTrue(confidenceScore <= 1);
}
}
}
private static void assertValidResult(TextLanguage language) {
assertNotNull(language);
assertNotNull(language.getExtras());
assertTrue(language.getLocaleHypothesisCount() >= 0);
for (int i = 0; i < language.getLocaleHypothesisCount(); i++) {
final ULocale locale = language.getLocale(i);
assertNotNull(locale);
final float confidenceScore = language.getConfidenceScore(locale);
assertTrue(confidenceScore >= 0);
assertTrue(confidenceScore <= 1);
}
}
private static void assertValidResult(ConversationActions conversationActions) {
assertNotNull(conversationActions);
List<ConversationAction> conversationActionsList =
conversationActions.getConversationActions();
assertNotNull(conversationActionsList);
for (ConversationAction conversationAction : conversationActionsList) {
assertThat(conversationAction.getType()).isNotNull();
assertThat(conversationAction.getConfidenceScore()).isIn(Range.closed(0f, 1.0f));
}
}
private static void assertFullEntityConfig_notIncludeTypesFromTextClassifier(
TextClassifier.EntityConfig typeConfig) {
List<String> extraTypesFromTextClassifier = Arrays.asList(
ConversationAction.TYPE_CALL_PHONE,
ConversationAction.TYPE_CREATE_REMINDER);
Collection<String> resolvedTypes =
typeConfig.resolveEntityListModifications(extraTypesFromTextClassifier);
assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isFalse();
assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList()))
.containsExactly(ConversationAction.TYPE_OPEN_URL);
assertThat(resolvedTypes).containsExactly(ConversationAction.TYPE_OPEN_URL);
}
private static void assertFullEntityConfig(TextClassifier.EntityConfig typeConfig) {
List<String> extraTypesFromTextClassifier = Arrays.asList(
ConversationAction.TYPE_CALL_PHONE,
ConversationAction.TYPE_CREATE_REMINDER);
Collection<String> resolvedTypes =
typeConfig.resolveEntityListModifications(extraTypesFromTextClassifier);
assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isTrue();
assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList()))
.containsExactly(ConversationAction.TYPE_OPEN_URL);
assertThat(resolvedTypes).containsExactly(
ConversationAction.TYPE_OPEN_URL, ConversationAction.TYPE_CREATE_REMINDER);
}
private static void assertMinimalEntityConfig(TextClassifier.EntityConfig typeConfig) {
assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isTrue();
assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList())).isEmpty();
assertThat(typeConfig.resolveEntityListModifications(
Collections.singletonList(ConversationAction.TYPE_OPEN_URL))).containsExactly(
ConversationAction.TYPE_OPEN_URL);
}
private static <T extends Parcelable> T parcelizeDeparcelize(
T parcelable, Parcelable.Creator<T> creator) {
Parcel parcel = Parcel.obtain();
parcelable.writeToParcel(parcel, 0);
parcel.setDataPosition(0);
return creator.createFromParcel(parcel);
}
}