blob: 0e3842c3b50c118d197c2791b23587a7547c8afd [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.verify;
import android.content.Context;
import android.os.CancellationSignal;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextLinks.TextLink;
import android.view.textclassifier.TextSelection;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.internal.os.StatsdConfigProto.StatsdConfig;
import com.android.os.AtomsProto;
import com.android.os.AtomsProto.Atom;
import com.android.os.AtomsProto.TextClassifierApiUsageReported;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.StatsdTestUtils;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
import com.android.textclassifier.downloader.ModelDownloadManager;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class DefaultTextClassifierServiceTest {
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
private static final long SHORT_TIMEOUT_MS = 1000;
private static final String SESSION_ID = "abcdef";
private TestInjector testInjector;
private DefaultTextClassifierService defaultTextClassifierService;
@Mock private TextClassifierService.Callback<TextClassification> textClassificationCallback;
@Mock private TextClassifierService.Callback<TextSelection> textSelectionCallback;
@Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
@Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
@Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
defaultTextClassifierService.onCreate();
}
@Before
public void setupStatsdTestUtils() throws Exception {
StatsdTestUtils.cleanup(CONFIG_ID);
StatsdConfig.Builder builder =
StatsdConfig.newBuilder()
.setId(CONFIG_ID)
.addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER);
StatsdTestUtils.pushConfig(builder.build());
}
@After
public void tearDown() throws Exception {
StatsdTestUtils.cleanup(CONFIG_ID);
}
@Test
public void classifyText_success() throws Exception {
String text = "www.android.com";
TextClassification.Request request =
new TextClassification.Request.Builder(text, 0, text.length()).build();
defaultTextClassifierService.onClassifyText(
TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textClassificationCallback);
ArgumentCaptor<TextClassification> captor = ArgumentCaptor.forClass(TextClassification.class);
verify(textClassificationCallback).onSuccess(captor.capture());
assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.SUCCESS);
}
@Test
public void suggestSelection_success() throws Exception {
String text = "Visit http://www.android.com for more information";
String selected = "http";
String suggested = "http://www.android.com";
int start = text.indexOf(selected);
int end = start + suggested.length();
TextSelection.Request request = new TextSelection.Request.Builder(text, start, end).build();
defaultTextClassifierService.onSuggestSelection(
TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textSelectionCallback);
ArgumentCaptor<TextSelection> captor = ArgumentCaptor.forClass(TextSelection.class);
verify(textSelectionCallback).onSuccess(captor.capture());
assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
verifyApiUsageLog(ApiType.SUGGEST_SELECTION, ResultType.SUCCESS);
}
@Test
public void generateLinks_success() throws Exception {
String text = "Visit http://www.android.com for more information";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
defaultTextClassifierService.onGenerateLinks(
TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textLinksCallback);
ArgumentCaptor<TextLinks> captor = ArgumentCaptor.forClass(TextLinks.class);
verify(textLinksCallback).onSuccess(captor.capture());
assertThat(captor.getValue().getLinks()).hasSize(1);
TextLink textLink = captor.getValue().getLinks().iterator().next();
assertThat(textLink.getEntityCount()).isGreaterThan(0);
assertThat(textLink.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
verifyApiUsageLog(ApiType.GENERATE_LINKS, ResultType.SUCCESS);
}
@Test
public void detectLanguage_success() throws Exception {
String text = "ピカチュウ";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
defaultTextClassifierService.onDetectLanguage(
TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textLanguageCallback);
ArgumentCaptor<TextLanguage> captor = ArgumentCaptor.forClass(TextLanguage.class);
verify(textLanguageCallback).onSuccess(captor.capture());
assertThat(captor.getValue().getLocaleHypothesisCount()).isGreaterThan(0);
assertThat(captor.getValue().getLocale(0).toLanguageTag()).isEqualTo("ja");
verifyApiUsageLog(ApiType.DETECT_LANGUAGES, ResultType.SUCCESS);
}
@Test
public void suggestConversationActions_success() throws Exception {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Checkout www.android.com")
.build();
ConversationActions.Request request =
new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
defaultTextClassifierService.onSuggestConversationActions(
TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
conversationActionsCallback);
ArgumentCaptor<ConversationActions> captor = ArgumentCaptor.forClass(ConversationActions.class);
verify(conversationActionsCallback).onSuccess(captor.capture());
List<ConversationAction> conversationActions = captor.getValue().getConversationActions();
assertThat(conversationActions.size()).isGreaterThan(0);
assertThat(conversationActions.get(0).getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
verifyApiUsageLog(ApiType.SUGGEST_CONVERSATION_ACTIONS, ResultType.SUCCESS);
}
@Test
public void missingModelFile_onFailureShouldBeCalled() throws Exception {
testInjector.setModelFileManager(
new ModelFileManagerImpl(
ApplicationProvider.getApplicationContext(),
ImmutableList.of(),
testInjector.createTextClassifierSettings()));
defaultTextClassifierService.onCreate();
TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
defaultTextClassifierService.onClassifyText(
TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textClassificationCallback);
verify(textClassificationCallback).onFailure(Mockito.anyString());
verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.FAIL);
}
private static void verifyApiUsageLog(
AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType,
AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType)
throws Exception {
ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
ImmutableList<TextClassifierApiUsageReported> loggedEvents =
ImmutableList.copyOf(
loggedAtoms.stream()
.map(Atom::getTextClassifierApiUsageReported)
.collect(Collectors.toList()));
assertThat(loggedEvents).hasSize(1);
TextClassifierApiUsageReported loggedEvent = loggedEvents.get(0);
assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0L);
assertThat(loggedEvent.getApiType()).isEqualTo(expectedApiType);
assertThat(loggedEvent.getResultType()).isEqualTo(expectedResultApiType);
assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID);
}
private static final class TestInjector implements DefaultTextClassifierService.Injector {
private final Context context;
private ModelFileManager modelFileManager;
private TestInjector(Context context) {
this.context = Preconditions.checkNotNull(context);
}
private void setModelFileManager(ModelFileManager modelFileManager) {
this.modelFileManager = modelFileManager;
}
@Override
public Context getContext() {
return context;
}
@Override
public ModelFileManager createModelFileManager(
TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
if (modelFileManager == null) {
return TestDataUtils.createModelFileManagerForTesting(context);
}
return modelFileManager;
}
@Override
public TextClassifierSettings createTextClassifierSettings() {
return new TextClassifierSettings();
}
@Override
public TextClassifierImpl createTextClassifierImpl(
TextClassifierSettings settings, ModelFileManager modelFileManager) {
return new TextClassifierImpl(context, settings, modelFileManager);
}
@Override
public ListeningExecutorService createNormPriorityExecutor() {
return MoreExecutors.newDirectExecutorService();
}
@Override
public ListeningExecutorService createLowPriorityExecutor() {
return MoreExecutors.newDirectExecutorService();
}
@Override
public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
TextClassifierSettings settings, Executor executor) {
return new TextClassifierApiUsageLogger(
/* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
}
}
}