Implement a single training run workflow when fcp job triggered
Bug: 276931193
Test: atest
Change-Id: I913671d743082220816a213b1dea20ad78594dd1
diff --git a/federatedcompute/apk/Android.bp b/federatedcompute/apk/Android.bp
index ed5ef23..2c6a970 100644
--- a/federatedcompute/apk/Android.bp
+++ b/federatedcompute/apk/Android.bp
@@ -48,6 +48,7 @@
plugins: ["auto_value_plugin"],
static_libs: [
"flatbuffers-java",
+ "androidx.concurrent_concurrent-futures",
"federated-compute-java-proto-lite",
"guava",
"modules-utils-preconditions",
diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java b/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java
index 92ed6ec..f1c830d 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java
@@ -18,25 +18,20 @@
/** Constants used internally in the FederatedCompute APK. */
public class Constants {
- public static final String EXTRA_COLLECTION_NAME = "android.federatedcompute.collection_name";
- public static final String EXTRA_EXAMPLE_ITERATOR_CRITERIA =
- "android.federatedcompute.example_iterator_criteria";
- public static final String EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN =
- "android.federatedcompute.example_iterator_resumption_token";
public static final String EXTRA_EXAMPLE_STORE_ITERATOR_BINDER =
"android.federatedcompute.example_store_iterator_binder";
- public static final String EXTRA_RESULT_HANDLING_SERVICE_BINDER =
- "android.federatedcompute.result_handling_service_binder";
public static final String EXTRA_INPUT_CHECKPOINT_FD =
"android.federatedcompute.input_checkpoint_fd";
public static final String EXTRA_OUTPUT_CHECKPOINT_FD =
"android.federatedcompute.output_checkpoint_fd";
- public static final String EXTRA_POPULATION_NAME = "android.federatedcompute.population_name";
public static final String EXTRA_FL_RUNNER_RESULT = "android.federatedcompute.fl_runner_result";
public static final String EXTRA_JOB_ID = "android.federatedcompute.job_id";
public static final String EXTRA_EXAMPLE_SELECTOR = "android.federatedcompute.example_selector";
public static final String EXTRA_CLIENT_ONLY_PLAN = "android.federatedcompute.client_only_plan";
+ public static final String ISOLATED_TRAINING_SERVICE_NAME =
+ "com.android.federatedcompute.service.training.IsolatedTrainingService";
+
private Constants() {}
}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java b/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java
new file mode 100644
index 0000000..e2b7d94
--- /dev/null
+++ b/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.federatedcompute.services.common;
+
+import android.os.ParcelFileDescriptor;
+
+import com.android.federatedcompute.internal.util.LogUtil;
+
+import java.io.BufferedInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+/** Utils related to {@link File} and {@link ParcelFileDescriptor}. */
+public class FileUtils {
+ private static final String TAG = FileUtils.class.getSimpleName();
+
+ private static final int BUFFER_SIZE = 1024;
+
+ /** Create {@link ParcelFileDescriptor} based on the input file. */
+ public static ParcelFileDescriptor createTempFileDescriptor(String fileName) {
+ ParcelFileDescriptor fileDescriptor;
+ try {
+ fileDescriptor =
+ ParcelFileDescriptor.open(
+ new File(fileName), ParcelFileDescriptor.MODE_READ_ONLY);
+ } catch (IOException e) {
+ LogUtil.e(TAG, e, "Failed to createTempFileDescriptor %s", fileName);
+ throw new RuntimeException(e);
+ }
+ return fileDescriptor;
+ }
+
+ /** Create a temporary file based on provided name and extension. */
+ public static String createTempFile(String name, String extension) {
+ String fileName;
+ try {
+ File tempFile = File.createTempFile(name, extension);
+ fileName = tempFile.getAbsolutePath();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return fileName;
+ }
+
+ /** Write the provided data to the file. */
+ public static void writeToFile(String fileName, byte[] data) throws IOException {
+ FileOutputStream out = new FileOutputStream(fileName);
+ out.write(data);
+ out.close();
+ }
+
+ /** Read the input file content to a byte array. */
+ public static byte[] readFileAsByteArray(String filePath) throws IOException {
+ File file = new File(filePath);
+ long fileLength = file.length();
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) fileLength);
+ try (BufferedInputStream inputStream = new BufferedInputStream(new FileInputStream(file))) {
+ byte[] buffer = new byte[BUFFER_SIZE];
+ for (int len = inputStream.read(buffer); len > 0; len = inputStream.read(buffer)) {
+ outputStream.write(buffer, 0, len);
+ }
+ } catch (IOException e) {
+ LogUtil.e(TAG, e, "Failed to read the content of binary file %s", filePath);
+ throw e;
+ }
+ return outputStream.toByteArray();
+ }
+}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java
index 22dcf9b..76a1b77 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java
@@ -89,8 +89,8 @@
}
/** Returns all recorded {@link ExampleConsumption}. */
- public synchronized List<ExampleConsumption> finishRecordingAndGet() {
- List<ExampleConsumption> exampleConsumptions = new ArrayList<>();
+ public synchronized ArrayList<ExampleConsumption> finishRecordingAndGet() {
+ ArrayList<ExampleConsumption> exampleConsumptions = new ArrayList<>();
for (SingleQueryRecorder recorder : mSingleQueryRecorders) {
exampleConsumptions.add(recorder.finishRecordingAndGet());
}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java
index ccb9df8..4d029b3 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java
@@ -27,8 +27,6 @@
import android.util.Pair;
import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.common.Constants;
-import com.android.federatedcompute.services.common.ErrorStatusException;
import com.android.federatedcompute.services.common.Flags;
import com.google.common.util.concurrent.SettableFuture;
@@ -53,8 +51,7 @@
@Override
public IExampleStoreIterator getExampleStoreIterator(
- String packageName, ExampleSelector exampleSelector)
- throws InterruptedException, ErrorStatusException {
+ String packageName, ExampleSelector exampleSelector) throws InterruptedException {
String collection = exampleSelector.getCollectionUri();
byte[] criteria = exampleSelector.getCriteria().toByteArray();
byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray();
@@ -71,9 +68,10 @@
IExampleStoreService exampleStoreService =
mExampleStoreServiceProvider.getExampleStoreService();
Bundle bundle = new Bundle();
- bundle.putString(Constants.EXTRA_COLLECTION_NAME, collection);
- bundle.putByteArray(Constants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
- bundle.putByteArray(Constants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
+ bundle.putString(ClientConstants.EXTRA_COLLECTION_NAME, collection);
+ bundle.putByteArray(
+ ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
+ bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
SettableFuture<Pair<IExampleStoreIterator, Integer>> iteratorOrFailureFuture =
SettableFuture.create();
try {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java b/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java
index 0e747bf..e82bfe6 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java
@@ -18,6 +18,8 @@
import android.annotation.Nullable;
+import com.android.internal.util.Preconditions;
+
import com.google.internal.federated.plan.ClientOnlyPlan;
import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment;
@@ -25,20 +27,23 @@
* The result after client calls TaskAssignemnt API. It includes init checkpoint data and plan data.
*/
public class CheckinResult {
- private byte[] mCheckpointData = null;
+ private String mInputCheckpoint = null;
private ClientOnlyPlan mPlanData = null;
private TaskAssignment mTaskAssignment = null;
public CheckinResult(
- byte[] checkpointData, ClientOnlyPlan planData, TaskAssignment taskAssignment) {
- this.mCheckpointData = checkpointData;
+ String inputCheckpoint, ClientOnlyPlan planData, TaskAssignment taskAssignment) {
+ this.mInputCheckpoint = inputCheckpoint;
this.mPlanData = planData;
this.mTaskAssignment = taskAssignment;
}
@Nullable
- public byte[] getCheckpointData() {
- return mCheckpointData;
+ public String getInputCheckpointFile() {
+ Preconditions.checkArgument(
+ mInputCheckpoint != null && !mInputCheckpoint.isEmpty(),
+ "Input checkpoint file should not be none or empty");
+ return mInputCheckpoint;
}
@Nullable
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
index a77f804..f622f7f 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
@@ -18,6 +18,9 @@
import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor;
import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
+import static com.android.federatedcompute.services.common.FileUtils.readFileAsByteArray;
+import static com.android.federatedcompute.services.common.FileUtils.writeToFile;
import static com.android.federatedcompute.services.http.HttpClientUtil.HTTP_OK_STATUS;
import com.android.federatedcompute.internal.util.LogUtil;
@@ -42,17 +45,11 @@
import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction;
import com.google.protobuf.InvalidProtocolBufferException;
-import java.io.BufferedInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
import java.util.HashMap;
/** Implements a single session of HTTP-based federated compute protocol. */
public final class HttpFederatedProtocol {
public static final String TAG = "HttpFederatedProtocol";
- private static final int BUFFER_SIZE = 1024;
private final String mClientVersion;
private final String mPopulationName;
@@ -204,9 +201,10 @@
return Futures.immediateFailedFuture(
new IllegalStateException("Could not parse ClientOnlyPlan proto", e));
}
+ String inputCheckpointFile = createTempFile("input", ".ckp");
+ writeToFile(inputCheckpointFile, checkpointDataResponse.getPayload());
return Futures.immediateFuture(
- new CheckinResult(
- checkpointDataResponse.getPayload(), clientOnlyPlan, taskAssignment));
+ new CheckinResult(inputCheckpointFile, clientOnlyPlan, taskAssignment));
} catch (Exception e) {
return Futures.immediateFailedFuture(e);
@@ -282,22 +280,6 @@
}
}
- private byte[] readFileAsByteArray(String filePath) throws IOException {
- File file = new File(filePath);
- long fileLength = file.length();
- ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) fileLength);
- try (BufferedInputStream inputStream = new BufferedInputStream(new FileInputStream(file))) {
- byte[] buffer = new byte[BUFFER_SIZE];
- for (int len = inputStream.read(buffer); len > 0; len = inputStream.read(buffer)) {
- outputStream.write(buffer, 0, len);
- }
- } catch (IOException e) {
- LogUtil.e(TAG, e, "Failed to read the content of binary file %s", filePath);
- throw e;
- }
- return outputStream.toByteArray();
- }
-
private ListenableFuture<FederatedComputeHttpResponse> fetchTaskResource(Resource resource) {
switch (resource.getResourceCase()) {
case URI:
diff --git a/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java b/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java
index d2909d2..7d9edf2 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java
@@ -35,7 +35,6 @@
import com.android.federatedcompute.services.common.Flags;
import com.android.federatedcompute.services.common.MonotonicClock;
import com.android.federatedcompute.services.common.PhFlags;
-import com.android.federatedcompute.services.common.TrainingResult;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.FederatedTrainingTaskDao;
import com.android.federatedcompute.services.data.fbs.SchedulingMode;
@@ -46,6 +45,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import com.google.intelligence.fcp.client.engine.TaskRetry;
import java.util.Arrays;
@@ -55,14 +55,13 @@
/** Handles scheduling training tasks e.g. calling into JobScheduler, maintaining datastore. */
public class FederatedComputeJobManager {
private static final String TAG = FederatedComputeJobManager.class.getSimpleName();
-
+ private static volatile FederatedComputeJobManager sSingletonInstance;
@NonNull private final Context mContext;
private final FederatedTrainingTaskDao mFederatedTrainingTaskDao;
private final JobSchedulerHelper mJobSchedulerHelper;
- private static volatile FederatedComputeJobManager sSingletonInstance;
private final FederatedJobIdGenerator mJobIdGenerator;
- private Clock mClock;
private final Flags mFlags;
+ private final Clock mClock;
@VisibleForTesting
FederatedComputeJobManager(
@@ -72,7 +71,7 @@
JobSchedulerHelper jobSchedulerHelper,
@NonNull Clock clock,
Flags flag) {
- this.mContext = context;
+ this.mContext = context.getApplicationContext();
this.mFederatedTrainingTaskDao = federatedTrainingTaskDao;
this.mJobIdGenerator = jobIdGenerator;
this.mJobSchedulerHelper = jobSchedulerHelper;
@@ -285,7 +284,7 @@
String populationName,
TrainingIntervalOptions trainingIntervalOptions,
TaskRetry taskRetry,
- @TrainingResult int trainingResult) {
+ ContributionResult trainingResult) {
boolean result =
rescheduleFederatedTaskAfterTraining(
jobId, populationName, trainingIntervalOptions, taskRetry, trainingResult);
@@ -300,7 +299,7 @@
String populationName,
TrainingIntervalOptions intervalOptions,
TaskRetry taskRetry,
- @TrainingResult int trainingResult) {
+ ContributionResult trainingResult) {
FederatedTrainingTask existingTask =
mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationAndJobId(
populationName, jobId);
@@ -310,7 +309,7 @@
if (existingTask == null) {
return true;
}
- boolean hasContributed = trainingResult == TrainingResult.SUCCESS;
+ boolean hasContributed = trainingResult == ContributionResult.SUCCESS;
if (intervalOptions != null
&& intervalOptions.schedulingMode() == SchedulingMode.ONE_TIME
&& hasContributed) {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
index 33de4b4..fe45813 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
@@ -16,70 +16,144 @@
package com.android.federatedcompute.services.training;
+import static com.android.federatedcompute.services.common.Constants.ISOLATED_TRAINING_SERVICE_NAME;
+import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor;
+import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFileDescriptor;
+
import android.annotation.NonNull;
import android.annotation.Nullable;
+import android.content.ComponentName;
import android.content.Context;
+import android.content.Intent;
+import android.federatedcompute.aidl.IExampleStoreCallback;
+import android.federatedcompute.aidl.IExampleStoreIterator;
+import android.federatedcompute.aidl.IExampleStoreService;
+import android.federatedcompute.common.ClientConstants;
+import android.federatedcompute.common.ExampleConsumption;
+import android.os.Bundle;
+import android.os.ParcelFileDescriptor;
+import androidx.concurrent.futures.CallbackToFutureAdapter;
+
+import com.android.federatedcompute.internal.util.AbstractServiceBinder;
import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.common.Flags;
-import com.android.federatedcompute.services.common.PhFlags;
-import com.android.federatedcompute.services.common.TrainingResult;
+import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
+import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
+import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
+import com.android.federatedcompute.services.http.CheckinResult;
+import com.android.federatedcompute.services.http.HttpFederatedProtocol;
import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
-import com.android.federatedcompute.services.scheduling.SchedulingUtil;
+import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService;
+import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback;
+import com.android.federatedcompute.services.training.util.ComputationResult;
+import com.android.federatedcompute.services.training.util.ListenableSupplier;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition;
+import com.android.internal.annotations.GuardedBy;
+import com.android.internal.util.Preconditions;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
+import com.google.intelligence.fcp.client.RetryInfo;
import com.google.intelligence.fcp.client.engine.TaskRetry;
+import com.google.internal.federated.plan.ClientOnlyPlan;
+import com.google.internal.federated.plan.ExampleSelector;
+import com.google.protobuf.InvalidProtocolBufferException;
-import javax.annotation.concurrent.GuardedBy;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
/** The worker to execute federated computation jobs. */
public class FederatedComputeWorker {
private static final String TAG = FederatedComputeWorker.class.getSimpleName();
+ private static volatile FederatedComputeWorker sWorker;
+ private final Object mLock = new Object();
+ private final AtomicBoolean mInterruptFlag = new AtomicBoolean(false);
+ private final ListenableSupplier<Boolean> mInterruptSupplier =
+ new ListenableSupplier<>(mInterruptFlag::get);
+ private final Context mContext;
+ @Nullable private final FederatedComputeJobManager mJobManager;
+ @Nullable private final TrainingConditionsChecker mTrainingConditionsChecker;
+ private final ComputationRunner mComputationRunner;
+ private final ResultCallbackHelper mResultCallbackHelper;
+ @NonNull private final Injector mInjector;
- static final Object LOCK = new Object();
-
- @GuardedBy("LOCK")
+ @GuardedBy("mLock")
@Nullable
private TrainingRun mActiveRun = null;
- @Nullable private final FederatedComputeJobManager mJobManager;
- private static volatile FederatedComputeWorker sFederatedComputeWorker;
- private final Flags mFlags;
+ private HttpFederatedProtocol mHttpFederatedProtocol;
+ private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder;
+ private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder;
@VisibleForTesting
- public FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags) {
+ public FederatedComputeWorker(
+ Context context,
+ FederatedComputeJobManager jobManager,
+ TrainingConditionsChecker trainingConditionsChecker,
+ ComputationRunner computationRunner,
+ ResultCallbackHelper resultCallbackHelper,
+ Injector injector) {
+ this.mContext = context.getApplicationContext();
this.mJobManager = jobManager;
- this.mFlags = flags;
+ this.mTrainingConditionsChecker = trainingConditionsChecker;
+ this.mComputationRunner = computationRunner;
+ this.mInjector = injector;
+ this.mResultCallbackHelper = resultCallbackHelper;
}
/** Gets an instance of {@link FederatedComputeWorker}. */
@NonNull
public static FederatedComputeWorker getInstance(Context context) {
- if (sFederatedComputeWorker == null) {
+ if (sWorker == null) {
synchronized (FederatedComputeWorker.class) {
- if (sFederatedComputeWorker == null) {
- sFederatedComputeWorker =
+ if (sWorker == null) {
+ sWorker =
new FederatedComputeWorker(
+ context,
FederatedComputeJobManager.getInstance(context),
- PhFlags.getInstance());
+ TrainingConditionsChecker.getInstance(context),
+ new ComputationRunner(context),
+ new ResultCallbackHelper(context),
+ new Injector());
}
}
}
- return sFederatedComputeWorker;
+ return sWorker;
}
/** Starts a training run with the given job Id. */
- public boolean startTrainingRun(int jobId) {
+ public ListenableFuture<FLRunnerResult> startTrainingRun(int jobId) {
LogUtil.d(TAG, "startTrainingRun()");
FederatedTrainingTask trainingTask = mJobManager.onTrainingStarted(jobId);
if (trainingTask == null) {
LogUtil.i(TAG, "Could not find task to run for job ID %s", jobId);
- return false;
+ return Futures.immediateFuture(null);
}
- synchronized (LOCK) {
- // Only allow one concurrent federated computation job.
+ if (!checkTrainingConditions(trainingTask.getTrainingConstraints())) {
+ mJobManager.onTrainingCompleted(
+ jobId,
+ trainingTask.populationName(),
+ trainingTask.getTrainingIntervalOptions(),
+ /* taskRetry= */ null,
+ ContributionResult.FAIL);
+ LogUtil.i(TAG, "Training conditions not satisfied (before bindService)!");
+ return Futures.immediateFuture(null);
+ }
+
+ synchronized (mLock) {
+ // Only allows one concurrent job running.
if (mActiveRun != null) {
LogUtil.i(
TAG,
@@ -91,55 +165,469 @@
trainingTask.populationName(),
trainingTask.getTrainingIntervalOptions(),
/* taskRetry= */ null,
- TrainingResult.FAIL);
- return false;
+ ContributionResult.FAIL);
+ return Futures.immediateFuture(null);
}
+
TrainingRun run = new TrainingRun(jobId, trainingTask);
- this.mActiveRun = run;
- doTraining(run);
- // TODO: get retry info from federated server.
- TaskRetry taskRetry = SchedulingUtil.generateTransientErrorTaskRetry(mFlags);
- finish(this.mActiveRun, taskRetry, TrainingResult.SUCCESS);
+ mActiveRun = run;
+ ListenableFuture<FLRunnerResult> runCompletedFuture = doTraining(run);
+ run.mFuture = runCompletedFuture;
+ return runCompletedFuture;
}
- return true;
}
- /** Cancels the running job if present. */
- public void cancelActiveRun() {
- LogUtil.d(TAG, "cancelActiveRun()");
- synchronized (LOCK) {
+ private ListenableFuture<FLRunnerResult> doTraining(TrainingRun run) {
+ try {
+ // 1. Communicate with remote federated compute server to start task assignment and
+ // download client plan and initial model checkpoint. Note: use bLocking executors for
+ // http requests.
+ mHttpFederatedProtocol =
+ getHttpFederatedProtocol(run.mTask.serverAddress(), run.mTask.populationName());
+ ListenableFuture<CheckinResult> checkinResultFuture =
+ mHttpFederatedProtocol.issueCheckin();
+
+ // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector.
+ ListenableFuture<IExampleStoreIterator> iteratorFuture =
+ FluentFuture.from(checkinResultFuture)
+ .transform(
+ result -> {
+ // Set active run's task name.
+ String taskName = result.getTaskAssignment().getTaskName();
+ Preconditions.checkArgument(
+ !taskName.isEmpty(),
+ "Task name should not be empty");
+ synchronized (mLock) {
+ mActiveRun.mTaskName = taskName;
+ }
+ return getExampleSelector(result);
+ },
+ getLightweightExecutor())
+ .transformAsync(
+ selector ->
+ getExampleStoreIterator(
+ run,
+ run.mTask.appPackageName(),
+ run.mTaskName,
+ selector),
+ getBackgroundExecutor());
+
+ // 3. Run federated learning or federated analytic depends on task type. Federated
+ // learning job will start a new isolated process to run TFLite training.
+ ListenableFuture<ComputationResult> computationResultFuture =
+ Futures.whenAllSucceed(checkinResultFuture, iteratorFuture)
+ .callAsync(
+ () ->
+ runFederatedComputation(
+ Futures.getDone(checkinResultFuture),
+ run,
+ Futures.getDone(iteratorFuture)),
+ getBackgroundExecutor());
+
+ // 4. Report computation result to federated compute server.
+ ListenableFuture<Void> reportToServerFuture =
+ FluentFuture.from(computationResultFuture)
+ .transformAsync(
+ result -> mHttpFederatedProtocol.reportResult(result),
+ getLightweightExecutor());
+ return FluentFuture.from(
+ Futures.whenAllSucceed(reportToServerFuture, computationResultFuture)
+ .call(
+ () -> {
+ ComputationResult result =
+ Futures.getDone(computationResultFuture);
+ var reportToServer = Futures.getDone(reportToServerFuture);
+ // 5. Publish computation result and consumed examples to
+ // client implemented ResultHandlingService.
+ var unused =
+ mResultCallbackHelper.callHandleResult(
+ run.mTaskName, run.mTask, result);
+ return result.getFlRunnerResult();
+ },
+ getBackgroundExecutor()));
+ } catch (Exception e) {
+ return Futures.immediateFailedFuture(e);
+ }
+ }
+
+ /**
+ * Completes the running job , schedule recurrent job, and unbind from ExampleStoreService and
+ * ResultHandlingService etc.
+ */
+ public void finish(FLRunnerResult flRunnerResult) {
+ TaskRetry taskRetry = null;
+ if (flRunnerResult != null) {
+ if (flRunnerResult.hasRetryInfo()) {
+ RetryInfo retryInfo = flRunnerResult.getRetryInfo();
+ long delay = retryInfo.getMinimumDelay().getSeconds() * 1000L;
+ taskRetry =
+ TaskRetry.newBuilder()
+ .setRetryToken(retryInfo.getRetryToken())
+ .setDelayMin(delay)
+ .setDelayMax(delay)
+ .build();
+ LogUtil.i(TAG, "Finished with task retry= %s", taskRetry);
+ }
+ }
+ finish(taskRetry, flRunnerResult.getContributionResult(), true);
+ }
+
+ /**
+ * Cancel the current running job, schedule recurrent job, unbind from ExampleStoreService and
+ * ResultHandlingService etc.
+ */
+ public void finish(
+ TaskRetry taskRetry, ContributionResult contributionResult, boolean cancelFuture) {
+ TrainingRun runToFinish;
+ synchronized (mLock) {
if (mActiveRun == null) {
return;
}
- finish(mActiveRun, /* taskRetry= */ null, TrainingResult.FAIL);
- }
- }
- private void finish(
- TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult) {
- synchronized (LOCK) {
- if (mActiveRun != runToFinish) {
- return;
- }
+ runToFinish = mActiveRun;
mActiveRun = null;
- mJobManager.onTrainingCompleted(
- runToFinish.mJobId,
- runToFinish.mTask.populationName(),
- runToFinish.mTask.getTrainingIntervalOptions(),
- taskRetry,
- trainingResult);
+ if (cancelFuture) {
+ runToFinish.mFuture.cancel(true);
+ }
+ }
+
+ unBindServicesIfNecessary(runToFinish);
+ mJobManager.onTrainingCompleted(
+ runToFinish.mJobId,
+ runToFinish.mTask.populationName(),
+ runToFinish.mTask.getTrainingIntervalOptions(),
+ taskRetry,
+ contributionResult);
+ }
+
+ private void unBindServicesIfNecessary(TrainingRun runToFinish) {
+ if (runToFinish.mIsolatedTrainingService != null) {
+ LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
+ unbindFromIsolatedTrainingService();
+ runToFinish.mIsolatedTrainingService = null;
+ }
+ if (runToFinish.mExampleStoreService != null) {
+ LogUtil.i(TAG, "Unbinding from ExampleStoreService");
+ unbindFromExampleStoreService();
+ runToFinish.mExampleStoreService = null;
}
}
- private void doTraining(TrainingRun run) {
- // TODO: add training logic.
- LogUtil.d(TAG, "Start run training job %d ", run.mJobId);
+ @VisibleForTesting
+ HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) {
+ return HttpFederatedProtocol.create(serverAddress, "1.0", populationName);
+ }
+
+ private ExampleSelector getExampleSelector(CheckinResult checkinResult) {
+ ClientOnlyPlan clientPlan = checkinResult.getPlanData();
+ switch (clientPlan.getPhase().getSpecCase()) {
+ case EXAMPLE_QUERY_SPEC:
+ // Only support one FA query for now.
+ return clientPlan
+ .getPhase()
+ .getExampleQuerySpec()
+ .getExampleQueries(0)
+ .getExampleSelector();
+ case TENSORFLOW_SPEC:
+ return clientPlan.getPhase().getTensorflowSpec().getExampleSelector();
+ default:
+ throw new IllegalArgumentException(
+ String.format(
+ "Client plan spec is not supported %s",
+ clientPlan.getPhase().getSpecCase().toString()));
+ }
+ }
+
+ private boolean checkTrainingConditions(TrainingConstraints constraints) {
+ Set<Condition> conditions =
+ mTrainingConditionsChecker.checkAllConditionsForFlTraining(constraints);
+ for (Condition condition : conditions) {
+ switch (condition) {
+ case THERMALS_NOT_OK:
+ LogUtil.e(TAG, "training job service interrupt thermals not ok");
+ break;
+ case BATTERY_NOT_OK:
+ LogUtil.e(TAG, "training job service interrupt battery not ok");
+ break;
+ }
+ }
+ return conditions.isEmpty();
+ }
+
+ @VisibleForTesting
+ ListenableFuture<ComputationResult> runFlComputation(
+ TrainingRun run,
+ CheckinResult checkinResult,
+ String outputCheckpointFile,
+ IExampleStoreIterator iterator) {
+ ParcelFileDescriptor outputCheckpointFd = createTempFileDescriptor(outputCheckpointFile);
+ ParcelFileDescriptor inputCheckpointFd =
+ createTempFileDescriptor(checkinResult.getInputCheckpointFile());
+ try {
+ IIsolatedTrainingService trainingService = getIsolatedTrainingService();
+ if (trainingService == null) {
+ LogUtil.w(TAG, "Could not bind to IsolatedTrainingService");
+ throw new IllegalStateException("Could not bind to IsolatedTrainingService");
+ }
+ run.mIsolatedTrainingService = trainingService;
+
+ Bundle bundle = new Bundle();
+ bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName());
+ bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, inputCheckpointFd);
+ bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, outputCheckpointFd);
+ bundle.putInt(Constants.EXTRA_JOB_ID, run.mJobId);
+ bundle.putBinder(Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, iterator.asBinder());
+ return FluentFuture.from(runIsolatedTrainingProcess(run, bundle))
+ .transform(
+ result -> {
+ ComputationResult computationResult =
+ processIsolatedTrainingResult(outputCheckpointFile, result);
+ // Close opened file descriptor.
+ try {
+ if (outputCheckpointFd != null) {
+ outputCheckpointFd.close();
+ }
+ if (inputCheckpointFd != null) {
+ inputCheckpointFd.close();
+ }
+ } catch (IOException e) {
+ LogUtil.e(TAG, "Failed to close file descriptor", e);
+ } finally {
+ // Unbind from IsolatedTrainingService.
+ LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
+ unbindFromIsolatedTrainingService();
+ run.mIsolatedTrainingService = null;
+ }
+ return computationResult;
+ },
+ getLightweightExecutor());
+ } catch (Exception e) {
+ // Close opened file descriptor.
+ try {
+ if (outputCheckpointFd != null) {
+ outputCheckpointFd.close();
+ }
+ if (inputCheckpointFd != null) {
+ inputCheckpointFd.close();
+ }
+ } catch (IOException t) {
+ LogUtil.e(TAG, t, "Failed to close file descriptor");
+ } finally {
+ // Unbind from IsolatedTrainingService.
+ LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
+ unbindFromIsolatedTrainingService();
+ run.mIsolatedTrainingService = null;
+ }
+ return Futures.immediateFailedFuture(e);
+ }
+ }
+
+ private ComputationResult processIsolatedTrainingResult(
+ String outputCheckpoint, Bundle result) {
+ byte[] resultBytes =
+ Objects.requireNonNull(result.getByteArray(Constants.EXTRA_FL_RUNNER_RESULT));
+ FLRunnerResult flRunnerResult;
+ try {
+ flRunnerResult = FLRunnerResult.parseFrom(resultBytes);
+ } catch (InvalidProtocolBufferException e) {
+ throw new IllegalArgumentException(e);
+ }
+ ArrayList<ExampleConsumption> exampleList =
+ result.getParcelableArrayList(
+ ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, ExampleConsumption.class);
+ if (exampleList == null || exampleList.isEmpty()) {
+ throw new IllegalArgumentException("example consumption list should not be empty");
+ }
+
+ return new ComputationResult(outputCheckpoint, flRunnerResult, exampleList);
+ }
+
+ private ListenableFuture<Bundle> runIsolatedTrainingProcess(TrainingRun run, Bundle input) {
+ return CallbackToFutureAdapter.getFuture(
+ completer -> {
+ try {
+ run.mIsolatedTrainingService.runFlTraining(
+ input,
+ new ITrainingResultCallback.Stub() {
+ @Override
+ public void onResult(Bundle result) {
+ completer.set(result);
+ }
+ });
+ } catch (Exception e) {
+ completer.setException(e);
+ }
+ return "runIsolatedTrainingProcess";
+ });
+ }
+
+ private ListenableFuture<ComputationResult> runFederatedComputation(
+ CheckinResult checkinResult, TrainingRun run, IExampleStoreIterator iterator) {
+ ClientOnlyPlan clientPlan = checkinResult.getPlanData();
+ String outputCheckpointFile = createTempFile("output", ".ckp");
+
+ ListenableFuture<ComputationResult> computationResultFuture;
+ switch (clientPlan.getPhase().getSpecCase()) {
+ case EXAMPLE_QUERY_SPEC:
+ computationResultFuture =
+ runFAComputation(run, checkinResult, outputCheckpointFile, iterator);
+ break;
+ case TENSORFLOW_SPEC:
+ computationResultFuture =
+ runFlComputation(run, checkinResult, outputCheckpointFile, iterator);
+ break;
+ default:
+ return Futures.immediateFailedFuture(
+ new IllegalArgumentException(
+ String.format(
+ "Client plan spec is not supported %s",
+ clientPlan.getPhase().getSpecCase().toString())));
+ }
+ return computationResultFuture;
+ }
+
+ private ListenableFuture<ComputationResult> runFAComputation(
+ TrainingRun run,
+ CheckinResult checkinResult,
+ String outputCheckpointFile,
+ IExampleStoreIterator exampleStoreIterator) {
+ ExampleSelector exampleSelector = getExampleSelector(checkinResult);
+ ClientOnlyPlan clientPlan = checkinResult.getPlanData();
+ // The federated analytic runs in main process which has permission to file system.
+ ExampleConsumptionRecorder recorder = mInjector.getExampleConsumptionRecorder();
+ FLRunnerResult runResult =
+ mComputationRunner.runTaskWithNativeRunner(
+ run.mJobId,
+ run.mTask.populationName(),
+ checkinResult.getInputCheckpointFile(),
+ outputCheckpointFile,
+ clientPlan,
+ exampleSelector,
+ recorder,
+ exampleStoreIterator,
+ mInterruptSupplier);
+ ArrayList<ExampleConsumption> exampleConsumptions = recorder.finishRecordingAndGet();
+ return Futures.immediateFuture(
+ new ComputationResult(outputCheckpointFile, runResult, exampleConsumptions));
+ }
+
+ @VisibleForTesting
+ IExampleStoreService getExampleStoreService(String packageName) {
+ mExampleStoreServiceBinder =
+ AbstractServiceBinder.getServiceBinder(
+ mContext,
+ ClientConstants.EXAMPLE_STORE_ACTION,
+ IExampleStoreService.Stub::asInterface);
+ Intent intent = new Intent(ClientConstants.EXAMPLE_STORE_ACTION).setPackage(packageName);
+ return mExampleStoreServiceBinder.getService(Runnable::run, intent);
+ }
+
+ @VisibleForTesting
+ void unbindFromExampleStoreService() {
+ mExampleStoreServiceBinder.unbindFromService();
+ }
+
+ private ListenableFuture<IExampleStoreIterator> runExampleStoreStartQuery(
+ TrainingRun run, Bundle input) {
+ return CallbackToFutureAdapter.getFuture(
+ completer -> {
+ try {
+ run.mExampleStoreService.startQuery(
+ input,
+ new IExampleStoreCallback.Stub() {
+ @Override
+ public void onStartQuerySuccess(
+ IExampleStoreIterator iterator) {
+ LogUtil.d(TAG, "Acquire iterator");
+ completer.set(iterator);
+ }
+
+ @Override
+ public void onStartQueryFailure(int errorCode) {
+ LogUtil.e(TAG, "Could not acquire iterator: " + errorCode);
+ completer.setException(
+ new IllegalStateException(
+ "StartQuery failed: " + errorCode));
+ }
+ });
+ } catch (Exception e) {
+ completer.setException(e);
+ }
+ return "runExampleStoreStartQuery";
+ });
+ }
+
+ private ListenableFuture<IExampleStoreIterator> getExampleStoreIterator(
+ TrainingRun run, String packageName, String taskName, ExampleSelector exampleSelector) {
+ try {
+ run.mTaskName = taskName;
+
+ IExampleStoreService exampleStoreService = getExampleStoreService(packageName);
+ if (exampleStoreService == null) {
+ return Futures.immediateFailedFuture(
+ new IllegalStateException(
+ "Could not bind to ExampleStoreService " + packageName));
+ }
+ run.mExampleStoreService = exampleStoreService;
+
+ String collection = exampleSelector.getCollectionUri();
+ byte[] criteria = exampleSelector.getCriteria().toByteArray();
+ byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray();
+ Bundle bundle = new Bundle();
+ bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName());
+ bundle.putString(ClientConstants.EXTRA_COLLECTION_NAME, collection);
+ bundle.putString(ClientConstants.EXTRA_TASK_NAME, taskName);
+ bundle.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, run.mTask.contextData());
+ bundle.putByteArray(
+ ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
+ bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
+
+ return runExampleStoreStartQuery(run, bundle);
+ } catch (Exception e) {
+ LogUtil.e(TAG, "StartQuery failure: " + e.getMessage());
+ return Futures.immediateFailedFuture(e);
+ }
+ }
+
+ @VisibleForTesting
+ @Nullable
+ IIsolatedTrainingService getIsolatedTrainingService() {
+ mIsolatedTrainingServiceBinder =
+ AbstractServiceBinder.getServiceBinder(
+ mContext,
+ ISOLATED_TRAINING_SERVICE_NAME,
+ IIsolatedTrainingService.Stub::asInterface);
+ Intent intent = new Intent();
+ ComponentName serviceComponent =
+ new ComponentName(mContext.getPackageName(), ISOLATED_TRAINING_SERVICE_NAME);
+ intent.setComponent(serviceComponent);
+ return mIsolatedTrainingServiceBinder.getService(Runnable::run, intent);
+ }
+
+ @VisibleForTesting
+ void unbindFromIsolatedTrainingService() {
+ mIsolatedTrainingServiceBinder.unbindFromService();
+ }
+
+ @VisibleForTesting
+ static class Injector {
+ ExampleConsumptionRecorder getExampleConsumptionRecorder() {
+ return new ExampleConsumptionRecorder();
+ }
}
private static final class TrainingRun {
private final int mJobId;
+
+ private String mTaskName;
private final FederatedTrainingTask mTask;
+ @Nullable private ListenableFuture<?> mFuture;
+
+ @Nullable private IIsolatedTrainingService mIsolatedTrainingService = null;
+
+ @Nullable private IExampleStoreService mExampleStoreService = null;
+
private TrainingRun(int jobId, FederatedTrainingTask task) {
this.mJobId = jobId;
this.mTask = task;
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java
index 053e51a..f5c3220 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java
@@ -22,17 +22,17 @@
import android.app.job.JobService;
import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.common.FederatedComputeExecutors;
import com.android.federatedcompute.services.common.FlagsFactory;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
/** Main service for the scheduled federated computation jobs. */
public class FederatedJobService extends JobService {
private static final String TAG = FederatedJobService.class.getSimpleName();
- private ListenableFuture<Boolean> mRunCompleteFuture;
@Override
public boolean onStartJob(JobParameters params) {
@@ -42,19 +42,19 @@
jobFinished(params, /* wantsReschedule= */ false);
return true;
}
- mRunCompleteFuture =
- Futures.submit(
- () ->
- FederatedComputeWorker.getInstance(this)
- .startTrainingRun(params.getJobId()),
- FederatedComputeExecutors.getBackgroundExecutor());
+ FederatedComputeWorker worker = FederatedComputeWorker.getInstance(this);
+ ListenableFuture<FLRunnerResult> runCompleteFuture =
+ worker.startTrainingRun(params.getJobId());
Futures.addCallback(
- mRunCompleteFuture,
- new FutureCallback<Boolean>() {
+ runCompleteFuture,
+ new FutureCallback<FLRunnerResult>() {
@Override
- public void onSuccess(Boolean result) {
- LogUtil.d(TAG, "federated computation job is done!");
+ public void onSuccess(FLRunnerResult flRunnerResult) {
+ LogUtil.d(TAG, "Federated computation job %d is done!", params.getJobId());
+ if (flRunnerResult != null) {
+ worker.finish(flRunnerResult);
+ }
jobFinished(params, /* wantsReschedule= */ false);
}
@@ -62,6 +62,7 @@
public void onFailure(Throwable t) {
LogUtil.e(
TAG, t, "Failed to handle computation job: %d", params.getJobId());
+ worker.finish(null, ContributionResult.FAIL, false);
jobFinished(params, /* wantsReschedule= */ false);
}
},
@@ -71,12 +72,8 @@
@Override
public boolean onStopJob(JobParameters params) {
- if (mRunCompleteFuture != null) {
- mRunCompleteFuture.cancel(true);
- }
- FederatedComputeWorker.getInstance(this).cancelActiveRun();
- // Reschedule the job since it's not done. TODO: we should implement specify reschedule
- // logic instead.
- return true;
+ LogUtil.d(TAG, "FederatedJobService.onStopJob %d", params.getJobId());
+ FederatedComputeWorker.getInstance(this).finish(null, ContributionResult.FAIL, true);
+ return false;
}
}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java
index 0fd9996..62035e8 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java
@@ -18,6 +18,7 @@
import android.content.Context;
import android.federatedcompute.aidl.IExampleStoreIterator;
+import android.federatedcompute.common.ClientConstants;
import android.os.Bundle;
import android.os.ParcelFileDescriptor;
import android.os.RemoteException;
@@ -83,7 +84,7 @@
throw new IllegalArgumentException("ExampleSelector proto is invalid", e);
}
String populationName =
- Objects.requireNonNull(params.getString(Constants.EXTRA_POPULATION_NAME));
+ Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME));
ParcelFileDescriptor inputCheckpointFd =
Objects.requireNonNull(
params.getParcelable(
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java b/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java
index 63a448d..d49aadd 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java
@@ -25,18 +25,17 @@
import android.federatedcompute.aidl.IFederatedComputeCallback;
import android.federatedcompute.aidl.IResultHandlingService;
import android.federatedcompute.common.ClientConstants;
-import android.federatedcompute.common.ExampleConsumption;
import android.os.Bundle;
import com.android.federatedcompute.internal.util.AbstractServiceBinder;
import com.android.federatedcompute.internal.util.LogUtil;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
+import com.android.federatedcompute.services.training.util.ComputationResult;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
-import java.util.ArrayList;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
@@ -69,19 +68,16 @@
* Publishes the training result and example list to client implemented ResultHandlingService.
*/
public ListenableFuture<CallbackResult> callHandleResult(
- String taskName,
- FederatedTrainingTask task,
- ArrayList<ExampleConsumption> exampleConsumptions,
- boolean success) {
+ String taskName, FederatedTrainingTask task, ComputationResult result) {
Bundle input = new Bundle();
input.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName());
input.putString(ClientConstants.EXTRA_TASK_NAME, taskName);
input.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData());
input.putInt(
ClientConstants.EXTRA_COMPUTATION_RESULT,
- success ? STATUS_SUCCESS : STATUS_TRAINING_FAILED);
+ result.isResultSuccess() ? STATUS_SUCCESS : STATUS_TRAINING_FAILED);
input.putParcelableArrayList(
- ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptions);
+ ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, result.getExampleConsumptionList());
try {
IResultHandlingService resultHandlingService =
@@ -117,10 +113,9 @@
} catch (Exception e) {
LogUtil.e(
TAG,
- String.format(
- "ResultHandlingService binding died. population name: %s",
- task.populationName()),
- e);
+ e,
+ "ResultHandlingService binding died. population name: %s",
+ task.populationName());
// We publish result to client app with best effort and should not crash flow.
return Futures.immediateFuture(CallbackResult.FAIL);
} finally {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java b/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java
index 72a58ae..89d9847 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java
@@ -21,24 +21,24 @@
import com.google.intelligence.fcp.client.FLRunnerResult;
import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
-import java.util.List;
+import java.util.ArrayList;
/** The result of federated computation. */
public class ComputationResult {
private String mOutputCheckpointFile = "";
private FLRunnerResult mFlRunnerResult = null;
- private List<ExampleConsumption> mExampleConsumptionList = null;
+ private ArrayList<ExampleConsumption> mExampleConsumptionList = null;
public ComputationResult(
String outputCheckpointFile,
FLRunnerResult flRunnerResult,
- List<ExampleConsumption> exampleConsumptionList) {
+ ArrayList<ExampleConsumption> exampleConsumptionList) {
this.mOutputCheckpointFile = outputCheckpointFile;
this.mFlRunnerResult = flRunnerResult;
this.mExampleConsumptionList = exampleConsumptionList;
}
- public List<ExampleConsumption> getExampleConsumptionList() {
+ public ArrayList<ExampleConsumption> getExampleConsumptionList() {
return mExampleConsumptionList;
}
diff --git a/framework/java/android/federatedcompute/common/ClientConstants.java b/framework/java/android/federatedcompute/common/ClientConstants.java
index 6d12749..0886506 100644
--- a/framework/java/android/federatedcompute/common/ClientConstants.java
+++ b/framework/java/android/federatedcompute/common/ClientConstants.java
@@ -26,8 +26,8 @@
public static final int STATUS_SUCCESS = 0;
public static final int STATUS_INTERNAL_ERROR = 1;
public static final int STATUS_TRAINING_FAILED = 2;
-
public static final String EXTRA_POPULATION_NAME = "android.federatedcompute.population_name";
+
public static final String EXTRA_COLLECTION_NAME = "android.federatedcompute.collection_name";
public static final String EXTRA_TASK_NAME = "android.federatedcompute.task_name";
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java
index 326fb89..9c0c8ec 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java
@@ -37,7 +37,6 @@
import com.android.federatedcompute.services.common.Clock;
import com.android.federatedcompute.services.common.Flags;
-import com.android.federatedcompute.services.common.TrainingResult;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.FederatedTrainingTaskDao;
import com.android.federatedcompute.services.data.FederatedTrainingTaskDbHelper;
@@ -47,6 +46,7 @@
import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import com.google.intelligence.fcp.client.engine.TaskRetry;
import org.junit.After;
@@ -92,13 +92,11 @@
.build();
private static final TaskRetry TASK_RETRY =
TaskRetry.newBuilder().setDelayMin(5000000).setDelayMax(6000000).build();
-
+ private final CountDownLatch mLatch = new CountDownLatch(1);
private FederatedComputeJobManager mJobManager;
private Context mContext;
private FederatedTrainingTaskDao mTrainingTaskDao;
private boolean mSuccess = false;
- private final CountDownLatch mLatch = new CountDownLatch(1);
-
@Mock private Clock mClock;
@Mock private Flags mMockFlags;
@Mock private FederatedJobIdGenerator mMockJobIdGenerator;
@@ -724,7 +722,7 @@
POPULATION_NAME1,
createTrainingIntervalOptionsAsRoot(SchedulingMode.RECURRENT, 0),
TASK_RETRY,
- TrainingResult.SUCCESS);
+ ContributionResult.SUCCESS);
assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull();
assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1);
@@ -748,7 +746,7 @@
POPULATION_NAME1,
createTrainingIntervalOptionsAsRoot(SchedulingMode.ONE_TIME, 0),
TASK_RETRY,
- TrainingResult.SUCCESS);
+ ContributionResult.SUCCESS);
assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull();
assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty();
@@ -787,7 +785,7 @@
.setDelayMin(serverRetryDelayMillis)
.setDelayMax(serverRetryDelayMillis)
.build(),
- TrainingResult.FAIL);
+ ContributionResult.FAIL);
List<FederatedTrainingTask> taskList =
mTrainingTaskDao.getFederatedTrainingTask(null, null);
@@ -847,7 +845,7 @@
.setDelayMin(minRetryDelayMillis)
.setDelayMax(maxRetryDelayMillis)
.build(),
- TrainingResult.SUCCESS);
+ ContributionResult.SUCCESS);
List<FederatedTrainingTask> taskList =
mTrainingTaskDao.getFederatedTrainingTask(null, null);
@@ -907,7 +905,7 @@
.setDelayMin(serverDefinedIntervalMillis)
.setDelayMax(serverDefinedIntervalMillis)
.build(),
- TrainingResult.SUCCESS);
+ ContributionResult.SUCCESS);
List<FederatedTrainingTask> taskList =
mTrainingTaskDao.getFederatedTrainingTask(null, null);
@@ -967,7 +965,7 @@
.setDelayMin(serverDefinedIntervalMillis)
.setDelayMax(serverDefinedIntervalMillis)
.build(),
- TrainingResult.FAIL);
+ ContributionResult.FAIL);
List<FederatedTrainingTask> taskList =
mTrainingTaskDao.getFederatedTrainingTask(null, null);
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java
index b08c1c4..4f52e41 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java
@@ -26,14 +26,10 @@
import com.google.internal.federated.plan.ExampleSelector;
import com.google.internal.federated.plan.FederatedExampleQueryIORouter;
import com.google.internal.federated.plan.TFV1CheckpointAggregation;
+import com.google.internal.federated.plan.TensorflowSpec;
/** The utility class for federated learning related tests. */
public class TrainingTestUtil {
- public static final String CLIENT_PACKAGE_NAME = "de.myselph";
- public static final long RUN_ID = 12345L;
- public static final String SESSION_NAME = "session_name";
- public static final String TASK_NAME = "task_name";
- public static final String POPULATION_NAME = "population_name";
public static final String STRING_VECTOR_NAME = "vector1";
public static final String INT_VECTOR_NAME = "vector2";
public static final String STRING_TENSOR_NAME = "tensor1";
@@ -85,4 +81,18 @@
.build();
return clientOnlyPlan;
}
+
+ public static ClientOnlyPlan createFakeFederatedLearningClientPlan() {
+ TensorflowSpec tensorflowSpec =
+ TensorflowSpec.newBuilder()
+ .setDatasetTokenTensorName("dataset")
+ .addTargetNodeNames("target")
+ .build();
+ ClientOnlyPlan clientOnlyPlan =
+ ClientOnlyPlan.newBuilder()
+ .setPhase(
+ ClientPhase.newBuilder().setTensorflowSpec(tensorflowSpec).build())
+ .build();
+ return clientOnlyPlan;
+ }
}
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
index 4bbfca3..9a5a3df 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
@@ -16,108 +16,172 @@
package com.android.federatedcompute.services.training;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.util.concurrent.Futures.immediateFailedFuture;
+import static com.google.common.util.concurrent.Futures.immediateFuture;
+
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
-import com.android.federatedcompute.services.common.Flags;
-import com.android.federatedcompute.services.common.TrainingResult;
+import android.content.Context;
+import android.federatedcompute.aidl.IExampleStoreCallback;
+import android.federatedcompute.aidl.IExampleStoreService;
+import android.federatedcompute.common.ClientConstants;
+import android.federatedcompute.common.ExampleConsumption;
+import android.os.Bundle;
+import android.os.RemoteException;
+
+import androidx.test.core.app.ApplicationProvider;
+
+import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.fbs.SchedulingMode;
import com.android.federatedcompute.services.data.fbs.SchedulingReason;
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
+import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
+import com.android.federatedcompute.services.http.CheckinResult;
+import com.android.federatedcompute.services.http.HttpFederatedProtocol;
import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
+import com.android.federatedcompute.services.testutils.FakeExampleStoreIterator;
+import com.android.federatedcompute.services.testutils.TrainingTestUtil;
+import com.android.federatedcompute.services.training.ResultCallbackHelper.CallbackResult;
+import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService;
+import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
+import com.google.intelligence.fcp.client.RetryInfo;
+import com.google.intelligence.fcp.client.engine.TaskRetry;
+import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+import org.tensorflow.example.BytesList;
+import org.tensorflow.example.Example;
+import org.tensorflow.example.Feature;
+import org.tensorflow.example.Features;
+
+import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.concurrent.ExecutionException;
@RunWith(JUnit4.class)
public final class FederatedComputeWorkerTest {
private static final int JOB_ID = 1234;
private static final String POPULATION_NAME = "barPopulation";
- private static final String SERVER_ADDRESS = "https://server.uri/";
+ private static final String TASK_NAME = "task-id";
private static final long CREATION_TIME_MS = 10000L;
private static final long TASK_EARLIEST_NEXT_RUN_TIME_MS = 1234567L;
private static final String PACKAGE_NAME = "com.android.federatedcompute.services.training";
+ private static final String SERVER_ADDRESS = "https://server.com/";
private static final byte[] DEFAULT_TRAINING_CONSTRAINTS =
createTrainingConstraints(true, true, true);
- private static final long FEDERATED_TRANSIENT_ERROR_RETRY_PERIOD_SECS = 50000;
+ private static final TaskRetry TASK_RETRY =
+ TaskRetry.newBuilder().setRetryToken("foobar").build();
+ private static final CheckinResult FL_CHECKIN_RESULT =
+ new CheckinResult(
+ createTempFile("input", ".ckp"),
+ TrainingTestUtil.createFakeFederatedLearningClientPlan(),
+ TaskAssignment.newBuilder().setTaskName(TASK_NAME).build());
+ private static final CheckinResult FA_CHECKIN_RESULT =
+ new CheckinResult(
+ createTempFile("input", ".ckp"),
+ TrainingTestUtil.createFederatedAnalyticClientPlan(),
+ TaskAssignment.newBuilder().setTaskName(TASK_NAME).build());
+ private static final FLRunnerResult FL_RUNNER_FAILURE_RESULT =
+ FLRunnerResult.newBuilder().setContributionResult(ContributionResult.FAIL).build();
+
+ private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
+ FLRunnerResult.newBuilder()
+ .setContributionResult(ContributionResult.SUCCESS)
+ .setRetryInfo(
+ RetryInfo.newBuilder()
+ .setRetryToken(TASK_RETRY.getRetryToken())
+ .build())
+ .build();
private static final byte[] INTERVAL_OPTIONS = createDefaultTrainingIntervalOptions();
private static final FederatedTrainingTask FEDERATED_TRAINING_TASK_1 =
FederatedTrainingTask.builder()
.appPackageName(PACKAGE_NAME)
.creationTime(CREATION_TIME_MS)
.lastScheduledTime(TASK_EARLIEST_NEXT_RUN_TIME_MS)
+ .serverAddress(SERVER_ADDRESS)
.populationName(POPULATION_NAME)
.jobId(JOB_ID)
- .serverAddress(SERVER_ADDRESS)
.intervalOptions(INTERVAL_OPTIONS)
.constraints(DEFAULT_TRAINING_CONSTRAINTS)
.earliestNextRunTime(TASK_EARLIEST_NEXT_RUN_TIME_MS)
.schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
.build();
-
+ private static final Example EXAMPLE_PROTO_1 =
+ Example.newBuilder()
+ .setFeatures(
+ Features.newBuilder()
+ .putFeature(
+ "feature1",
+ Feature.newBuilder()
+ .setBytesList(
+ BytesList.newBuilder()
+ .addValue(
+ ByteString.copyFromUtf8(
+ "f1_value1")))
+ .build()))
+ .build();
+ private static final Any FAKE_CRITERIA = Any.newBuilder().setTypeUrl("baz.com").build();
+ private static final String COLLECTION_URI = "app://com.foo.bar//inapp/collection1";
+ private static final ExampleConsumption EXAMPLE_CONSUMPTION_1 =
+ new ExampleConsumption.Builder()
+ .setCollectionName(COLLECTION_URI)
+ .setSelectionCriteria(FAKE_CRITERIA.toByteArray())
+ .setExampleCount(100)
+ .build();
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+ @Mock TrainingConditionsChecker mTrainingConditionsChecker;
@Mock FederatedComputeJobManager mMockJobManager;
- @Mock private Flags mMockFlags;
- private FederatedComputeWorker mFcpWorker;
-
- @Before
- public void doBeforeEachTest() {
- MockitoAnnotations.initMocks(this);
- mFcpWorker = new FederatedComputeWorker(mMockJobManager, mMockFlags);
- doNothing()
- .when(mMockJobManager)
- .onTrainingCompleted(anyInt(), anyString(), any(), any(), anyInt());
- when(mMockFlags.getTransientErrorRetryDelayJitterPercent()).thenReturn(0.1f);
- when(mMockFlags.getTransientErrorRetryDelaySecs())
- .thenReturn(FEDERATED_TRANSIENT_ERROR_RETRY_PERIOD_SECS);
- }
-
- @Test
- public void testTrainingSuccess() {
- when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1);
- boolean result = mFcpWorker.startTrainingRun(JOB_ID);
-
- assertTrue(result);
- verify(mMockJobManager, times(1))
- .onTrainingCompleted(
- eq(JOB_ID), eq(POPULATION_NAME), any(), any(), eq(TrainingResult.SUCCESS));
- }
-
- @Test
- public void testTrainingFailure_nonExist() {
- when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(null);
- boolean result = mFcpWorker.startTrainingRun(JOB_ID);
-
- assertFalse(result);
- verify(mMockJobManager, times(0))
- .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), anyInt());
- }
+ private Context mContext;
+ private FederatedComputeWorker mSpyWorker;
+ @Mock private HttpFederatedProtocol mMockHttpFederatedProtocol;
+ @Mock private ComputationRunner mMockComputationRunner;
+ @Mock private ResultCallbackHelper mMockResultCallbackHelper;
private static byte[] createTrainingConstraints(
boolean requiresSchedulerIdle,
- boolean requiresSchedulerCharging,
+ boolean requiresSchedulerBatteryNotLow,
boolean requiresSchedulerUnmeteredNetwork) {
FlatBufferBuilder builder = new FlatBufferBuilder();
builder.finish(
TrainingConstraints.createTrainingConstraints(
builder,
requiresSchedulerIdle,
- requiresSchedulerCharging,
+ requiresSchedulerBatteryNotLow,
requiresSchedulerUnmeteredNetwork));
return builder.sizedByteArray();
}
@@ -129,4 +193,310 @@
builder, SchedulingMode.ONE_TIME, 0));
return builder.sizedByteArray();
}
+
+ @Before
+ public void doBeforeEachTest() throws Exception {
+ mContext = ApplicationProvider.getApplicationContext();
+ mSpyWorker =
+ Mockito.spy(
+ new FederatedComputeWorker(
+ mContext,
+ mMockJobManager,
+ mTrainingConditionsChecker,
+ mMockComputationRunner,
+ mMockResultCallbackHelper,
+ new TestInjector()));
+ when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any()))
+ .thenReturn(EnumSet.noneOf(Condition.class));
+ when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any()))
+ .thenReturn(Futures.immediateFuture(CallbackResult.SUCCESS));
+ when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1);
+ doReturn(mMockHttpFederatedProtocol)
+ .when(mSpyWorker)
+ .getHttpFederatedProtocol(anyString(), anyString());
+ when(mMockComputationRunner.runTaskWithNativeRunner(
+ anyInt(),
+ anyString(),
+ anyString(),
+ anyString(),
+ any(),
+ any(),
+ any(),
+ any(),
+ any()))
+ .thenReturn(FL_RUNNER_SUCCESS_RESULT);
+ }
+
+ @Test
+ public void testJobNonExist_returnsFail() throws Exception {
+ when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(null);
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+
+ assertNull(result);
+ verify(mMockJobManager, times(0))
+ .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), any());
+ }
+
+ @Test
+ public void testTrainingConditionsCheckFailed_returnsFail() throws Exception {
+ when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any()))
+ .thenReturn(ImmutableSet.of(Condition.BATTERY_NOT_OK));
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+
+ assertNull(result);
+ verify(mMockJobManager)
+ .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), any());
+ }
+
+ @Test
+ public void testCheckinFails_throwsException() throws Exception {
+ setUpExampleStoreService();
+
+ doReturn(
+ immediateFailedFuture(
+ new ExecutionException(
+ "issue checkin failed",
+ new IllegalStateException("http 404"))))
+ .when(mMockHttpFederatedProtocol)
+ .issueCheckin();
+ doReturn(FluentFuture.from(immediateFuture(null)))
+ .when(mMockHttpFederatedProtocol)
+ .reportResult(any());
+
+ assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+ mSpyWorker.finish(null, ContributionResult.FAIL, false);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ }
+
+ @Test
+ public void testReportResultFails_throwsException() throws Exception {
+ setUpExampleStoreService();
+
+ doReturn(immediateFuture(FA_CHECKIN_RESULT))
+ .when(mMockHttpFederatedProtocol)
+ .issueCheckin();
+ doReturn(
+ FluentFuture.from(
+ immediateFailedFuture(
+ new ExecutionException(
+ "report result failed",
+ new IllegalStateException("http 404")))))
+ .when(mMockHttpFederatedProtocol)
+ .reportResult(any());
+
+ assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+ mSpyWorker.finish(null, ContributionResult.FAIL, false);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ verify(mSpyWorker).unbindFromExampleStoreService();
+ }
+
+ @Test
+ public void testBindToExampleStoreFails_throwsException() throws Exception {
+ setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
+
+ // Mock failure bind to ExampleStoreService.
+ doReturn(null).when(mSpyWorker).getExampleStoreService(anyString());
+ doNothing().when(mSpyWorker).unbindFromExampleStoreService();
+
+ assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+ mSpyWorker.finish(null, ContributionResult.FAIL, false);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ verify(mSpyWorker, times(0)).unbindFromExampleStoreService();
+ }
+
+ @Test
+ public void testRunFAComputationReturnsFailResult() throws Exception {
+ setUpExampleStoreService();
+ setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+ // Mock return failed runner result from native fcp client.
+ when(mMockComputationRunner.runTaskWithNativeRunner(
+ anyInt(),
+ anyString(),
+ anyString(),
+ anyString(),
+ any(),
+ any(),
+ any(),
+ any(),
+ any()))
+ .thenReturn(FL_RUNNER_FAILURE_RESULT);
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+ assertThat(result.getContributionResult()).isEqualTo(ContributionResult.FAIL);
+
+ mSpyWorker.finish(result);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ verify(mSpyWorker).unbindFromExampleStoreService();
+ }
+
+ @Test
+ public void testPublishToResultHandlingServiceFails_returnsSuccess() throws Exception {
+ setUpExampleStoreService();
+ setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+ // Mock publish to ResultHandlingService fails which is best effort and should not affect
+ // final result.
+ when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any()))
+ .thenReturn(Futures.immediateFuture(CallbackResult.FAIL));
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+ assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+ mSpyWorker.finish(result);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS));
+ verify(mSpyWorker).unbindFromExampleStoreService();
+ }
+
+ @Test
+ public void testPublishToResultHandlingServiceThrowsException_returnsSuccess()
+ throws Exception {
+ setUpExampleStoreService();
+ setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+ // Mock publish to ResultHandlingService throws exception which is best effort and should
+ // not affect final result.
+ when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any()))
+ .thenReturn(
+ immediateFailedFuture(
+ new ExecutionException(
+ "ResultHandlingService fail",
+ new IllegalStateException("can't bind to service"))));
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+ assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+ mSpyWorker.finish(result);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS));
+ verify(mSpyWorker).unbindFromExampleStoreService();
+ verify(mMockResultCallbackHelper).callHandleResult(eq(TASK_NAME), any(), any());
+ }
+
+ @Test
+ public void testRunFAComputation_returnsSuccess() throws Exception {
+ setUpExampleStoreService();
+ setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+ assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+ mSpyWorker.finish(result);
+ verify(mMockJobManager).onTrainingCompleted(anyInt(), anyString(), any(), any(), any());
+ }
+
+ @Test
+ public void testBindToIsolatedTrainingServiceFail_returnsFail() throws Exception {
+ doReturn(immediateFuture(FL_CHECKIN_RESULT))
+ .when(mMockHttpFederatedProtocol)
+ .issueCheckin();
+ setUpExampleStoreService();
+
+ // Mock failure bind to IsolatedTrainingService.
+ doReturn(null).when(mSpyWorker).getIsolatedTrainingService();
+ doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService();
+
+ ExecutionException exception =
+ assertThrows(
+ ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+ assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class);
+ assertThat(exception.getCause())
+ .hasMessageThat()
+ .isEqualTo("Could not bind to IsolatedTrainingService");
+
+ mSpyWorker.finish(null, ContributionResult.FAIL, false);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ }
+
+ @Test
+ public void testRunFLComputation_returnsSuccess() throws Exception {
+ setUpExampleStoreService();
+ setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
+
+ // Mock bind to IsolatedTrainingService.
+ doReturn(new FakeIsolatedTrainingService()).when(mSpyWorker).getIsolatedTrainingService();
+ doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService();
+
+ FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+ assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+ mSpyWorker.finish(result);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS));
+ verify(mSpyWorker).unbindFromIsolatedTrainingService();
+ verify(mSpyWorker).unbindFromExampleStoreService();
+ }
+
+ private void setUpExampleStoreService() {
+ TestExampleStoreService testExampleStoreService = new TestExampleStoreService();
+ doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString());
+ doNothing().when(mSpyWorker).unbindFromExampleStoreService();
+ }
+
+ private void setUpHttpFederatedProtocol(CheckinResult checkinResult) {
+ doReturn(immediateFuture(checkinResult)).when(mMockHttpFederatedProtocol).issueCheckin();
+ doReturn(FluentFuture.from(immediateFuture(null)))
+ .when(mMockHttpFederatedProtocol)
+ .reportResult(any());
+ }
+
+ private static class TestExampleStoreService extends IExampleStoreService.Stub {
+ @Override
+ public void startQuery(Bundle params, IExampleStoreCallback callback)
+ throws RemoteException {
+ callback.onStartQuerySuccess(
+ new FakeExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1.toByteArray())));
+ }
+ }
+
+ private static class TestInjector extends FederatedComputeWorker.Injector {
+ @Override
+ ExampleConsumptionRecorder getExampleConsumptionRecorder() {
+ return new ExampleConsumptionRecorder() {
+ @Override
+ public synchronized ArrayList<ExampleConsumption> finishRecordingAndGet() {
+ ArrayList<ExampleConsumption> exampleList = new ArrayList<>();
+ exampleList.add(EXAMPLE_CONSUMPTION_1);
+ return exampleList;
+ }
+ };
+ }
+ }
+
+ private static final class FakeIsolatedTrainingService extends IIsolatedTrainingService.Stub {
+ @Override
+ public void runFlTraining(Bundle params, ITrainingResultCallback callback)
+ throws RemoteException {
+ Bundle bundle = new Bundle();
+ bundle.putByteArray(
+ Constants.EXTRA_FL_RUNNER_RESULT, FL_RUNNER_SUCCESS_RESULT.toByteArray());
+ ArrayList<ExampleConsumption> exampleConsumptionList = new ArrayList<>();
+ exampleConsumptionList.add(EXAMPLE_CONSUMPTION_1);
+ bundle.putParcelableArrayList(
+ ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptionList);
+ callback.onResult(bundle);
+ }
+
+ @Override
+ public void cancelTraining(long runId) {}
+ }
}
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java
index 3210068..b9e6086 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java
@@ -16,13 +16,16 @@
package com.android.federatedcompute.services.training;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -33,20 +36,39 @@
import com.android.federatedcompute.services.common.FederatedComputeExecutors;
import com.android.federatedcompute.services.common.PhFlagsTestUtil;
+import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.MoreExecutors;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
+import com.google.intelligence.fcp.client.RetryInfo;
+import com.google.intelligence.fcp.client.engine.TaskRetry;
+import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.mockito.Mock;
import org.mockito.MockitoSession;
import org.mockito.quality.Strictness;
@RunWith(JUnit4.class)
public final class FederatedJobServiceTest {
- private static final long WAIT_IN_MILLIS = 1_000L;
+ private static final TaskRetry TASK_RETRY =
+ TaskRetry.newBuilder().setRetryToken("foobar").build();
+ private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
+ FLRunnerResult.newBuilder()
+ .setContributionResult(ContributionResult.SUCCESS)
+ .setRetryInfo(
+ RetryInfo.newBuilder()
+ .setRetryToken(TASK_RETRY.getRetryToken())
+ .build())
+ .build();
private FederatedJobService mSpyService;
+ @Mock private FederatedComputeWorker mMockWorker;
+
+ private MockitoSession mStaticMockSession;
@Before
public void setUp() throws Exception {
@@ -56,60 +78,54 @@
mSpyService = spy(new FederatedJobService());
doNothing().when(mSpyService).jobFinished(any(), anyBoolean());
doReturn(mSpyService).when(mSpyService).getApplicationContext();
+ mStaticMockSession =
+ ExtendedMockito.mockitoSession()
+ .spyStatic(FederatedComputeExecutors.class)
+ .spyStatic(FederatedComputeWorker.class)
+ .initMocks(this)
+ .strictness(Strictness.LENIENT)
+ .startMocking();
+ ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService())
+ .when(() -> FederatedComputeExecutors.getBackgroundExecutor());
+ ExtendedMockito.doReturn(mMockWorker).when(() -> FederatedComputeWorker.getInstance(any()));
+ }
+
+ @After
+ public void teardown() {
+ if (mStaticMockSession != null) {
+ mStaticMockSession.finishMocking();
+ }
}
@Test
public void testOnStartJob() throws Exception {
- MockitoSession session =
- ExtendedMockito.mockitoSession()
- .spyStatic(FederatedComputeExecutors.class)
- .strictness(Strictness.LENIENT)
- .startMocking();
- try {
- ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService())
- .when(FederatedComputeExecutors::getBackgroundExecutor);
+ doReturn(Futures.immediateFuture(FL_RUNNER_SUCCESS_RESULT))
+ .when(mMockWorker)
+ .startTrainingRun(anyInt());
+ doNothing().when(mMockWorker).finish(eq(FL_RUNNER_SUCCESS_RESULT));
- boolean result = mSpyService.onStartJob(mock(JobParameters.class));
+ boolean result = mSpyService.onStartJob(mock(JobParameters.class));
- assertTrue(result);
- Thread.sleep(WAIT_IN_MILLIS);
-
- verify(mSpyService, times(1)).jobFinished(any(), anyBoolean());
- } finally {
- session.finishMocking();
- }
+ assertTrue(result);
+ verify(mSpyService, times(1)).jobFinished(any(), anyBoolean());
}
@Test
public void testOnStartJobKillSwitch() throws Exception {
PhFlagsTestUtil.enableGlobalKillSwitch();
- MockitoSession session =
- ExtendedMockito.mockitoSession()
- .spyStatic(FederatedComputeExecutors.class)
- .strictness(Strictness.LENIENT)
- .startMocking();
- try {
- ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService())
- .when(FederatedComputeExecutors::getBackgroundExecutor);
- boolean result = mSpyService.onStartJob(mock(JobParameters.class));
+ boolean result = mSpyService.onStartJob(mock(JobParameters.class));
- assertTrue(result);
-
- verify(mSpyService, times(1)).jobFinished(any(), eq(false));
- } finally {
- session.finishMocking();
- }
+ assertTrue(result);
+ verify(mMockWorker, never()).startTrainingRun(anyInt());
+ verify(mSpyService, times(1)).jobFinished(any(), eq(false));
}
@Test
public void testOnStopJob() {
- MockitoSession session =
- ExtendedMockito.mockitoSession().strictness(Strictness.LENIENT).startMocking();
- try {
- assertTrue(mSpyService.onStopJob(mock(JobParameters.class)));
- } finally {
- session.finishMocking();
- }
+ doNothing().when(mMockWorker).finish(any(), eq(ContributionResult.FAIL), eq(true));
+
+ // Do not reschedule in JobService. FederatedComputeJobManager will handle it.
+ assertFalse(mSpyService.onStopJob(mock(JobParameters.class)));
}
}
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java
index a72eb85..57e14b6 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java
@@ -24,12 +24,10 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.when;
-import android.content.Context;
+import android.federatedcompute.common.ClientConstants;
import android.os.Bundle;
import android.os.ParcelFileDescriptor;
-import androidx.test.core.app.ApplicationProvider;
-
import com.android.dx.mockito.inline.extended.ExtendedMockito;
import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.common.FederatedComputeExecutors;
@@ -62,18 +60,11 @@
@RunWith(JUnit4.class)
public final class IsolatedTrainingServiceImplTest {
private static final String POPULATION_NAME = "population_name";
- private static final String CLIENT_PACKAGE_NAME = "de.myselph";
private static final long RUN_ID = 12345L;
- private static final String SESSION_NAME = "session_name";
- private static final String TASK_NAME = "task_name";
- private static final String INPUT_CHECKPOINT_FD = "fd:///5";
- private static final String OUTPUT_CHECKPOINT_FD = "fd:///6";
private static final FakeExampleStoreIterator FAKE_EXAMPLE_STORE_ITERATOR =
new FakeExampleStoreIterator(ImmutableList.of());
private static final ExampleSelector EXAMPLE_SELECTOR =
ExampleSelector.newBuilder().setCollectionUri("collection_uri").build();
-
- private final CountDownLatch mLatch = new CountDownLatch(1);
private static final TaskRetry TASK_RETRY =
TaskRetry.newBuilder().setRetryToken("foobar").build();
private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
@@ -84,7 +75,6 @@
.setRetryToken(TASK_RETRY.getRetryToken())
.build())
.build();
-
private static final FLRunnerResult FL_RUNNER_FAIL_RESULT =
FLRunnerResult.newBuilder()
.setContributionResult(ContributionResult.FAIL)
@@ -93,11 +83,9 @@
.setRetryToken(TASK_RETRY.getRetryToken())
.build())
.build();
-
- private final Context mContext = ApplicationProvider.getApplicationContext();
+ private final CountDownLatch mLatch = new CountDownLatch(1);
private IsolatedTrainingServiceImpl mIsolatedTrainingService;
private Bundle mCallbackResult;
- private int mCallbackErrorCode;
@Mock private ComputationRunner mComputationRunner;
private MockitoSession mStaticMockSession;
private ParcelFileDescriptor mInputCheckpointFd;
@@ -159,7 +147,7 @@
@Test
public void runFlTrainingMissingExampleSelector_returnsFailure() throws Exception {
Bundle bundle = new Bundle();
- bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+ bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
bundle.putBinder(
@@ -173,7 +161,7 @@
@Test
public void runFlTrainingInvalidExampleSelector_returnsFailure() throws Exception {
Bundle bundle = new Bundle();
- bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+ bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
bundle.putBinder(
@@ -189,7 +177,7 @@
@Test
public void runFlTrainingNullPlan_returnsFailure() throws Exception {
Bundle bundle = new Bundle();
- bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+ bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
bundle.putBinder(
@@ -211,13 +199,12 @@
private Bundle buildInputBundle() throws Exception {
Bundle bundle = new Bundle();
- bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+ bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
bundle.putByteArray(Constants.EXTRA_EXAMPLE_SELECTOR, EXAMPLE_SELECTOR.toByteArray());
bundle.putBinder(
Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, FAKE_EXAMPLE_STORE_ITERATOR);
-
ClientOnlyPlan clientOnlyPlan = TrainingTestUtil.createFederatedAnalyticClientPlan();
bundle.putByteArray(Constants.EXTRA_CLIENT_ONLY_PLAN, clientOnlyPlan.toByteArray());
return bundle;
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java
index 9a58277..725cebf 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java
@@ -41,8 +41,11 @@
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
import com.android.federatedcompute.services.training.ResultCallbackHelper.CallbackResult;
+import com.android.federatedcompute.services.training.util.ComputationResult;
import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import org.junit.Before;
import org.junit.Test;
@@ -51,7 +54,6 @@
import org.mockito.Mockito;
import java.util.ArrayList;
-import java.util.concurrent.CountDownLatch;
@RunWith(JUnit4.class)
public final class ResultCallbackHelperTest {
@@ -60,10 +62,10 @@
private static final int SCHEDULING_REASON = SchedulingReason.SCHEDULING_REASON_NEW_TASK;
private static final String POPULATION_NAME = "population_name";
private static final String TASK_NAME = "task_name";
- private static final int JOB_ID = 123;
private static final byte[] INTERVAL_OPTIONS = createDefaultTrainingIntervalOptions();
private static final byte[] TRAINING_CONSTRAINTS = createDefaultTrainingConstraints();
- private final CountDownLatch mLatch = new CountDownLatch(1);
+ private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
+ FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build();
private static final FederatedTrainingTask TRAINING_TASK =
FederatedTrainingTask.builder()
@@ -80,11 +82,14 @@
.build();
private static final ArrayList<ExampleConsumption> EXAMPLE_CONSUMPTIONS =
getExampleConsumptions();
+ private ComputationResult mComputationResult;
private ResultCallbackHelper mHelper;
@Before
public void setUp() {
Context context = ApplicationProvider.getApplicationContext();
+ mComputationResult =
+ new ComputationResult("output", FL_RUNNER_SUCCESS_RESULT, EXAMPLE_CONSUMPTIONS);
mHelper = Mockito.spy(new ResultCallbackHelper(context));
doNothing().when(mHelper).unbindFromResultHandlingService();
}
@@ -96,8 +101,7 @@
.getResultHandlingService(eq(PACKAGE_NAME));
CallbackResult result =
- mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
- .get();
+ mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
assertThat(result).isEqualTo(CallbackResult.SUCCESS);
}
@@ -109,8 +113,7 @@
.getResultHandlingService(eq(PACKAGE_NAME));
CallbackResult result =
- mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
- .get();
+ mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
assertThat(result).isEqualTo(CallbackResult.FAIL);
}
@@ -122,8 +125,7 @@
.getResultHandlingService(eq(PACKAGE_NAME));
CallbackResult result =
- mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
- .get();
+ mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
assertThat(result).isEqualTo(CallbackResult.FAIL);
}
@@ -135,8 +137,7 @@
.getResultHandlingService(eq(PACKAGE_NAME));
CallbackResult result =
- mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
- .get();
+ mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
assertThat(result).isEqualTo(CallbackResult.FAIL);
}