blob: 8a4487d861072ac60509df7e6ab072f4bbd202e6 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.expectThrows;
import android.app.RemoteAction;
import android.content.Context;
import android.content.Intent;
import android.net.Uri;
import android.os.Bundle;
import android.os.LocaleList;
import android.text.Spannable;
import android.text.SpannableString;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
import androidx.collection.LruCache;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SdkSuppress;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.android.textclassifier.testing.TestingDeviceConfig;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class TextClassifierImplTest {
private static final String TYPE_COPY = "copy";
private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
private static final String NO_TYPE = null;
@Mock private ModelFileManager modelFileManager;
private Context context;
private TestingDeviceConfig deviceConfig;
private TextClassifierSettings settings;
private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
private TextClassifierImpl classifier;
@Before
public void setup() throws IOException {
MockitoAnnotations.initMocks(this);
this.context =
new FakeContextBuilder()
.setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
.setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
.build();
this.deviceConfig = new TestingDeviceConfig();
this.settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false);
this.annotatorModelCache = new LruCache<>(2);
this.classifier =
new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
.thenReturn(TestDataUtils.getLangIdModelFileWrapped());
when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
.thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Test
public void testSuggestSelection() throws IOException {
String text = "Contact me at droid@android.com";
String selected = "droid";
String suggested = "droid@android.com";
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
}
@Test
public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
String text = "Contact me at droid@android.com";
String selected = "droid";
String suggested = "droid@android.com";
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, endIndex)
.setDefaultLocales(LOCALES)
.build();
classifier.suggestSelection(null, null, request);
verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
}
@Test
public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
String suggested = "http://www.android.com";
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
}
@Test
public void testSmartSelection_withEmoji() throws IOException {
String text = "\uD83D\uDE02 Hello.";
String selected = "Hello";
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
}
@SdkSuppress(minSdkVersion = 31, codeName = "S")
@Test
public void testSuggestSelection_includeTextClassification() throws IOException {
String text = "Visit http://www.android.com for more information";
String suggested = "http://www.android.com";
int startIndex = text.indexOf(suggested);
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, /* endIndex= */ startIndex + 1)
.setIncludeTextClassification(true)
.build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
selection.getTextClassification(),
isTextClassification(suggested, TextClassifier.TYPE_URL));
assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW));
}
@SdkSuppress(minSdkVersion = 31, codeName = "S")
@Test
public void testSuggestSelection_notIncludeTextClassification() throws IOException {
String text = "Visit http://www.android.com for more information";
TextSelection.Request request =
new TextSelection.Request.Builder(text, /* startIndex= */ 0, /* endIndex= */ 4)
.setIncludeTextClassification(false)
.build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection.getTextClassification()).isNull();
}
@Test
public void testClassifyText() throws IOException {
String text = "Contact me at droid@android.com";
String classifiedText = "droid@android.com";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification =
classifier.classifyText(/* sessionId= */ null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
}
@Test
public void testClassifyText_url() throws IOException {
String text = "Visit www.android.com for more information";
String classifiedText = "www.android.com";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
}
@Test
public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
new TextClassification.Request.Builder(text, 0, text.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
}
@Test
public void testClassifyText_url_inCaps() throws IOException {
String text = "Visit HTTP://ANDROID.COM for more information";
String classifiedText = "HTTP://ANDROID.COM";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
}
@Test
public void testClassifyText_date() throws IOException {
String text = "Let's meet on January 9, 2018.";
String classifiedText = "January 9, 2018";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
Bundle extras = classification.getExtras();
List<Bundle> entities = ExtrasUtils.getEntities(extras);
assertThat(entities).hasSize(1);
assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE);
ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification);
actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras);
}
@Test
public void testClassifyText_datetime() throws IOException {
String text = "Let's meet 2018/01/01 10:30:20.";
String classifiedText = "2018/01/01 10:30:20";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
}
@Test
public void testClassifyText_foreignText() throws IOException {
LocaleList originalLocales = LocaleList.getDefault();
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
TextClassification.Request request =
new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
RemoteAction translateAction = classification.getActions().get(0);
assertEquals(1, classification.getActions().size());
assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
assertNoPackageInfoInExtras(intent);
assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first);
LocaleList.setDefault(originalLocales);
}
@Test
public void testGenerateLinks_phone() throws IOException {
String text = "The number is +12122537077. See you tonight!";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
assertThat(
classifier.generateLinks(null, null, request),
isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
}
@Test
public void testGenerateLinks_exclude() throws IOException {
String text = "The number is +12122537077. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
.build();
assertThat(
classifier.generateLinks(null, null, request),
not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
}
@Test
public void testGenerateLinks_explicit_address() throws IOException {
String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
.build();
assertThat(
classifier.generateLinks(null, null, request),
isTextLinksContaining(
text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
}
@Test
public void testGenerateLinks_exclude_override() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
.build();
assertThat(
classifier.generateLinks(null, null, request),
not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
}
@Test
public void testGenerateLinks_maxLength() throws IOException {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
TextLinks links = classifier.generateLinks(null, null, request);
assertTrue(links.getLinks().isEmpty());
}
@Test
public void testApplyLinks_unsupportedCharacter() throws IOException {
Spannable url = new SpannableString("\u202Emoc.diordna.com");
TextLinks.Request request = new TextLinks.Request.Builder(url).build();
assertEquals(
TextLinks.STATUS_UNSUPPORTED_CHARACTER,
classifier.generateLinks(null, null, request).apply(url, 0, null));
}
@Test
public void testGenerateLinks_tooLong() {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
expectThrows(
IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request));
}
@Test
public void testGenerateLinks_entityData() throws IOException {
String text = "The number is +12122537077.";
Bundle extras = new Bundle();
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
TextLinks textLinks = classifier.generateLinks(null, null, request);
assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
assertThat(entities).hasSize(1);
Bundle entity = entities.get(0);
assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
}
@Test
public void testGenerateLinks_entityData_disabled() throws IOException {
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
TextLinks textLinks = classifier.generateLinks(null, null, request);
assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
assertThat(entities).isNull();
}
@Test
public void testDetectLanguage() throws IOException {
String text = "This is English text";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
assertThat(textLanguage, isTextLanguage("en"));
}
@Test
public void testDetectLanguage_japanese() throws IOException {
String text = "これは日本語のテキストです";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
assertThat(textLanguage, isTextLanguage("ja"));
}
@Test
public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
.build();
TextClassifier.EntityConfig typeConfig =
new TextClassifier.EntityConfig.Builder()
.includeTypesFromTextClassifier(false)
.setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
.build();
ConversationActions.Request request =
new ConversationActions.Request.Builder(Collections.singletonList(message))
.setMaxSuggestions(1)
.setTypeConfig(typeConfig)
.build();
ConversationActions conversationActions =
classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
assertThat(conversationAction.getTextReply()).isNotNull();
}
@Test
public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
.build();
TextClassifier.EntityConfig typeConfig =
new TextClassifier.EntityConfig.Builder()
.includeTypesFromTextClassifier(false)
.setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
.build();
ConversationActions.Request request =
new ConversationActions.Request.Builder(Collections.singletonList(message))
.setTypeConfig(typeConfig)
.build();
ConversationActions conversationActions =
classifier.suggestConversationActions(null, null, request);
assertTrue(conversationActions.getConversationActions().size() > 1);
for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
}
}
@Test
public void testSuggestConversationActions_openUrl() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Check this out: https://www.android.com")
.build();
TextClassifier.EntityConfig typeConfig =
new TextClassifier.EntityConfig.Builder()
.includeTypesFromTextClassifier(false)
.setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
.build();
ConversationActions.Request request =
new ConversationActions.Request.Builder(Collections.singletonList(message))
.setMaxSuggestions(1)
.setTypeConfig(typeConfig)
.build();
ConversationActions conversationActions =
classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
assertNoPackageInfoInExtras(actionIntent);
}
@Test
public void testSuggestConversationActions_copy() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Authentication code: 12345")
.build();
TextClassifier.EntityConfig typeConfig =
new TextClassifier.EntityConfig.Builder()
.includeTypesFromTextClassifier(false)
.setIncludedTypes(Collections.singletonList(TYPE_COPY))
.build();
ConversationActions.Request request =
new ConversationActions.Request.Builder(Collections.singletonList(message))
.setMaxSuggestions(1)
.setTypeConfig(typeConfig)
.build();
ConversationActions conversationActions =
classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
assertThat(conversationAction.getAction()).isNull();
String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
assertThat(code).isEqualTo("12345");
assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
}
@Test
public void testSuggestConversationActions_deduplicate() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("a@android.com b@android.com")
.build();
ConversationActions.Request request =
new ConversationActions.Request.Builder(Collections.singletonList(message))
.setMaxSuggestions(3)
.build();
ConversationActions conversationActions =
classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).isEmpty();
}
@Test
public void testUseCachedAnnotatorModelDisabled() throws IOException {
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
ModelFile annotatorModelA =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
ModelFile annotatorModelB =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
String englishText = "You can reach me on +12122537077.";
String classifiedText = "+12122537077";
TextClassification.Request request =
new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
// Check modelFileA v701
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(annotatorModelA);
TextClassification classificationA = classifier.classifyText(null, null, request);
assertThat(classificationA.getId()).contains("v701");
assertThat(classificationA.getText()).contains(classifiedText);
assertArrayEquals(
new int[] {0, 0, 0, 0},
new int[] {
annotatorModelCache.putCount(),
annotatorModelCache.evictionCount(),
annotatorModelCache.hitCount(),
annotatorModelCache.missCount()
});
// Check modelFileB v801
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(annotatorModelB);
TextClassification classificationB = classifier.classifyText(null, null, request);
assertThat(classificationB.getId()).contains("v801");
assertThat(classificationB.getText()).contains(classifiedText);
assertArrayEquals(
new int[] {0, 0, 0, 0},
new int[] {
annotatorModelCache.putCount(),
annotatorModelCache.evictionCount(),
annotatorModelCache.hitCount(),
annotatorModelCache.missCount()
});
// Reload modelFileA v701
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(annotatorModelA);
TextClassification classificationAcached = classifier.classifyText(null, null, request);
assertThat(classificationAcached.getId()).contains("v701");
assertThat(classificationAcached.getText()).contains(classifiedText);
assertArrayEquals(
new int[] {0, 0, 0, 0},
new int[] {
annotatorModelCache.putCount(),
annotatorModelCache.evictionCount(),
annotatorModelCache.hitCount(),
annotatorModelCache.missCount()
});
}
@Test
public void testUseCachedAnnotatorModelEnabled() throws IOException {
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true);
String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
ModelFile annotatorModelA =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
ModelFile annotatorModelB =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
String englishText = "You can reach me on +12122537077.";
String classifiedText = "+12122537077";
TextClassification.Request request =
new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
// Check modelFileA v701
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(annotatorModelA);
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification.getId()).contains("v701");
assertThat(classification.getText()).contains(classifiedText);
assertArrayEquals(
new int[] {1, 0, 0, 1},
new int[] {
annotatorModelCache.putCount(),
annotatorModelCache.evictionCount(),
annotatorModelCache.hitCount(),
annotatorModelCache.missCount()
});
// Check modelFileB v801
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(annotatorModelB);
TextClassification classificationB = classifier.classifyText(null, null, request);
assertThat(classificationB.getId()).contains("v801");
assertThat(classificationB.getText()).contains(classifiedText);
assertArrayEquals(
new int[] {2, 0, 0, 2},
new int[] {
annotatorModelCache.putCount(),
annotatorModelCache.evictionCount(),
annotatorModelCache.hitCount(),
annotatorModelCache.missCount()
});
// Reload modelFileA v701
when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
.thenReturn(annotatorModelA);
TextClassification classificationAcached = classifier.classifyText(null, null, request);
assertThat(classificationAcached.getId()).contains("v701");
assertThat(classificationAcached.getText()).contains(classifiedText);
assertArrayEquals(
new int[] {2, 0, 1, 2},
new int[] {
annotatorModelCache.putCount(),
annotatorModelCache.evictionCount(),
annotatorModelCache.hitCount(),
annotatorModelCache.missCount()
});
}
private static void assertNoPackageInfoInExtras(Intent intent) {
assertThat(intent.getComponent()).isNull();
assertThat(intent.getPackage()).isNull();
}
private static Matcher<TextSelection> isTextSelection(
final int startIndex, final int endIndex, final String type) {
return new BaseMatcher<TextSelection>() {
@Override
public boolean matches(Object o) {
if (o instanceof TextSelection) {
TextSelection selection = (TextSelection) o;
return startIndex == selection.getSelectionStartIndex()
&& endIndex == selection.getSelectionEndIndex()
&& typeMatches(selection, type);
}
return false;
}
private boolean typeMatches(TextSelection selection, String type) {
return type == null
|| (selection.getEntityCount() > 0
&& type.trim().equalsIgnoreCase(selection.getEntity(0)));
}
@Override
public void describeTo(Description description) {
description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
}
};
}
private static Matcher<TextLinks> isTextLinksContaining(
final String text, final String substring, final String type) {
return new BaseMatcher<TextLinks>() {
@Override
public void describeTo(Description description) {
description
.appendText("text=")
.appendValue(text)
.appendText(", substring=")
.appendValue(substring)
.appendText(", type=")
.appendValue(type);
}
@Override
public boolean matches(Object o) {
if (o instanceof TextLinks) {
for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) {
return type.equals(link.getEntity(0));
}
}
}
return false;
}
};
}
private static Matcher<TextClassification> isTextClassification(
final String text, final String type) {
return new BaseMatcher<TextClassification>() {
@Override
public boolean matches(Object o) {
if (o instanceof TextClassification) {
TextClassification result = (TextClassification) o;
return text.equals(result.getText())
&& result.getEntityCount() > 0
&& type.equals(result.getEntity(0));
}
return false;
}
@Override
public void describeTo(Description description) {
description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type);
}
};
}
private static Matcher<TextClassification> containsIntentWithAction(final String action) {
return new BaseMatcher<TextClassification>() {
@Override
public boolean matches(Object o) {
if (o instanceof TextClassification) {
TextClassification result = (TextClassification) o;
return ExtrasUtils.findAction(result, action) != null;
}
return false;
}
@Override
public void describeTo(Description description) {
description.appendText("intent action=").appendValue(action);
}
};
}
private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
return new BaseMatcher<TextLanguage>() {
@Override
public boolean matches(Object o) {
if (o instanceof TextLanguage) {
TextLanguage result = (TextLanguage) o;
return result.getLocaleHypothesisCount() > 0
&& languageTag.equals(result.getLocale(0).toLanguageTag());
}
return false;
}
@Override
public void describeTo(Description description) {
description.appendText("locale=").appendValue(languageTag);
}
};
}
private static Matcher<ConversationAction> isConversationAction(String actionType) {
return new BaseMatcher<ConversationAction>() {
@Override
public boolean matches(Object o) {
if (!(o instanceof ConversationAction)) {
return false;
}
ConversationAction conversationAction = (ConversationAction) o;
if (!actionType.equals(conversationAction.getType())) {
return false;
}
if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
if (conversationAction.getTextReply() == null) {
return false;
}
}
if (conversationAction.getConfidenceScore() < 0
|| conversationAction.getConfidenceScore() > 1) {
return false;
}
return true;
}
@Override
public void describeTo(Description description) {
description.appendText("actionType=").appendValue(actionType);
}
};
}
}