blob: fe45813f2da31931c0677363c943bbda7531e8e7 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.federatedcompute.services.training;
import static com.android.federatedcompute.services.common.Constants.ISOLATED_TRAINING_SERVICE_NAME;
import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor;
import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor;
import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
import static com.android.federatedcompute.services.common.FileUtils.createTempFileDescriptor;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.federatedcompute.aidl.IExampleStoreCallback;
import android.federatedcompute.aidl.IExampleStoreIterator;
import android.federatedcompute.aidl.IExampleStoreService;
import android.federatedcompute.common.ClientConstants;
import android.federatedcompute.common.ExampleConsumption;
import android.os.Bundle;
import android.os.ParcelFileDescriptor;
import androidx.concurrent.futures.CallbackToFutureAdapter;
import com.android.federatedcompute.internal.util.AbstractServiceBinder;
import com.android.federatedcompute.internal.util.LogUtil;
import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
import com.android.federatedcompute.services.http.CheckinResult;
import com.android.federatedcompute.services.http.HttpFederatedProtocol;
import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService;
import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback;
import com.android.federatedcompute.services.training.util.ComputationResult;
import com.android.federatedcompute.services.training.util.ListenableSupplier;
import com.android.federatedcompute.services.training.util.TrainingConditionsChecker;
import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition;
import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.Preconditions;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.intelligence.fcp.client.FLRunnerResult;
import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import com.google.intelligence.fcp.client.RetryInfo;
import com.google.intelligence.fcp.client.engine.TaskRetry;
import com.google.internal.federated.plan.ClientOnlyPlan;
import com.google.internal.federated.plan.ExampleSelector;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/** The worker to execute federated computation jobs. */
public class FederatedComputeWorker {
private static final String TAG = FederatedComputeWorker.class.getSimpleName();
private static volatile FederatedComputeWorker sWorker;
private final Object mLock = new Object();
private final AtomicBoolean mInterruptFlag = new AtomicBoolean(false);
private final ListenableSupplier<Boolean> mInterruptSupplier =
new ListenableSupplier<>(mInterruptFlag::get);
private final Context mContext;
@Nullable private final FederatedComputeJobManager mJobManager;
@Nullable private final TrainingConditionsChecker mTrainingConditionsChecker;
private final ComputationRunner mComputationRunner;
private final ResultCallbackHelper mResultCallbackHelper;
@NonNull private final Injector mInjector;
@GuardedBy("mLock")
@Nullable
private TrainingRun mActiveRun = null;
private HttpFederatedProtocol mHttpFederatedProtocol;
private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder;
private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder;
@VisibleForTesting
public FederatedComputeWorker(
Context context,
FederatedComputeJobManager jobManager,
TrainingConditionsChecker trainingConditionsChecker,
ComputationRunner computationRunner,
ResultCallbackHelper resultCallbackHelper,
Injector injector) {
this.mContext = context.getApplicationContext();
this.mJobManager = jobManager;
this.mTrainingConditionsChecker = trainingConditionsChecker;
this.mComputationRunner = computationRunner;
this.mInjector = injector;
this.mResultCallbackHelper = resultCallbackHelper;
}
/** Gets an instance of {@link FederatedComputeWorker}. */
@NonNull
public static FederatedComputeWorker getInstance(Context context) {
if (sWorker == null) {
synchronized (FederatedComputeWorker.class) {
if (sWorker == null) {
sWorker =
new FederatedComputeWorker(
context,
FederatedComputeJobManager.getInstance(context),
TrainingConditionsChecker.getInstance(context),
new ComputationRunner(context),
new ResultCallbackHelper(context),
new Injector());
}
}
}
return sWorker;
}
/** Starts a training run with the given job Id. */
public ListenableFuture<FLRunnerResult> startTrainingRun(int jobId) {
LogUtil.d(TAG, "startTrainingRun()");
FederatedTrainingTask trainingTask = mJobManager.onTrainingStarted(jobId);
if (trainingTask == null) {
LogUtil.i(TAG, "Could not find task to run for job ID %s", jobId);
return Futures.immediateFuture(null);
}
if (!checkTrainingConditions(trainingTask.getTrainingConstraints())) {
mJobManager.onTrainingCompleted(
jobId,
trainingTask.populationName(),
trainingTask.getTrainingIntervalOptions(),
/* taskRetry= */ null,
ContributionResult.FAIL);
LogUtil.i(TAG, "Training conditions not satisfied (before bindService)!");
return Futures.immediateFuture(null);
}
synchronized (mLock) {
// Only allows one concurrent job running.
if (mActiveRun != null) {
LogUtil.i(
TAG,
"Delaying %d/%s another run is already active!",
jobId,
trainingTask.populationName());
mJobManager.onTrainingCompleted(
jobId,
trainingTask.populationName(),
trainingTask.getTrainingIntervalOptions(),
/* taskRetry= */ null,
ContributionResult.FAIL);
return Futures.immediateFuture(null);
}
TrainingRun run = new TrainingRun(jobId, trainingTask);
mActiveRun = run;
ListenableFuture<FLRunnerResult> runCompletedFuture = doTraining(run);
run.mFuture = runCompletedFuture;
return runCompletedFuture;
}
}
private ListenableFuture<FLRunnerResult> doTraining(TrainingRun run) {
try {
// 1. Communicate with remote federated compute server to start task assignment and
// download client plan and initial model checkpoint. Note: use bLocking executors for
// http requests.
mHttpFederatedProtocol =
getHttpFederatedProtocol(run.mTask.serverAddress(), run.mTask.populationName());
ListenableFuture<CheckinResult> checkinResultFuture =
mHttpFederatedProtocol.issueCheckin();
// 2. Bind to client app implemented ExampleStoreService based on ExampleSelector.
ListenableFuture<IExampleStoreIterator> iteratorFuture =
FluentFuture.from(checkinResultFuture)
.transform(
result -> {
// Set active run's task name.
String taskName = result.getTaskAssignment().getTaskName();
Preconditions.checkArgument(
!taskName.isEmpty(),
"Task name should not be empty");
synchronized (mLock) {
mActiveRun.mTaskName = taskName;
}
return getExampleSelector(result);
},
getLightweightExecutor())
.transformAsync(
selector ->
getExampleStoreIterator(
run,
run.mTask.appPackageName(),
run.mTaskName,
selector),
getBackgroundExecutor());
// 3. Run federated learning or federated analytic depends on task type. Federated
// learning job will start a new isolated process to run TFLite training.
ListenableFuture<ComputationResult> computationResultFuture =
Futures.whenAllSucceed(checkinResultFuture, iteratorFuture)
.callAsync(
() ->
runFederatedComputation(
Futures.getDone(checkinResultFuture),
run,
Futures.getDone(iteratorFuture)),
getBackgroundExecutor());
// 4. Report computation result to federated compute server.
ListenableFuture<Void> reportToServerFuture =
FluentFuture.from(computationResultFuture)
.transformAsync(
result -> mHttpFederatedProtocol.reportResult(result),
getLightweightExecutor());
return FluentFuture.from(
Futures.whenAllSucceed(reportToServerFuture, computationResultFuture)
.call(
() -> {
ComputationResult result =
Futures.getDone(computationResultFuture);
var reportToServer = Futures.getDone(reportToServerFuture);
// 5. Publish computation result and consumed examples to
// client implemented ResultHandlingService.
var unused =
mResultCallbackHelper.callHandleResult(
run.mTaskName, run.mTask, result);
return result.getFlRunnerResult();
},
getBackgroundExecutor()));
} catch (Exception e) {
return Futures.immediateFailedFuture(e);
}
}
/**
* Completes the running job , schedule recurrent job, and unbind from ExampleStoreService and
* ResultHandlingService etc.
*/
public void finish(FLRunnerResult flRunnerResult) {
TaskRetry taskRetry = null;
if (flRunnerResult != null) {
if (flRunnerResult.hasRetryInfo()) {
RetryInfo retryInfo = flRunnerResult.getRetryInfo();
long delay = retryInfo.getMinimumDelay().getSeconds() * 1000L;
taskRetry =
TaskRetry.newBuilder()
.setRetryToken(retryInfo.getRetryToken())
.setDelayMin(delay)
.setDelayMax(delay)
.build();
LogUtil.i(TAG, "Finished with task retry= %s", taskRetry);
}
}
finish(taskRetry, flRunnerResult.getContributionResult(), true);
}
/**
* Cancel the current running job, schedule recurrent job, unbind from ExampleStoreService and
* ResultHandlingService etc.
*/
public void finish(
TaskRetry taskRetry, ContributionResult contributionResult, boolean cancelFuture) {
TrainingRun runToFinish;
synchronized (mLock) {
if (mActiveRun == null) {
return;
}
runToFinish = mActiveRun;
mActiveRun = null;
if (cancelFuture) {
runToFinish.mFuture.cancel(true);
}
}
unBindServicesIfNecessary(runToFinish);
mJobManager.onTrainingCompleted(
runToFinish.mJobId,
runToFinish.mTask.populationName(),
runToFinish.mTask.getTrainingIntervalOptions(),
taskRetry,
contributionResult);
}
private void unBindServicesIfNecessary(TrainingRun runToFinish) {
if (runToFinish.mIsolatedTrainingService != null) {
LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
unbindFromIsolatedTrainingService();
runToFinish.mIsolatedTrainingService = null;
}
if (runToFinish.mExampleStoreService != null) {
LogUtil.i(TAG, "Unbinding from ExampleStoreService");
unbindFromExampleStoreService();
runToFinish.mExampleStoreService = null;
}
}
@VisibleForTesting
HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) {
return HttpFederatedProtocol.create(serverAddress, "1.0", populationName);
}
private ExampleSelector getExampleSelector(CheckinResult checkinResult) {
ClientOnlyPlan clientPlan = checkinResult.getPlanData();
switch (clientPlan.getPhase().getSpecCase()) {
case EXAMPLE_QUERY_SPEC:
// Only support one FA query for now.
return clientPlan
.getPhase()
.getExampleQuerySpec()
.getExampleQueries(0)
.getExampleSelector();
case TENSORFLOW_SPEC:
return clientPlan.getPhase().getTensorflowSpec().getExampleSelector();
default:
throw new IllegalArgumentException(
String.format(
"Client plan spec is not supported %s",
clientPlan.getPhase().getSpecCase().toString()));
}
}
private boolean checkTrainingConditions(TrainingConstraints constraints) {
Set<Condition> conditions =
mTrainingConditionsChecker.checkAllConditionsForFlTraining(constraints);
for (Condition condition : conditions) {
switch (condition) {
case THERMALS_NOT_OK:
LogUtil.e(TAG, "training job service interrupt thermals not ok");
break;
case BATTERY_NOT_OK:
LogUtil.e(TAG, "training job service interrupt battery not ok");
break;
}
}
return conditions.isEmpty();
}
@VisibleForTesting
ListenableFuture<ComputationResult> runFlComputation(
TrainingRun run,
CheckinResult checkinResult,
String outputCheckpointFile,
IExampleStoreIterator iterator) {
ParcelFileDescriptor outputCheckpointFd = createTempFileDescriptor(outputCheckpointFile);
ParcelFileDescriptor inputCheckpointFd =
createTempFileDescriptor(checkinResult.getInputCheckpointFile());
try {
IIsolatedTrainingService trainingService = getIsolatedTrainingService();
if (trainingService == null) {
LogUtil.w(TAG, "Could not bind to IsolatedTrainingService");
throw new IllegalStateException("Could not bind to IsolatedTrainingService");
}
run.mIsolatedTrainingService = trainingService;
Bundle bundle = new Bundle();
bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName());
bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, inputCheckpointFd);
bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, outputCheckpointFd);
bundle.putInt(Constants.EXTRA_JOB_ID, run.mJobId);
bundle.putBinder(Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, iterator.asBinder());
return FluentFuture.from(runIsolatedTrainingProcess(run, bundle))
.transform(
result -> {
ComputationResult computationResult =
processIsolatedTrainingResult(outputCheckpointFile, result);
// Close opened file descriptor.
try {
if (outputCheckpointFd != null) {
outputCheckpointFd.close();
}
if (inputCheckpointFd != null) {
inputCheckpointFd.close();
}
} catch (IOException e) {
LogUtil.e(TAG, "Failed to close file descriptor", e);
} finally {
// Unbind from IsolatedTrainingService.
LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
unbindFromIsolatedTrainingService();
run.mIsolatedTrainingService = null;
}
return computationResult;
},
getLightweightExecutor());
} catch (Exception e) {
// Close opened file descriptor.
try {
if (outputCheckpointFd != null) {
outputCheckpointFd.close();
}
if (inputCheckpointFd != null) {
inputCheckpointFd.close();
}
} catch (IOException t) {
LogUtil.e(TAG, t, "Failed to close file descriptor");
} finally {
// Unbind from IsolatedTrainingService.
LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
unbindFromIsolatedTrainingService();
run.mIsolatedTrainingService = null;
}
return Futures.immediateFailedFuture(e);
}
}
private ComputationResult processIsolatedTrainingResult(
String outputCheckpoint, Bundle result) {
byte[] resultBytes =
Objects.requireNonNull(result.getByteArray(Constants.EXTRA_FL_RUNNER_RESULT));
FLRunnerResult flRunnerResult;
try {
flRunnerResult = FLRunnerResult.parseFrom(resultBytes);
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException(e);
}
ArrayList<ExampleConsumption> exampleList =
result.getParcelableArrayList(
ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, ExampleConsumption.class);
if (exampleList == null || exampleList.isEmpty()) {
throw new IllegalArgumentException("example consumption list should not be empty");
}
return new ComputationResult(outputCheckpoint, flRunnerResult, exampleList);
}
private ListenableFuture<Bundle> runIsolatedTrainingProcess(TrainingRun run, Bundle input) {
return CallbackToFutureAdapter.getFuture(
completer -> {
try {
run.mIsolatedTrainingService.runFlTraining(
input,
new ITrainingResultCallback.Stub() {
@Override
public void onResult(Bundle result) {
completer.set(result);
}
});
} catch (Exception e) {
completer.setException(e);
}
return "runIsolatedTrainingProcess";
});
}
private ListenableFuture<ComputationResult> runFederatedComputation(
CheckinResult checkinResult, TrainingRun run, IExampleStoreIterator iterator) {
ClientOnlyPlan clientPlan = checkinResult.getPlanData();
String outputCheckpointFile = createTempFile("output", ".ckp");
ListenableFuture<ComputationResult> computationResultFuture;
switch (clientPlan.getPhase().getSpecCase()) {
case EXAMPLE_QUERY_SPEC:
computationResultFuture =
runFAComputation(run, checkinResult, outputCheckpointFile, iterator);
break;
case TENSORFLOW_SPEC:
computationResultFuture =
runFlComputation(run, checkinResult, outputCheckpointFile, iterator);
break;
default:
return Futures.immediateFailedFuture(
new IllegalArgumentException(
String.format(
"Client plan spec is not supported %s",
clientPlan.getPhase().getSpecCase().toString())));
}
return computationResultFuture;
}
private ListenableFuture<ComputationResult> runFAComputation(
TrainingRun run,
CheckinResult checkinResult,
String outputCheckpointFile,
IExampleStoreIterator exampleStoreIterator) {
ExampleSelector exampleSelector = getExampleSelector(checkinResult);
ClientOnlyPlan clientPlan = checkinResult.getPlanData();
// The federated analytic runs in main process which has permission to file system.
ExampleConsumptionRecorder recorder = mInjector.getExampleConsumptionRecorder();
FLRunnerResult runResult =
mComputationRunner.runTaskWithNativeRunner(
run.mJobId,
run.mTask.populationName(),
checkinResult.getInputCheckpointFile(),
outputCheckpointFile,
clientPlan,
exampleSelector,
recorder,
exampleStoreIterator,
mInterruptSupplier);
ArrayList<ExampleConsumption> exampleConsumptions = recorder.finishRecordingAndGet();
return Futures.immediateFuture(
new ComputationResult(outputCheckpointFile, runResult, exampleConsumptions));
}
@VisibleForTesting
IExampleStoreService getExampleStoreService(String packageName) {
mExampleStoreServiceBinder =
AbstractServiceBinder.getServiceBinder(
mContext,
ClientConstants.EXAMPLE_STORE_ACTION,
IExampleStoreService.Stub::asInterface);
Intent intent = new Intent(ClientConstants.EXAMPLE_STORE_ACTION).setPackage(packageName);
return mExampleStoreServiceBinder.getService(Runnable::run, intent);
}
@VisibleForTesting
void unbindFromExampleStoreService() {
mExampleStoreServiceBinder.unbindFromService();
}
private ListenableFuture<IExampleStoreIterator> runExampleStoreStartQuery(
TrainingRun run, Bundle input) {
return CallbackToFutureAdapter.getFuture(
completer -> {
try {
run.mExampleStoreService.startQuery(
input,
new IExampleStoreCallback.Stub() {
@Override
public void onStartQuerySuccess(
IExampleStoreIterator iterator) {
LogUtil.d(TAG, "Acquire iterator");
completer.set(iterator);
}
@Override
public void onStartQueryFailure(int errorCode) {
LogUtil.e(TAG, "Could not acquire iterator: " + errorCode);
completer.setException(
new IllegalStateException(
"StartQuery failed: " + errorCode));
}
});
} catch (Exception e) {
completer.setException(e);
}
return "runExampleStoreStartQuery";
});
}
private ListenableFuture<IExampleStoreIterator> getExampleStoreIterator(
TrainingRun run, String packageName, String taskName, ExampleSelector exampleSelector) {
try {
run.mTaskName = taskName;
IExampleStoreService exampleStoreService = getExampleStoreService(packageName);
if (exampleStoreService == null) {
return Futures.immediateFailedFuture(
new IllegalStateException(
"Could not bind to ExampleStoreService " + packageName));
}
run.mExampleStoreService = exampleStoreService;
String collection = exampleSelector.getCollectionUri();
byte[] criteria = exampleSelector.getCriteria().toByteArray();
byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray();
Bundle bundle = new Bundle();
bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName());
bundle.putString(ClientConstants.EXTRA_COLLECTION_NAME, collection);
bundle.putString(ClientConstants.EXTRA_TASK_NAME, taskName);
bundle.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, run.mTask.contextData());
bundle.putByteArray(
ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
return runExampleStoreStartQuery(run, bundle);
} catch (Exception e) {
LogUtil.e(TAG, "StartQuery failure: " + e.getMessage());
return Futures.immediateFailedFuture(e);
}
}
@VisibleForTesting
@Nullable
IIsolatedTrainingService getIsolatedTrainingService() {
mIsolatedTrainingServiceBinder =
AbstractServiceBinder.getServiceBinder(
mContext,
ISOLATED_TRAINING_SERVICE_NAME,
IIsolatedTrainingService.Stub::asInterface);
Intent intent = new Intent();
ComponentName serviceComponent =
new ComponentName(mContext.getPackageName(), ISOLATED_TRAINING_SERVICE_NAME);
intent.setComponent(serviceComponent);
return mIsolatedTrainingServiceBinder.getService(Runnable::run, intent);
}
@VisibleForTesting
void unbindFromIsolatedTrainingService() {
mIsolatedTrainingServiceBinder.unbindFromService();
}
@VisibleForTesting
static class Injector {
ExampleConsumptionRecorder getExampleConsumptionRecorder() {
return new ExampleConsumptionRecorder();
}
}
private static final class TrainingRun {
private final int mJobId;
private String mTaskName;
private final FederatedTrainingTask mTask;
@Nullable private ListenableFuture<?> mFuture;
@Nullable private IIsolatedTrainingService mIsolatedTrainingService = null;
@Nullable private IExampleStoreService mExampleStoreService = null;
private TrainingRun(int jobId, FederatedTrainingTask task) {
this.mJobId = jobId;
this.mTask = task;
}
}
}