blob: 62035e8b9f37ab074382623e4bb999e7fee94782 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.federatedcompute.services.training;
import android.content.Context;
import android.federatedcompute.aidl.IExampleStoreIterator;
import android.federatedcompute.common.ClientConstants;
import android.os.Bundle;
import android.os.ParcelFileDescriptor;
import android.os.RemoteException;
import com.android.federatedcompute.internal.util.LogUtil;
import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.common.FederatedComputeExecutors;
import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService;
import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback;
import com.android.federatedcompute.services.training.util.ListenableSupplier;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.intelligence.fcp.client.FLRunnerResult;
import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import com.google.internal.federated.plan.ClientOnlyPlan;
import com.google.internal.federated.plan.ExampleSelector;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnull;
/** The implementation of {@link IsolatedTrainingService}. */
public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub {
private static final String TAG = IsolatedTrainingServiceImpl.class.getSimpleName();
private final AtomicBoolean mInterruptFlag = new AtomicBoolean(false);
@VisibleForTesting
ListenableSupplier<Boolean> mInterruptState = new ListenableSupplier<>(mInterruptFlag::get);
private final ComputationRunner mComputationRunner;
public IsolatedTrainingServiceImpl(Context context) {
mComputationRunner = new ComputationRunner(context);
}
@VisibleForTesting
IsolatedTrainingServiceImpl(ComputationRunner computationRunner) {
this.mComputationRunner = computationRunner;
}
@Override
public void runFlTraining(@Nonnull Bundle params, @Nonnull ITrainingResultCallback callback) {
Objects.requireNonNull(params);
Objects.requireNonNull(callback);
IExampleStoreIterator exampleStoreIteratorBinder =
IExampleStoreIterator.Stub.asInterface(
Objects.requireNonNull(
params.getBinder(Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER)));
Objects.requireNonNull(exampleStoreIteratorBinder);
ExampleConsumptionRecorder recorder = new ExampleConsumptionRecorder();
byte[] exampleSelectorBytes =
Objects.requireNonNull(params.getByteArray(Constants.EXTRA_EXAMPLE_SELECTOR));
ExampleSelector exampleSelector;
try {
exampleSelector = ExampleSelector.parseFrom(exampleSelectorBytes);
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException("ExampleSelector proto is invalid", e);
}
String populationName =
Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME));
ParcelFileDescriptor inputCheckpointFd =
Objects.requireNonNull(
params.getParcelable(
Constants.EXTRA_INPUT_CHECKPOINT_FD, ParcelFileDescriptor.class));
ParcelFileDescriptor outputCheckpointFd =
Objects.requireNonNull(
params.getParcelable(
Constants.EXTRA_OUTPUT_CHECKPOINT_FD, ParcelFileDescriptor.class));
int jobId = params.getInt(Constants.EXTRA_JOB_ID);
byte[] clientPlanBytes =
Objects.requireNonNull(params.getByteArray(Constants.EXTRA_CLIENT_ONLY_PLAN));
ClientOnlyPlan clientPlan;
try {
clientPlan = ClientOnlyPlan.parseFrom(clientPlanBytes);
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException("ClientOnlyPlan proto is invalid", e);
}
ListenableFuture<FLRunnerResult> resultFuture =
Futures.submit(
() ->
mComputationRunner.runTaskWithNativeRunner(
jobId,
populationName,
getFileDescriptorForTensorflow(inputCheckpointFd),
getFileDescriptorForTensorflow(outputCheckpointFd),
clientPlan,
exampleSelector,
recorder,
exampleStoreIteratorBinder,
mInterruptState),
FederatedComputeExecutors.getBackgroundExecutor());
Futures.addCallback(
resultFuture,
new FutureCallback<FLRunnerResult>() {
@Override
public void onSuccess(FLRunnerResult result) {
sendResult(result, callback);
}
@Override
public void onFailure(Throwable t) {
LogUtil.e(TAG, t, "Failed to runTaskWithNativeRunner");
FLRunnerResult result =
FLRunnerResult.newBuilder()
.setContributionResult(ContributionResult.FAIL)
.build();
sendResult(result, callback);
}
},
FederatedComputeExecutors.getLightweightExecutor());
}
// We implement a customized tensorflow filesystem which support file descriptor for read and
// write. The file format is "fd:///${fd_number}".
private String getFileDescriptorForTensorflow(ParcelFileDescriptor parcelFileDescriptor) {
return "fd:///" + parcelFileDescriptor.getFd();
}
private void sendResult(FLRunnerResult result, ITrainingResultCallback callback) {
Bundle bundle = new Bundle();
bundle.putByteArray(Constants.EXTRA_FL_RUNNER_RESULT, result.toByteArray());
try {
callback.onResult(bundle);
} catch (RemoteException e) {
LogUtil.w(TAG, e, ": Callback failed ");
}
}
@Override
public void cancelTraining(long runId) {
mInterruptFlag.set(true);
mInterruptState.runListeners();
}
}