blob: 0f1ff81b480f43ff6e9f9e288d23535a4694b1c6 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.apppredictionservice.cts;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import android.app.prediction.AppPredictionContext;
import android.app.prediction.AppPredictionSessionId;
import android.app.prediction.AppPredictor;
import android.app.prediction.AppTarget;
import android.app.prediction.AppTargetEvent;
import android.app.prediction.AppTargetId;
import android.os.Binder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
/**
* Reports calls from the CTS prediction service back to the tests.
*/
public class ServiceReporter extends Binder {
public HashMap<AppPredictionSessionId, AppPredictionContext> mSessions = new HashMap<>();
public ArrayList<AppTargetEvent> mEvents = new ArrayList<>();
public String mLocationsShown;
public ArrayList<AppTargetId> mLocationsShownTargets = new ArrayList<>();
public int mNumRequestedUpdates = 0;
public boolean mPredictionUpdatesStarted = false;
private CountDownLatch mCreateSessionLatch = new CountDownLatch(1);
private CountDownLatch mEventLatch = new CountDownLatch(1);
private CountDownLatch mLocationShownLatch = new CountDownLatch(1);
private CountDownLatch mSortLatch = new CountDownLatch(1);
private CountDownLatch mStartPredictionUpdatesLatch = new CountDownLatch(1);
private CountDownLatch mStopPredictionUpdatesLatch = new CountDownLatch(1);
private CountDownLatch mPredictionUpdateLatch = new CountDownLatch(1);
private CountDownLatch mDestroyLatch = new CountDownLatch(1);
private PredictionsProvider mPredictionsProvider;
private SortedPredictionsProvider mSortedPredictionsProvider;
void setPredictionsProvider(PredictionsProvider cb) {
mPredictionsProvider = cb;
}
PredictionsProvider getPredictionsProvider() {
return mPredictionsProvider;
}
void setSortedPredictionsProvider(SortedPredictionsProvider cb) {
mSortedPredictionsProvider = cb;
}
SortedPredictionsProvider getSortedPredictionsProvider() {
return mSortedPredictionsProvider;
}
void assertActiveSession(AppPredictionSessionId sessionId) {
assertTrue(mSessions.containsKey(sessionId));
}
AppPredictionContext getPredictionContext(AppPredictionSessionId sessionId) {
assertTrue(mSessions.containsKey(sessionId));
return mSessions.get(sessionId);
}
void onCreatePredictionSession(AppPredictionContext context,
AppPredictionSessionId sessionId) {
assertNotNull(context);
assertNotNull(sessionId);
assertFalse(mSessions.containsKey(sessionId));
mSessions.put(sessionId, context);
mCreateSessionLatch.countDown();
}
boolean awaitOnCreatePredictionSession() {
try {
return await(mCreateSessionLatch);
} finally {
mCreateSessionLatch = new CountDownLatch(1);
}
}
void onAppTargetEvent(AppPredictionSessionId sessionId, AppTargetEvent event) {
assertTrue(mSessions.containsKey(sessionId));
mEvents.add(event);
mEventLatch.countDown();
}
boolean awaitOnAppTargetEvent() {
try {
return await(mEventLatch);
} finally {
mEventLatch = new CountDownLatch(1);
}
}
void onLocationShown(AppPredictionSessionId sessionId, String launchLocation,
List<AppTargetId> targetIds) {
assertTrue(mSessions.containsKey(sessionId));
mLocationsShown = launchLocation;
mLocationsShownTargets.addAll(targetIds);
mLocationShownLatch.countDown();
}
boolean awaitOnLocationShown() {
try {
return await(mLocationShownLatch);
} finally {
mLocationShownLatch = new CountDownLatch(1);
}
}
void onSortAppTargets(AppPredictionSessionId sessionId, List<AppTarget> targets,
Consumer<List<AppTarget>> callback) {
assertTrue(mSessions.containsKey(sessionId));
assertNotNull(targets);
assertNotNull(callback);
mSortLatch.countDown();
}
boolean awaitOnSortAppTargets() {
try {
return await(mSortLatch);
} finally {
mSortLatch = new CountDownLatch(1);
}
}
void onStartPredictionUpdates() {
mPredictionUpdatesStarted = true;
}
boolean awaitOnStartPredictionUpdates() {
try {
return await(mStartPredictionUpdatesLatch);
} finally {
mStartPredictionUpdatesLatch = new CountDownLatch(1);
}
}
void onStopPredictionUpdates() {
mPredictionUpdatesStarted = false;
}
boolean awaitOnStopPredictionUpdates() {
try {
return await(mStopPredictionUpdatesLatch);
} finally {
mStopPredictionUpdatesLatch = new CountDownLatch(1);
}
}
void onRequestPredictionUpdate(AppPredictionSessionId sessionId) {
assertTrue(mSessions.containsKey(sessionId));
mNumRequestedUpdates++;
mPredictionUpdateLatch.countDown();
}
boolean awaitOnRequestPredictionUpdate() {
try {
return await(mPredictionUpdateLatch);
} finally {
mPredictionUpdateLatch = new CountDownLatch(1);
}
}
void onDestroyPredictionSession(AppPredictionSessionId sessionId) {
assertTrue(mSessions.containsKey(sessionId));
mSessions.remove(sessionId);
mDestroyLatch.countDown();
}
boolean awaitOnDestroyPredictionSession() {
try {
return await(mDestroyLatch);
} finally {
mDestroyLatch = new CountDownLatch(1);
}
}
public class Event {
final AppTarget target;
final int launchLocation;
final int eventType;
public Event(AppTarget target, int launchLocation, int eventType) {
this.target = target;
this.launchLocation = launchLocation;
this.eventType = eventType;
}
}
private boolean await(CountDownLatch latch) {
try {
latch.await(500, TimeUnit.MILLISECONDS);
return true;
} catch (InterruptedException e) {
return false;
}
}
public static class RequestVerifier implements AppPredictor.Callback, PredictionsProvider,
Consumer<List<AppTarget>> {
private ServiceReporter mReporter;
private CountDownLatch mReceivedLatch;
private List<AppTarget> mTargets;
public RequestVerifier(ServiceReporter reporter) {
mReporter = reporter;
mReceivedLatch = new CountDownLatch(1);
}
@Override
public List<AppTarget> getTargets(AppPredictionSessionId sessionId) {
return mTargets;
}
@Override
public void onTargetsAvailable(List<AppTarget> targets) {
if (mTargets != null) {
// Verify that the targets match
assertEquals(targets, mTargets);
} else {
// For the case where we didn't setup the request, save the targets so we can verify
// them in awaitTargets()
mTargets = targets;
}
mReceivedLatch.countDown();
}
@Override
public void accept(List<AppTarget> appTargets) {
onTargetsAvailable(appTargets);
}
/**
* @param requestUpdateCb Callback called when the request is setup
*/
boolean requestAndWaitForTargets(List<AppTarget> targets, Runnable requestUpdateCb) {
mTargets = targets;
mReceivedLatch = new CountDownLatch(1);
mReporter.setPredictionsProvider(this);
requestUpdateCb.run();
try {
return awaitTargets(targets);
} finally {
mReporter.setPredictionsProvider(null);
}
}
boolean awaitTargets(List<AppTarget> targets) {
try {
boolean result = mReceivedLatch.await(500, TimeUnit.MILLISECONDS);
assertEquals(targets, mTargets);
return result;
} catch (InterruptedException e) {
return false;
}
}
}
}