Implement a single training run workflow when fcp job triggered

Bug: 276931193
Test: atest
Change-Id: I913671d743082220816a213b1dea20ad78594dd1
diff --git a/federatedcompute/apk/Android.bp b/federatedcompute/apk/Android.bp
index ed5ef23..2c6a970 100644
--- a/federatedcompute/apk/Android.bp
+++ b/federatedcompute/apk/Android.bp
@@ -48,6 +48,7 @@
     plugins: ["auto_value_plugin"],
     static_libs: [
         "flatbuffers-java",
+        "androidx.concurrent_concurrent-futures",
         "federated-compute-java-proto-lite",
         "guava",
         "modules-utils-preconditions",
diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java b/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java
index 92ed6ec..f1c830d 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java
@@ -18,25 +18,20 @@
 
 /** Constants used internally in the FederatedCompute APK. */
 public class Constants {
-    public static final String EXTRA_COLLECTION_NAME = "android.federatedcompute.collection_name";
-    public static final String EXTRA_EXAMPLE_ITERATOR_CRITERIA =
-            "android.federatedcompute.example_iterator_criteria";
-    public static final String EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN =
-            "android.federatedcompute.example_iterator_resumption_token";
 
     public static final String EXTRA_EXAMPLE_STORE_ITERATOR_BINDER =
             "android.federatedcompute.example_store_iterator_binder";
-    public static final String EXTRA_RESULT_HANDLING_SERVICE_BINDER =
-            "android.federatedcompute.result_handling_service_binder";
     public static final String EXTRA_INPUT_CHECKPOINT_FD =
             "android.federatedcompute.input_checkpoint_fd";
     public static final String EXTRA_OUTPUT_CHECKPOINT_FD =
             "android.federatedcompute.output_checkpoint_fd";
-    public static final String EXTRA_POPULATION_NAME = "android.federatedcompute.population_name";
     public static final String EXTRA_FL_RUNNER_RESULT = "android.federatedcompute.fl_runner_result";
     public static final String EXTRA_JOB_ID = "android.federatedcompute.job_id";
     public static final String EXTRA_EXAMPLE_SELECTOR = "android.federatedcompute.example_selector";
     public static final String EXTRA_CLIENT_ONLY_PLAN = "android.federatedcompute.client_only_plan";
 
+    public static final String ISOLATED_TRAINING_SERVICE_NAME =
+            "com.android.federatedcompute.service.training.IsolatedTrainingService";
+
     private Constants() {}
 }
diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java b/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java
new file mode 100644
index 0000000..e2b7d94
--- /dev/null
+++ b/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.federatedcompute.services.common;
+
+import android.os.ParcelFileDescriptor;
+
+import com.android.federatedcompute.internal.util.LogUtil;
+
+import java.io.BufferedInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+/** Utils related to {@link File} and {@link ParcelFileDescriptor}. */
+public class FileUtils {
+    private static final String TAG = FileUtils.class.getSimpleName();
+
+    private static final int BUFFER_SIZE = 1024;
+
+    /** Create {@link ParcelFileDescriptor} based on the input file. */
+    public static ParcelFileDescriptor createTempFileDescriptor(String fileName) {
+        ParcelFileDescriptor fileDescriptor;
+        try {
+            fileDescriptor =
+                    ParcelFileDescriptor.open(
+                            new File(fileName), ParcelFileDescriptor.MODE_READ_ONLY);
+        } catch (IOException e) {
+            LogUtil.e(TAG, e, "Failed to createTempFileDescriptor %s", fileName);
+            throw new RuntimeException(e);
+        }
+        return fileDescriptor;
+    }
+
+    /** Create a temporary file based on provided name and extension. */
+    public static String createTempFile(String name, String extension) {
+        String fileName;
+        try {
+            File tempFile = File.createTempFile(name, extension);
+            fileName = tempFile.getAbsolutePath();
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+        return fileName;
+    }
+
+    /** Write the provided data to the file. */
+    public static void writeToFile(String fileName, byte[] data) throws IOException {
+        FileOutputStream out = new FileOutputStream(fileName);
+        out.write(data);
+        out.close();
+    }
+
+    /** Read the input file content to a byte array. */
+    public static byte[] readFileAsByteArray(String filePath) throws IOException {
+        File file = new File(filePath);
+        long fileLength = file.length();
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) fileLength);
+        try (BufferedInputStream inputStream = new BufferedInputStream(new FileInputStream(file))) {
+            byte[] buffer = new byte[BUFFER_SIZE];
+            for (int len = inputStream.read(buffer); len > 0; len = inputStream.read(buffer)) {
+                outputStream.write(buffer, 0, len);
+            }
+        } catch (IOException e) {
+            LogUtil.e(TAG, e, "Failed to read the content of binary file %s", filePath);
+            throw e;
+        }
+        return outputStream.toByteArray();
+    }
+}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java
index 22dcf9b..76a1b77 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java
@@ -89,8 +89,8 @@
     }
 
     /** Returns all recorded {@link ExampleConsumption}. */
-    public synchronized List<ExampleConsumption> finishRecordingAndGet() {
-        List<ExampleConsumption> exampleConsumptions = new ArrayList<>();
+    public synchronized ArrayList<ExampleConsumption> finishRecordingAndGet() {
+        ArrayList<ExampleConsumption> exampleConsumptions = new ArrayList<>();
         for (SingleQueryRecorder recorder : mSingleQueryRecorders) {
             exampleConsumptions.add(recorder.finishRecordingAndGet());
         }
diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java
index ccb9df8..4d029b3 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java
@@ -27,8 +27,6 @@
 import android.util.Pair;
 
 import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.common.Constants;
-import com.android.federatedcompute.services.common.ErrorStatusException;
 import com.android.federatedcompute.services.common.Flags;
 
 import com.google.common.util.concurrent.SettableFuture;
@@ -53,8 +51,7 @@
 
     @Override
     public IExampleStoreIterator getExampleStoreIterator(
-            String packageName, ExampleSelector exampleSelector)
-            throws InterruptedException, ErrorStatusException {
+            String packageName, ExampleSelector exampleSelector) throws InterruptedException {
         String collection = exampleSelector.getCollectionUri();
         byte[] criteria = exampleSelector.getCriteria().toByteArray();
         byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray();
@@ -71,9 +68,10 @@
         IExampleStoreService exampleStoreService =
                 mExampleStoreServiceProvider.getExampleStoreService();
         Bundle bundle = new Bundle();
-        bundle.putString(Constants.EXTRA_COLLECTION_NAME, collection);
-        bundle.putByteArray(Constants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
-        bundle.putByteArray(Constants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
+        bundle.putString(ClientConstants.EXTRA_COLLECTION_NAME, collection);
+        bundle.putByteArray(
+                ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
+        bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
         SettableFuture<Pair<IExampleStoreIterator, Integer>> iteratorOrFailureFuture =
                 SettableFuture.create();
         try {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java b/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java
index 0e747bf..e82bfe6 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java
@@ -18,6 +18,8 @@
 
 import android.annotation.Nullable;
 
+import com.android.internal.util.Preconditions;
+
 import com.google.internal.federated.plan.ClientOnlyPlan;
 import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment;
 
@@ -25,20 +27,23 @@
  * The result after client calls TaskAssignemnt API. It includes init checkpoint data and plan data.
  */
 public class CheckinResult {
-    private byte[] mCheckpointData = null;
+    private String mInputCheckpoint = null;
     private ClientOnlyPlan mPlanData = null;
     private TaskAssignment mTaskAssignment = null;
 
     public CheckinResult(
-            byte[] checkpointData, ClientOnlyPlan planData, TaskAssignment taskAssignment) {
-        this.mCheckpointData = checkpointData;
+            String inputCheckpoint, ClientOnlyPlan planData, TaskAssignment taskAssignment) {
+        this.mInputCheckpoint = inputCheckpoint;
         this.mPlanData = planData;
         this.mTaskAssignment = taskAssignment;
     }
 
     @Nullable
-    public byte[] getCheckpointData() {
-        return mCheckpointData;
+    public String getInputCheckpointFile() {
+        Preconditions.checkArgument(
+                mInputCheckpoint != null && !mInputCheckpoint.isEmpty(),
+                "Input checkpoint file should not be none or empty");
+        return mInputCheckpoint;
     }
 
     @Nullable
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
index a77f804..f622f7f 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
@@ -18,6 +18,9 @@
 
 import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor;
 import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
+import static com.android.federatedcompute.services.common.FileUtils.readFileAsByteArray;
+import static com.android.federatedcompute.services.common.FileUtils.writeToFile;
 import static com.android.federatedcompute.services.http.HttpClientUtil.HTTP_OK_STATUS;
 
 import com.android.federatedcompute.internal.util.LogUtil;
@@ -42,17 +45,11 @@
 import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction;
 import com.google.protobuf.InvalidProtocolBufferException;
 
-import java.io.BufferedInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
 import java.util.HashMap;
 
 /** Implements a single session of HTTP-based federated compute protocol. */
 public final class HttpFederatedProtocol {
     public static final String TAG = "HttpFederatedProtocol";
-    private static final int BUFFER_SIZE = 1024;
 
     private final String mClientVersion;
     private final String mPopulationName;
@@ -204,9 +201,10 @@
                 return Futures.immediateFailedFuture(
                         new IllegalStateException("Could not parse ClientOnlyPlan proto", e));
             }
+            String inputCheckpointFile = createTempFile("input", ".ckp");
+            writeToFile(inputCheckpointFile, checkpointDataResponse.getPayload());
             return Futures.immediateFuture(
-                    new CheckinResult(
-                            checkpointDataResponse.getPayload(), clientOnlyPlan, taskAssignment));
+                    new CheckinResult(inputCheckpointFile, clientOnlyPlan, taskAssignment));
 
         } catch (Exception e) {
             return Futures.immediateFailedFuture(e);
@@ -282,22 +280,6 @@
         }
     }
 
-    private byte[] readFileAsByteArray(String filePath) throws IOException {
-        File file = new File(filePath);
-        long fileLength = file.length();
-        ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) fileLength);
-        try (BufferedInputStream inputStream = new BufferedInputStream(new FileInputStream(file))) {
-            byte[] buffer = new byte[BUFFER_SIZE];
-            for (int len = inputStream.read(buffer); len > 0; len = inputStream.read(buffer)) {
-                outputStream.write(buffer, 0, len);
-            }
-        } catch (IOException e) {
-            LogUtil.e(TAG, e, "Failed to read the content of binary file %s", filePath);
-            throw e;
-        }
-        return outputStream.toByteArray();
-    }
-
     private ListenableFuture<FederatedComputeHttpResponse> fetchTaskResource(Resource resource) {
         switch (resource.getResourceCase()) {
             case URI:
diff --git a/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java b/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java
index d2909d2..7d9edf2 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java
@@ -35,7 +35,6 @@
 import com.android.federatedcompute.services.common.Flags;
 import com.android.federatedcompute.services.common.MonotonicClock;
 import com.android.federatedcompute.services.common.PhFlags;
-import com.android.federatedcompute.services.common.TrainingResult;
 import com.android.federatedcompute.services.data.FederatedTrainingTask;
 import com.android.federatedcompute.services.data.FederatedTrainingTaskDao;
 import com.android.federatedcompute.services.data.fbs.SchedulingMode;
@@ -46,6 +45,7 @@
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
 import com.google.intelligence.fcp.client.engine.TaskRetry;
 
 import java.util.Arrays;
@@ -55,14 +55,13 @@
 /** Handles scheduling training tasks e.g. calling into JobScheduler, maintaining datastore. */
 public class FederatedComputeJobManager {
     private static final String TAG = FederatedComputeJobManager.class.getSimpleName();
-
+    private static volatile FederatedComputeJobManager sSingletonInstance;
     @NonNull private final Context mContext;
     private final FederatedTrainingTaskDao mFederatedTrainingTaskDao;
     private final JobSchedulerHelper mJobSchedulerHelper;
-    private static volatile FederatedComputeJobManager sSingletonInstance;
     private final FederatedJobIdGenerator mJobIdGenerator;
-    private Clock mClock;
     private final Flags mFlags;
+    private final Clock mClock;
 
     @VisibleForTesting
     FederatedComputeJobManager(
@@ -72,7 +71,7 @@
             JobSchedulerHelper jobSchedulerHelper,
             @NonNull Clock clock,
             Flags flag) {
-        this.mContext = context;
+        this.mContext = context.getApplicationContext();
         this.mFederatedTrainingTaskDao = federatedTrainingTaskDao;
         this.mJobIdGenerator = jobIdGenerator;
         this.mJobSchedulerHelper = jobSchedulerHelper;
@@ -285,7 +284,7 @@
             String populationName,
             TrainingIntervalOptions trainingIntervalOptions,
             TaskRetry taskRetry,
-            @TrainingResult int trainingResult) {
+            ContributionResult trainingResult) {
         boolean result =
                 rescheduleFederatedTaskAfterTraining(
                         jobId, populationName, trainingIntervalOptions, taskRetry, trainingResult);
@@ -300,7 +299,7 @@
             String populationName,
             TrainingIntervalOptions intervalOptions,
             TaskRetry taskRetry,
-            @TrainingResult int trainingResult) {
+            ContributionResult trainingResult) {
         FederatedTrainingTask existingTask =
                 mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationAndJobId(
                         populationName, jobId);
@@ -310,7 +309,7 @@
         if (existingTask == null) {
             return true;
         }
-        boolean hasContributed = trainingResult == TrainingResult.SUCCESS;
+        boolean hasContributed = trainingResult == ContributionResult.SUCCESS;
         if (intervalOptions != null
                 && intervalOptions.schedulingMode() == SchedulingMode.ONE_TIME
                 && hasContributed) {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
index 33de4b4..fe45813 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
@@ -16,70 +16,144 @@
 
 package com.android.federatedcompute.services.training;
 
+import static com.android.federatedcompute.services.common.Constants.ISOLATED_TRAINING_SERVICE_NAME;
+import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor;
+import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFileDescriptor;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.content.ComponentName;
 import android.content.Context;
+import android.content.Intent;
+import android.federatedcompute.aidl.IExampleStoreCallback;
+import android.federatedcompute.aidl.IExampleStoreIterator;
+import android.federatedcompute.aidl.IExampleStoreService;
+import android.federatedcompute.common.ClientConstants;
+import android.federatedcompute.common.ExampleConsumption;
+import android.os.Bundle;
+import android.os.ParcelFileDescriptor;
 
+import androidx.concurrent.futures.CallbackToFutureAdapter;
+
+import com.android.federatedcompute.internal.util.AbstractServiceBinder;
 import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.common.Flags;
-import com.android.federatedcompute.services.common.PhFlags;
-import com.android.federatedcompute.services.common.TrainingResult;
+import com.android.federatedcompute.services.common.Constants;
 import com.android.federatedcompute.services.data.FederatedTrainingTask;
+import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
+import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
+import com.android.federatedcompute.services.http.CheckinResult;
+import com.android.federatedcompute.services.http.HttpFederatedProtocol;
 import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
-import com.android.federatedcompute.services.scheduling.SchedulingUtil;
+import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService;
+import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback;
+import com.android.federatedcompute.services.training.util.ComputationResult;
+import com.android.federatedcompute.services.training.util.ListenableSupplier;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition;
+import com.android.internal.annotations.GuardedBy;
+import com.android.internal.util.Preconditions;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
+import com.google.intelligence.fcp.client.RetryInfo;
 import com.google.intelligence.fcp.client.engine.TaskRetry;
+import com.google.internal.federated.plan.ClientOnlyPlan;
+import com.google.internal.federated.plan.ExampleSelector;
+import com.google.protobuf.InvalidProtocolBufferException;
 
-import javax.annotation.concurrent.GuardedBy;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /** The worker to execute federated computation jobs. */
 public class FederatedComputeWorker {
     private static final String TAG = FederatedComputeWorker.class.getSimpleName();
+    private static volatile FederatedComputeWorker sWorker;
+    private final Object mLock = new Object();
+    private final AtomicBoolean mInterruptFlag = new AtomicBoolean(false);
+    private final ListenableSupplier<Boolean> mInterruptSupplier =
+            new ListenableSupplier<>(mInterruptFlag::get);
+    private final Context mContext;
+    @Nullable private final FederatedComputeJobManager mJobManager;
+    @Nullable private final TrainingConditionsChecker mTrainingConditionsChecker;
+    private final ComputationRunner mComputationRunner;
+    private final ResultCallbackHelper mResultCallbackHelper;
+    @NonNull private final Injector mInjector;
 
-    static final Object LOCK = new Object();
-
-    @GuardedBy("LOCK")
+    @GuardedBy("mLock")
     @Nullable
     private TrainingRun mActiveRun = null;
 
-    @Nullable private final FederatedComputeJobManager mJobManager;
-    private static volatile FederatedComputeWorker sFederatedComputeWorker;
-    private final Flags mFlags;
+    private HttpFederatedProtocol mHttpFederatedProtocol;
+    private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder;
+    private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder;
 
     @VisibleForTesting
-    public FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags) {
+    public FederatedComputeWorker(
+            Context context,
+            FederatedComputeJobManager jobManager,
+            TrainingConditionsChecker trainingConditionsChecker,
+            ComputationRunner computationRunner,
+            ResultCallbackHelper resultCallbackHelper,
+            Injector injector) {
+        this.mContext = context.getApplicationContext();
         this.mJobManager = jobManager;
-        this.mFlags = flags;
+        this.mTrainingConditionsChecker = trainingConditionsChecker;
+        this.mComputationRunner = computationRunner;
+        this.mInjector = injector;
+        this.mResultCallbackHelper = resultCallbackHelper;
     }
 
     /** Gets an instance of {@link FederatedComputeWorker}. */
     @NonNull
     public static FederatedComputeWorker getInstance(Context context) {
-        if (sFederatedComputeWorker == null) {
+        if (sWorker == null) {
             synchronized (FederatedComputeWorker.class) {
-                if (sFederatedComputeWorker == null) {
-                    sFederatedComputeWorker =
+                if (sWorker == null) {
+                    sWorker =
                             new FederatedComputeWorker(
+                                    context,
                                     FederatedComputeJobManager.getInstance(context),
-                                    PhFlags.getInstance());
+                                    TrainingConditionsChecker.getInstance(context),
+                                    new ComputationRunner(context),
+                                    new ResultCallbackHelper(context),
+                                    new Injector());
                 }
             }
         }
-        return sFederatedComputeWorker;
+        return sWorker;
     }
 
     /** Starts a training run with the given job Id. */
-    public boolean startTrainingRun(int jobId) {
+    public ListenableFuture<FLRunnerResult> startTrainingRun(int jobId) {
         LogUtil.d(TAG, "startTrainingRun()");
         FederatedTrainingTask trainingTask = mJobManager.onTrainingStarted(jobId);
         if (trainingTask == null) {
             LogUtil.i(TAG, "Could not find task to run for job ID %s", jobId);
-            return false;
+            return Futures.immediateFuture(null);
         }
 
-        synchronized (LOCK) {
-            // Only allow one concurrent federated computation job.
+        if (!checkTrainingConditions(trainingTask.getTrainingConstraints())) {
+            mJobManager.onTrainingCompleted(
+                    jobId,
+                    trainingTask.populationName(),
+                    trainingTask.getTrainingIntervalOptions(),
+                    /* taskRetry= */ null,
+                    ContributionResult.FAIL);
+            LogUtil.i(TAG, "Training conditions not satisfied (before bindService)!");
+            return Futures.immediateFuture(null);
+        }
+
+        synchronized (mLock) {
+            // Only allows one concurrent job running.
             if (mActiveRun != null) {
                 LogUtil.i(
                         TAG,
@@ -91,55 +165,469 @@
                         trainingTask.populationName(),
                         trainingTask.getTrainingIntervalOptions(),
                         /* taskRetry= */ null,
-                        TrainingResult.FAIL);
-                return false;
+                        ContributionResult.FAIL);
+                return Futures.immediateFuture(null);
             }
+
             TrainingRun run = new TrainingRun(jobId, trainingTask);
-            this.mActiveRun = run;
-            doTraining(run);
-            // TODO: get retry info from federated server.
-            TaskRetry taskRetry = SchedulingUtil.generateTransientErrorTaskRetry(mFlags);
-            finish(this.mActiveRun, taskRetry, TrainingResult.SUCCESS);
+            mActiveRun = run;
+            ListenableFuture<FLRunnerResult> runCompletedFuture = doTraining(run);
+            run.mFuture = runCompletedFuture;
+            return runCompletedFuture;
         }
-        return true;
     }
 
-    /** Cancels the running job if present. */
-    public void cancelActiveRun() {
-        LogUtil.d(TAG, "cancelActiveRun()");
-        synchronized (LOCK) {
+    private ListenableFuture<FLRunnerResult> doTraining(TrainingRun run) {
+        try {
+            // 1. Communicate with remote federated compute server to start task assignment and
+            // download client plan and initial model checkpoint. Note: use bLocking executors for
+            // http requests.
+            mHttpFederatedProtocol =
+                    getHttpFederatedProtocol(run.mTask.serverAddress(), run.mTask.populationName());
+            ListenableFuture<CheckinResult> checkinResultFuture =
+                    mHttpFederatedProtocol.issueCheckin();
+
+            // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector.
+            ListenableFuture<IExampleStoreIterator> iteratorFuture =
+                    FluentFuture.from(checkinResultFuture)
+                            .transform(
+                                    result -> {
+                                        // Set active run's task name.
+                                        String taskName = result.getTaskAssignment().getTaskName();
+                                        Preconditions.checkArgument(
+                                                !taskName.isEmpty(),
+                                                "Task name should not be empty");
+                                        synchronized (mLock) {
+                                            mActiveRun.mTaskName = taskName;
+                                        }
+                                        return getExampleSelector(result);
+                                    },
+                                    getLightweightExecutor())
+                            .transformAsync(
+                                    selector ->
+                                            getExampleStoreIterator(
+                                                    run,
+                                                    run.mTask.appPackageName(),
+                                                    run.mTaskName,
+                                                    selector),
+                                    getBackgroundExecutor());
+
+            // 3. Run federated learning or federated analytic depends on task type. Federated
+            // learning job will start a new isolated process to run TFLite training.
+            ListenableFuture<ComputationResult> computationResultFuture =
+                    Futures.whenAllSucceed(checkinResultFuture, iteratorFuture)
+                            .callAsync(
+                                    () ->
+                                            runFederatedComputation(
+                                                    Futures.getDone(checkinResultFuture),
+                                                    run,
+                                                    Futures.getDone(iteratorFuture)),
+                                    getBackgroundExecutor());
+
+            // 4. Report computation result to federated compute server.
+            ListenableFuture<Void> reportToServerFuture =
+                    FluentFuture.from(computationResultFuture)
+                            .transformAsync(
+                                    result -> mHttpFederatedProtocol.reportResult(result),
+                                    getLightweightExecutor());
+            return FluentFuture.from(
+                    Futures.whenAllSucceed(reportToServerFuture, computationResultFuture)
+                            .call(
+                                    () -> {
+                                        ComputationResult result =
+                                                Futures.getDone(computationResultFuture);
+                                        var reportToServer = Futures.getDone(reportToServerFuture);
+                                        // 5. Publish computation result and consumed examples to
+                                        // client implemented ResultHandlingService.
+                                        var unused =
+                                                mResultCallbackHelper.callHandleResult(
+                                                        run.mTaskName, run.mTask, result);
+                                        return result.getFlRunnerResult();
+                                    },
+                                    getBackgroundExecutor()));
+        } catch (Exception e) {
+            return Futures.immediateFailedFuture(e);
+        }
+    }
+
+    /**
+     * Completes the running job , schedule recurrent job, and unbind from ExampleStoreService and
+     * ResultHandlingService etc.
+     */
+    public void finish(FLRunnerResult flRunnerResult) {
+        TaskRetry taskRetry = null;
+        if (flRunnerResult != null) {
+            if (flRunnerResult.hasRetryInfo()) {
+                RetryInfo retryInfo = flRunnerResult.getRetryInfo();
+                long delay = retryInfo.getMinimumDelay().getSeconds() * 1000L;
+                taskRetry =
+                        TaskRetry.newBuilder()
+                                .setRetryToken(retryInfo.getRetryToken())
+                                .setDelayMin(delay)
+                                .setDelayMax(delay)
+                                .build();
+                LogUtil.i(TAG, "Finished with task retry= %s", taskRetry);
+            }
+        }
+        finish(taskRetry, flRunnerResult.getContributionResult(), true);
+    }
+
+    /**
+     * Cancel the current running job, schedule recurrent job, unbind from ExampleStoreService and
+     * ResultHandlingService etc.
+     */
+    public void finish(
+            TaskRetry taskRetry, ContributionResult contributionResult, boolean cancelFuture) {
+        TrainingRun runToFinish;
+        synchronized (mLock) {
             if (mActiveRun == null) {
                 return;
             }
-            finish(mActiveRun, /* taskRetry= */ null, TrainingResult.FAIL);
-        }
-    }
 
-    private void finish(
-            TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult) {
-        synchronized (LOCK) {
-            if (mActiveRun != runToFinish) {
-                return;
-            }
+            runToFinish = mActiveRun;
             mActiveRun = null;
-            mJobManager.onTrainingCompleted(
-                    runToFinish.mJobId,
-                    runToFinish.mTask.populationName(),
-                    runToFinish.mTask.getTrainingIntervalOptions(),
-                    taskRetry,
-                    trainingResult);
+            if (cancelFuture) {
+                runToFinish.mFuture.cancel(true);
+            }
+        }
+
+        unBindServicesIfNecessary(runToFinish);
+        mJobManager.onTrainingCompleted(
+                runToFinish.mJobId,
+                runToFinish.mTask.populationName(),
+                runToFinish.mTask.getTrainingIntervalOptions(),
+                taskRetry,
+                contributionResult);
+    }
+
+    private void unBindServicesIfNecessary(TrainingRun runToFinish) {
+        if (runToFinish.mIsolatedTrainingService != null) {
+            LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
+            unbindFromIsolatedTrainingService();
+            runToFinish.mIsolatedTrainingService = null;
+        }
+        if (runToFinish.mExampleStoreService != null) {
+            LogUtil.i(TAG, "Unbinding from ExampleStoreService");
+            unbindFromExampleStoreService();
+            runToFinish.mExampleStoreService = null;
         }
     }
 
-    private void doTraining(TrainingRun run) {
-        // TODO: add training logic.
-        LogUtil.d(TAG, "Start run training job %d ", run.mJobId);
+    @VisibleForTesting
+    HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) {
+        return HttpFederatedProtocol.create(serverAddress, "1.0", populationName);
+    }
+
+    private ExampleSelector getExampleSelector(CheckinResult checkinResult) {
+        ClientOnlyPlan clientPlan = checkinResult.getPlanData();
+        switch (clientPlan.getPhase().getSpecCase()) {
+            case EXAMPLE_QUERY_SPEC:
+                // Only support one FA query for now.
+                return clientPlan
+                        .getPhase()
+                        .getExampleQuerySpec()
+                        .getExampleQueries(0)
+                        .getExampleSelector();
+            case TENSORFLOW_SPEC:
+                return clientPlan.getPhase().getTensorflowSpec().getExampleSelector();
+            default:
+                throw new IllegalArgumentException(
+                        String.format(
+                                "Client plan spec is not supported %s",
+                                clientPlan.getPhase().getSpecCase().toString()));
+        }
+    }
+
+    private boolean checkTrainingConditions(TrainingConstraints constraints) {
+        Set<Condition> conditions =
+                mTrainingConditionsChecker.checkAllConditionsForFlTraining(constraints);
+        for (Condition condition : conditions) {
+            switch (condition) {
+                case THERMALS_NOT_OK:
+                    LogUtil.e(TAG, "training job service interrupt thermals not ok");
+                    break;
+                case BATTERY_NOT_OK:
+                    LogUtil.e(TAG, "training job service interrupt battery not ok");
+                    break;
+            }
+        }
+        return conditions.isEmpty();
+    }
+
+    @VisibleForTesting
+    ListenableFuture<ComputationResult> runFlComputation(
+            TrainingRun run,
+            CheckinResult checkinResult,
+            String outputCheckpointFile,
+            IExampleStoreIterator iterator) {
+        ParcelFileDescriptor outputCheckpointFd = createTempFileDescriptor(outputCheckpointFile);
+        ParcelFileDescriptor inputCheckpointFd =
+                createTempFileDescriptor(checkinResult.getInputCheckpointFile());
+        try {
+            IIsolatedTrainingService trainingService = getIsolatedTrainingService();
+            if (trainingService == null) {
+                LogUtil.w(TAG, "Could not bind to IsolatedTrainingService");
+                throw new IllegalStateException("Could not bind to IsolatedTrainingService");
+            }
+            run.mIsolatedTrainingService = trainingService;
+
+            Bundle bundle = new Bundle();
+            bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName());
+            bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, inputCheckpointFd);
+            bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, outputCheckpointFd);
+            bundle.putInt(Constants.EXTRA_JOB_ID, run.mJobId);
+            bundle.putBinder(Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, iterator.asBinder());
+            return FluentFuture.from(runIsolatedTrainingProcess(run, bundle))
+                    .transform(
+                            result -> {
+                                ComputationResult computationResult =
+                                        processIsolatedTrainingResult(outputCheckpointFile, result);
+                                // Close opened file descriptor.
+                                try {
+                                    if (outputCheckpointFd != null) {
+                                        outputCheckpointFd.close();
+                                    }
+                                    if (inputCheckpointFd != null) {
+                                        inputCheckpointFd.close();
+                                    }
+                                } catch (IOException e) {
+                                    LogUtil.e(TAG, "Failed to close file descriptor", e);
+                                } finally {
+                                    // Unbind from IsolatedTrainingService.
+                                    LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
+                                    unbindFromIsolatedTrainingService();
+                                    run.mIsolatedTrainingService = null;
+                                }
+                                return computationResult;
+                            },
+                            getLightweightExecutor());
+        } catch (Exception e) {
+            // Close opened file descriptor.
+            try {
+                if (outputCheckpointFd != null) {
+                    outputCheckpointFd.close();
+                }
+                if (inputCheckpointFd != null) {
+                    inputCheckpointFd.close();
+                }
+            } catch (IOException t) {
+                LogUtil.e(TAG, t, "Failed to close file descriptor");
+            } finally {
+                // Unbind from IsolatedTrainingService.
+                LogUtil.i(TAG, "Unbinding from IsolatedTrainingService");
+                unbindFromIsolatedTrainingService();
+                run.mIsolatedTrainingService = null;
+            }
+            return Futures.immediateFailedFuture(e);
+        }
+    }
+
+    private ComputationResult processIsolatedTrainingResult(
+            String outputCheckpoint, Bundle result) {
+        byte[] resultBytes =
+                Objects.requireNonNull(result.getByteArray(Constants.EXTRA_FL_RUNNER_RESULT));
+        FLRunnerResult flRunnerResult;
+        try {
+            flRunnerResult = FLRunnerResult.parseFrom(resultBytes);
+        } catch (InvalidProtocolBufferException e) {
+            throw new IllegalArgumentException(e);
+        }
+        ArrayList<ExampleConsumption> exampleList =
+                result.getParcelableArrayList(
+                        ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, ExampleConsumption.class);
+        if (exampleList == null || exampleList.isEmpty()) {
+            throw new IllegalArgumentException("example consumption list should not be empty");
+        }
+
+        return new ComputationResult(outputCheckpoint, flRunnerResult, exampleList);
+    }
+
+    private ListenableFuture<Bundle> runIsolatedTrainingProcess(TrainingRun run, Bundle input) {
+        return CallbackToFutureAdapter.getFuture(
+                completer -> {
+                    try {
+                        run.mIsolatedTrainingService.runFlTraining(
+                                input,
+                                new ITrainingResultCallback.Stub() {
+                                    @Override
+                                    public void onResult(Bundle result) {
+                                        completer.set(result);
+                                    }
+                                });
+                    } catch (Exception e) {
+                        completer.setException(e);
+                    }
+                    return "runIsolatedTrainingProcess";
+                });
+    }
+
+    private ListenableFuture<ComputationResult> runFederatedComputation(
+            CheckinResult checkinResult, TrainingRun run, IExampleStoreIterator iterator) {
+        ClientOnlyPlan clientPlan = checkinResult.getPlanData();
+        String outputCheckpointFile = createTempFile("output", ".ckp");
+
+        ListenableFuture<ComputationResult> computationResultFuture;
+        switch (clientPlan.getPhase().getSpecCase()) {
+            case EXAMPLE_QUERY_SPEC:
+                computationResultFuture =
+                        runFAComputation(run, checkinResult, outputCheckpointFile, iterator);
+                break;
+            case TENSORFLOW_SPEC:
+                computationResultFuture =
+                        runFlComputation(run, checkinResult, outputCheckpointFile, iterator);
+                break;
+            default:
+                return Futures.immediateFailedFuture(
+                        new IllegalArgumentException(
+                                String.format(
+                                        "Client plan spec is not supported %s",
+                                        clientPlan.getPhase().getSpecCase().toString())));
+        }
+        return computationResultFuture;
+    }
+
+    private ListenableFuture<ComputationResult> runFAComputation(
+            TrainingRun run,
+            CheckinResult checkinResult,
+            String outputCheckpointFile,
+            IExampleStoreIterator exampleStoreIterator) {
+        ExampleSelector exampleSelector = getExampleSelector(checkinResult);
+        ClientOnlyPlan clientPlan = checkinResult.getPlanData();
+        // The federated analytic runs in main process which has permission to file system.
+        ExampleConsumptionRecorder recorder = mInjector.getExampleConsumptionRecorder();
+        FLRunnerResult runResult =
+                mComputationRunner.runTaskWithNativeRunner(
+                        run.mJobId,
+                        run.mTask.populationName(),
+                        checkinResult.getInputCheckpointFile(),
+                        outputCheckpointFile,
+                        clientPlan,
+                        exampleSelector,
+                        recorder,
+                        exampleStoreIterator,
+                        mInterruptSupplier);
+        ArrayList<ExampleConsumption> exampleConsumptions = recorder.finishRecordingAndGet();
+        return Futures.immediateFuture(
+                new ComputationResult(outputCheckpointFile, runResult, exampleConsumptions));
+    }
+
+    @VisibleForTesting
+    IExampleStoreService getExampleStoreService(String packageName) {
+        mExampleStoreServiceBinder =
+                AbstractServiceBinder.getServiceBinder(
+                        mContext,
+                        ClientConstants.EXAMPLE_STORE_ACTION,
+                        IExampleStoreService.Stub::asInterface);
+        Intent intent = new Intent(ClientConstants.EXAMPLE_STORE_ACTION).setPackage(packageName);
+        return mExampleStoreServiceBinder.getService(Runnable::run, intent);
+    }
+
+    @VisibleForTesting
+    void unbindFromExampleStoreService() {
+        mExampleStoreServiceBinder.unbindFromService();
+    }
+
+    private ListenableFuture<IExampleStoreIterator> runExampleStoreStartQuery(
+            TrainingRun run, Bundle input) {
+        return CallbackToFutureAdapter.getFuture(
+                completer -> {
+                    try {
+                        run.mExampleStoreService.startQuery(
+                                input,
+                                new IExampleStoreCallback.Stub() {
+                                    @Override
+                                    public void onStartQuerySuccess(
+                                            IExampleStoreIterator iterator) {
+                                        LogUtil.d(TAG, "Acquire iterator");
+                                        completer.set(iterator);
+                                    }
+
+                                    @Override
+                                    public void onStartQueryFailure(int errorCode) {
+                                        LogUtil.e(TAG, "Could not acquire iterator: " + errorCode);
+                                        completer.setException(
+                                                new IllegalStateException(
+                                                        "StartQuery failed: " + errorCode));
+                                    }
+                                });
+                    } catch (Exception e) {
+                        completer.setException(e);
+                    }
+                    return "runExampleStoreStartQuery";
+                });
+    }
+
+    private ListenableFuture<IExampleStoreIterator> getExampleStoreIterator(
+            TrainingRun run, String packageName, String taskName, ExampleSelector exampleSelector) {
+        try {
+            run.mTaskName = taskName;
+
+            IExampleStoreService exampleStoreService = getExampleStoreService(packageName);
+            if (exampleStoreService == null) {
+                return Futures.immediateFailedFuture(
+                        new IllegalStateException(
+                                "Could not bind to ExampleStoreService " + packageName));
+            }
+            run.mExampleStoreService = exampleStoreService;
+
+            String collection = exampleSelector.getCollectionUri();
+            byte[] criteria = exampleSelector.getCriteria().toByteArray();
+            byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray();
+            Bundle bundle = new Bundle();
+            bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName());
+            bundle.putString(ClientConstants.EXTRA_COLLECTION_NAME, collection);
+            bundle.putString(ClientConstants.EXTRA_TASK_NAME, taskName);
+            bundle.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, run.mTask.contextData());
+            bundle.putByteArray(
+                    ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
+            bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
+
+            return runExampleStoreStartQuery(run, bundle);
+        } catch (Exception e) {
+            LogUtil.e(TAG, "StartQuery failure: " + e.getMessage());
+            return Futures.immediateFailedFuture(e);
+        }
+    }
+
+    @VisibleForTesting
+    @Nullable
+    IIsolatedTrainingService getIsolatedTrainingService() {
+        mIsolatedTrainingServiceBinder =
+                AbstractServiceBinder.getServiceBinder(
+                        mContext,
+                        ISOLATED_TRAINING_SERVICE_NAME,
+                        IIsolatedTrainingService.Stub::asInterface);
+        Intent intent = new Intent();
+        ComponentName serviceComponent =
+                new ComponentName(mContext.getPackageName(), ISOLATED_TRAINING_SERVICE_NAME);
+        intent.setComponent(serviceComponent);
+        return mIsolatedTrainingServiceBinder.getService(Runnable::run, intent);
+    }
+
+    @VisibleForTesting
+    void unbindFromIsolatedTrainingService() {
+        mIsolatedTrainingServiceBinder.unbindFromService();
+    }
+
+    @VisibleForTesting
+    static class Injector {
+        ExampleConsumptionRecorder getExampleConsumptionRecorder() {
+            return new ExampleConsumptionRecorder();
+        }
     }
 
     private static final class TrainingRun {
         private final int mJobId;
+
+        private String mTaskName;
         private final FederatedTrainingTask mTask;
 
+        @Nullable private ListenableFuture<?> mFuture;
+
+        @Nullable private IIsolatedTrainingService mIsolatedTrainingService = null;
+
+        @Nullable private IExampleStoreService mExampleStoreService = null;
+
         private TrainingRun(int jobId, FederatedTrainingTask task) {
             this.mJobId = jobId;
             this.mTask = task;
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java
index 053e51a..f5c3220 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java
@@ -22,17 +22,17 @@
 import android.app.job.JobService;
 
 import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.common.FederatedComputeExecutors;
 import com.android.federatedcompute.services.common.FlagsFactory;
 
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
 
 /** Main service for the scheduled federated computation jobs. */
 public class FederatedJobService extends JobService {
     private static final String TAG = FederatedJobService.class.getSimpleName();
-    private ListenableFuture<Boolean> mRunCompleteFuture;
 
     @Override
     public boolean onStartJob(JobParameters params) {
@@ -42,19 +42,19 @@
             jobFinished(params, /* wantsReschedule= */ false);
             return true;
         }
-        mRunCompleteFuture =
-                Futures.submit(
-                        () ->
-                                FederatedComputeWorker.getInstance(this)
-                                        .startTrainingRun(params.getJobId()),
-                        FederatedComputeExecutors.getBackgroundExecutor());
+        FederatedComputeWorker worker = FederatedComputeWorker.getInstance(this);
+        ListenableFuture<FLRunnerResult> runCompleteFuture =
+                worker.startTrainingRun(params.getJobId());
 
         Futures.addCallback(
-                mRunCompleteFuture,
-                new FutureCallback<Boolean>() {
+                runCompleteFuture,
+                new FutureCallback<FLRunnerResult>() {
                     @Override
-                    public void onSuccess(Boolean result) {
-                        LogUtil.d(TAG, "federated computation job is done!");
+                    public void onSuccess(FLRunnerResult flRunnerResult) {
+                        LogUtil.d(TAG, "Federated computation job %d is done!", params.getJobId());
+                        if (flRunnerResult != null) {
+                            worker.finish(flRunnerResult);
+                        }
                         jobFinished(params, /* wantsReschedule= */ false);
                     }
 
@@ -62,6 +62,7 @@
                     public void onFailure(Throwable t) {
                         LogUtil.e(
                                 TAG, t, "Failed to handle computation job: %d", params.getJobId());
+                        worker.finish(null, ContributionResult.FAIL, false);
                         jobFinished(params, /* wantsReschedule= */ false);
                     }
                 },
@@ -71,12 +72,8 @@
 
     @Override
     public boolean onStopJob(JobParameters params) {
-        if (mRunCompleteFuture != null) {
-            mRunCompleteFuture.cancel(true);
-        }
-        FederatedComputeWorker.getInstance(this).cancelActiveRun();
-        // Reschedule the job since it's not done. TODO: we should implement specify reschedule
-        // logic instead.
-        return true;
+        LogUtil.d(TAG, "FederatedJobService.onStopJob %d", params.getJobId());
+        FederatedComputeWorker.getInstance(this).finish(null, ContributionResult.FAIL, true);
+        return false;
     }
 }
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java
index 0fd9996..62035e8 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java
@@ -18,6 +18,7 @@
 
 import android.content.Context;
 import android.federatedcompute.aidl.IExampleStoreIterator;
+import android.federatedcompute.common.ClientConstants;
 import android.os.Bundle;
 import android.os.ParcelFileDescriptor;
 import android.os.RemoteException;
@@ -83,7 +84,7 @@
             throw new IllegalArgumentException("ExampleSelector proto is invalid", e);
         }
         String populationName =
-                Objects.requireNonNull(params.getString(Constants.EXTRA_POPULATION_NAME));
+                Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME));
         ParcelFileDescriptor inputCheckpointFd =
                 Objects.requireNonNull(
                         params.getParcelable(
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java b/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java
index 63a448d..d49aadd 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java
@@ -25,18 +25,17 @@
 import android.federatedcompute.aidl.IFederatedComputeCallback;
 import android.federatedcompute.aidl.IResultHandlingService;
 import android.federatedcompute.common.ClientConstants;
-import android.federatedcompute.common.ExampleConsumption;
 import android.os.Bundle;
 
 import com.android.federatedcompute.internal.util.AbstractServiceBinder;
 import com.android.federatedcompute.internal.util.LogUtil;
 import com.android.federatedcompute.services.data.FederatedTrainingTask;
+import com.android.federatedcompute.services.training.util.ComputationResult;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 
-import java.util.ArrayList;
 import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.TimeUnit;
@@ -69,19 +68,16 @@
      * Publishes the training result and example list to client implemented ResultHandlingService.
      */
     public ListenableFuture<CallbackResult> callHandleResult(
-            String taskName,
-            FederatedTrainingTask task,
-            ArrayList<ExampleConsumption> exampleConsumptions,
-            boolean success) {
+            String taskName, FederatedTrainingTask task, ComputationResult result) {
         Bundle input = new Bundle();
         input.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName());
         input.putString(ClientConstants.EXTRA_TASK_NAME, taskName);
         input.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData());
         input.putInt(
                 ClientConstants.EXTRA_COMPUTATION_RESULT,
-                success ? STATUS_SUCCESS : STATUS_TRAINING_FAILED);
+                result.isResultSuccess() ? STATUS_SUCCESS : STATUS_TRAINING_FAILED);
         input.putParcelableArrayList(
-                ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptions);
+                ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, result.getExampleConsumptionList());
 
         try {
             IResultHandlingService resultHandlingService =
@@ -117,10 +113,9 @@
         } catch (Exception e) {
             LogUtil.e(
                     TAG,
-                    String.format(
-                            "ResultHandlingService binding died. population name: %s",
-                            task.populationName()),
-                    e);
+                    e,
+                    "ResultHandlingService binding died. population name: %s",
+                    task.populationName());
             // We publish result to client app with best effort and should not crash flow.
             return Futures.immediateFuture(CallbackResult.FAIL);
         } finally {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java b/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java
index 72a58ae..89d9847 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java
@@ -21,24 +21,24 @@
 import com.google.intelligence.fcp.client.FLRunnerResult;
 import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
 
-import java.util.List;
+import java.util.ArrayList;
 
 /** The result of federated computation. */
 public class ComputationResult {
     private String mOutputCheckpointFile = "";
     private FLRunnerResult mFlRunnerResult = null;
-    private List<ExampleConsumption> mExampleConsumptionList = null;
+    private ArrayList<ExampleConsumption> mExampleConsumptionList = null;
 
     public ComputationResult(
             String outputCheckpointFile,
             FLRunnerResult flRunnerResult,
-            List<ExampleConsumption> exampleConsumptionList) {
+            ArrayList<ExampleConsumption> exampleConsumptionList) {
         this.mOutputCheckpointFile = outputCheckpointFile;
         this.mFlRunnerResult = flRunnerResult;
         this.mExampleConsumptionList = exampleConsumptionList;
     }
 
-    public List<ExampleConsumption> getExampleConsumptionList() {
+    public ArrayList<ExampleConsumption> getExampleConsumptionList() {
         return mExampleConsumptionList;
     }
 
diff --git a/framework/java/android/federatedcompute/common/ClientConstants.java b/framework/java/android/federatedcompute/common/ClientConstants.java
index 6d12749..0886506 100644
--- a/framework/java/android/federatedcompute/common/ClientConstants.java
+++ b/framework/java/android/federatedcompute/common/ClientConstants.java
@@ -26,8 +26,8 @@
     public static final int STATUS_SUCCESS = 0;
     public static final int STATUS_INTERNAL_ERROR = 1;
     public static final int STATUS_TRAINING_FAILED = 2;
-
     public static final String EXTRA_POPULATION_NAME = "android.federatedcompute.population_name";
+
     public static final String EXTRA_COLLECTION_NAME = "android.federatedcompute.collection_name";
 
     public static final String EXTRA_TASK_NAME = "android.federatedcompute.task_name";
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java
index 326fb89..9c0c8ec 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java
@@ -37,7 +37,6 @@
 
 import com.android.federatedcompute.services.common.Clock;
 import com.android.federatedcompute.services.common.Flags;
-import com.android.federatedcompute.services.common.TrainingResult;
 import com.android.federatedcompute.services.data.FederatedTrainingTask;
 import com.android.federatedcompute.services.data.FederatedTrainingTaskDao;
 import com.android.federatedcompute.services.data.FederatedTrainingTaskDbHelper;
@@ -47,6 +46,7 @@
 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
 
 import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
 import com.google.intelligence.fcp.client.engine.TaskRetry;
 
 import org.junit.After;
@@ -92,13 +92,11 @@
                     .build();
     private static final TaskRetry TASK_RETRY =
             TaskRetry.newBuilder().setDelayMin(5000000).setDelayMax(6000000).build();
-
+    private final CountDownLatch mLatch = new CountDownLatch(1);
     private FederatedComputeJobManager mJobManager;
     private Context mContext;
     private FederatedTrainingTaskDao mTrainingTaskDao;
     private boolean mSuccess = false;
-    private final CountDownLatch mLatch = new CountDownLatch(1);
-
     @Mock private Clock mClock;
     @Mock private Flags mMockFlags;
     @Mock private FederatedJobIdGenerator mMockJobIdGenerator;
@@ -724,7 +722,7 @@
                 POPULATION_NAME1,
                 createTrainingIntervalOptionsAsRoot(SchedulingMode.RECURRENT, 0),
                 TASK_RETRY,
-                TrainingResult.SUCCESS);
+                ContributionResult.SUCCESS);
 
         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull();
         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1);
@@ -748,7 +746,7 @@
                 POPULATION_NAME1,
                 createTrainingIntervalOptionsAsRoot(SchedulingMode.ONE_TIME, 0),
                 TASK_RETRY,
-                TrainingResult.SUCCESS);
+                ContributionResult.SUCCESS);
 
         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull();
         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty();
@@ -787,7 +785,7 @@
                         .setDelayMin(serverRetryDelayMillis)
                         .setDelayMax(serverRetryDelayMillis)
                         .build(),
-                TrainingResult.FAIL);
+                ContributionResult.FAIL);
 
         List<FederatedTrainingTask> taskList =
                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
@@ -847,7 +845,7 @@
                         .setDelayMin(minRetryDelayMillis)
                         .setDelayMax(maxRetryDelayMillis)
                         .build(),
-                TrainingResult.SUCCESS);
+                ContributionResult.SUCCESS);
 
         List<FederatedTrainingTask> taskList =
                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
@@ -907,7 +905,7 @@
                         .setDelayMin(serverDefinedIntervalMillis)
                         .setDelayMax(serverDefinedIntervalMillis)
                         .build(),
-                TrainingResult.SUCCESS);
+                ContributionResult.SUCCESS);
 
         List<FederatedTrainingTask> taskList =
                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
@@ -967,7 +965,7 @@
                         .setDelayMin(serverDefinedIntervalMillis)
                         .setDelayMax(serverDefinedIntervalMillis)
                         .build(),
-                TrainingResult.FAIL);
+                ContributionResult.FAIL);
 
         List<FederatedTrainingTask> taskList =
                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java
index b08c1c4..4f52e41 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java
@@ -26,14 +26,10 @@
 import com.google.internal.federated.plan.ExampleSelector;
 import com.google.internal.federated.plan.FederatedExampleQueryIORouter;
 import com.google.internal.federated.plan.TFV1CheckpointAggregation;
+import com.google.internal.federated.plan.TensorflowSpec;
 
 /** The utility class for federated learning related tests. */
 public class TrainingTestUtil {
-    public static final String CLIENT_PACKAGE_NAME = "de.myselph";
-    public static final long RUN_ID = 12345L;
-    public static final String SESSION_NAME = "session_name";
-    public static final String TASK_NAME = "task_name";
-    public static final String POPULATION_NAME = "population_name";
     public static final String STRING_VECTOR_NAME = "vector1";
     public static final String INT_VECTOR_NAME = "vector2";
     public static final String STRING_TENSOR_NAME = "tensor1";
@@ -85,4 +81,18 @@
                         .build();
         return clientOnlyPlan;
     }
+
+    public static ClientOnlyPlan createFakeFederatedLearningClientPlan() {
+        TensorflowSpec tensorflowSpec =
+                TensorflowSpec.newBuilder()
+                        .setDatasetTokenTensorName("dataset")
+                        .addTargetNodeNames("target")
+                        .build();
+        ClientOnlyPlan clientOnlyPlan =
+                ClientOnlyPlan.newBuilder()
+                        .setPhase(
+                                ClientPhase.newBuilder().setTensorflowSpec(tensorflowSpec).build())
+                        .build();
+        return clientOnlyPlan;
+    }
 }
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
index 4bbfca3..9a5a3df 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
@@ -16,108 +16,172 @@
 
 package com.android.federatedcompute.services.training;
 
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static com.android.federatedcompute.services.common.FileUtils.createTempFile;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.util.concurrent.Futures.immediateFailedFuture;
+import static com.google.common.util.concurrent.Futures.immediateFuture;
+
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
-import com.android.federatedcompute.services.common.Flags;
-import com.android.federatedcompute.services.common.TrainingResult;
+import android.content.Context;
+import android.federatedcompute.aidl.IExampleStoreCallback;
+import android.federatedcompute.aidl.IExampleStoreService;
+import android.federatedcompute.common.ClientConstants;
+import android.federatedcompute.common.ExampleConsumption;
+import android.os.Bundle;
+import android.os.RemoteException;
+
+import androidx.test.core.app.ApplicationProvider;
+
+import com.android.federatedcompute.services.common.Constants;
 import com.android.federatedcompute.services.data.FederatedTrainingTask;
 import com.android.federatedcompute.services.data.fbs.SchedulingMode;
 import com.android.federatedcompute.services.data.fbs.SchedulingReason;
 import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
+import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
+import com.android.federatedcompute.services.http.CheckinResult;
+import com.android.federatedcompute.services.http.HttpFederatedProtocol;
 import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
+import com.android.federatedcompute.services.testutils.FakeExampleStoreIterator;
+import com.android.federatedcompute.services.testutils.TrainingTestUtil;
+import com.android.federatedcompute.services.training.ResultCallbackHelper.CallbackResult;
+import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService;
+import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker;
+import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition;
 
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
 import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
+import com.google.intelligence.fcp.client.RetryInfo;
+import com.google.intelligence.fcp.client.engine.TaskRetry;
+import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+import org.tensorflow.example.BytesList;
+import org.tensorflow.example.Example;
+import org.tensorflow.example.Feature;
+import org.tensorflow.example.Features;
+
+import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.concurrent.ExecutionException;
 
 @RunWith(JUnit4.class)
 public final class FederatedComputeWorkerTest {
     private static final int JOB_ID = 1234;
     private static final String POPULATION_NAME = "barPopulation";
-    private static final String SERVER_ADDRESS = "https://server.uri/";
+    private static final String TASK_NAME = "task-id";
     private static final long CREATION_TIME_MS = 10000L;
     private static final long TASK_EARLIEST_NEXT_RUN_TIME_MS = 1234567L;
     private static final String PACKAGE_NAME = "com.android.federatedcompute.services.training";
+    private static final String SERVER_ADDRESS = "https://server.com/";
     private static final byte[] DEFAULT_TRAINING_CONSTRAINTS =
             createTrainingConstraints(true, true, true);
-    private static final long FEDERATED_TRANSIENT_ERROR_RETRY_PERIOD_SECS = 50000;
+    private static final TaskRetry TASK_RETRY =
+            TaskRetry.newBuilder().setRetryToken("foobar").build();
+    private static final CheckinResult FL_CHECKIN_RESULT =
+            new CheckinResult(
+                    createTempFile("input", ".ckp"),
+                    TrainingTestUtil.createFakeFederatedLearningClientPlan(),
+                    TaskAssignment.newBuilder().setTaskName(TASK_NAME).build());
+    private static final CheckinResult FA_CHECKIN_RESULT =
+            new CheckinResult(
+                    createTempFile("input", ".ckp"),
+                    TrainingTestUtil.createFederatedAnalyticClientPlan(),
+                    TaskAssignment.newBuilder().setTaskName(TASK_NAME).build());
+    private static final FLRunnerResult FL_RUNNER_FAILURE_RESULT =
+            FLRunnerResult.newBuilder().setContributionResult(ContributionResult.FAIL).build();
+
+    private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
+            FLRunnerResult.newBuilder()
+                    .setContributionResult(ContributionResult.SUCCESS)
+                    .setRetryInfo(
+                            RetryInfo.newBuilder()
+                                    .setRetryToken(TASK_RETRY.getRetryToken())
+                                    .build())
+                    .build();
     private static final byte[] INTERVAL_OPTIONS = createDefaultTrainingIntervalOptions();
     private static final FederatedTrainingTask FEDERATED_TRAINING_TASK_1 =
             FederatedTrainingTask.builder()
                     .appPackageName(PACKAGE_NAME)
                     .creationTime(CREATION_TIME_MS)
                     .lastScheduledTime(TASK_EARLIEST_NEXT_RUN_TIME_MS)
+                    .serverAddress(SERVER_ADDRESS)
                     .populationName(POPULATION_NAME)
                     .jobId(JOB_ID)
-                    .serverAddress(SERVER_ADDRESS)
                     .intervalOptions(INTERVAL_OPTIONS)
                     .constraints(DEFAULT_TRAINING_CONSTRAINTS)
                     .earliestNextRunTime(TASK_EARLIEST_NEXT_RUN_TIME_MS)
                     .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
                     .build();
-
+    private static final Example EXAMPLE_PROTO_1 =
+            Example.newBuilder()
+                    .setFeatures(
+                            Features.newBuilder()
+                                    .putFeature(
+                                            "feature1",
+                                            Feature.newBuilder()
+                                                    .setBytesList(
+                                                            BytesList.newBuilder()
+                                                                    .addValue(
+                                                                            ByteString.copyFromUtf8(
+                                                                                    "f1_value1")))
+                                                    .build()))
+                    .build();
+    private static final Any FAKE_CRITERIA = Any.newBuilder().setTypeUrl("baz.com").build();
+    private static final String COLLECTION_URI = "app://com.foo.bar//inapp/collection1";
+    private static final ExampleConsumption EXAMPLE_CONSUMPTION_1 =
+            new ExampleConsumption.Builder()
+                    .setCollectionName(COLLECTION_URI)
+                    .setSelectionCriteria(FAKE_CRITERIA.toByteArray())
+                    .setExampleCount(100)
+                    .build();
+    @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+    @Mock TrainingConditionsChecker mTrainingConditionsChecker;
     @Mock FederatedComputeJobManager mMockJobManager;
-    @Mock private Flags mMockFlags;
-    private FederatedComputeWorker mFcpWorker;
-
-    @Before
-    public void doBeforeEachTest() {
-        MockitoAnnotations.initMocks(this);
-        mFcpWorker = new FederatedComputeWorker(mMockJobManager, mMockFlags);
-        doNothing()
-                .when(mMockJobManager)
-                .onTrainingCompleted(anyInt(), anyString(), any(), any(), anyInt());
-        when(mMockFlags.getTransientErrorRetryDelayJitterPercent()).thenReturn(0.1f);
-        when(mMockFlags.getTransientErrorRetryDelaySecs())
-                .thenReturn(FEDERATED_TRANSIENT_ERROR_RETRY_PERIOD_SECS);
-    }
-
-    @Test
-    public void testTrainingSuccess() {
-        when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1);
-        boolean result = mFcpWorker.startTrainingRun(JOB_ID);
-
-        assertTrue(result);
-        verify(mMockJobManager, times(1))
-                .onTrainingCompleted(
-                        eq(JOB_ID), eq(POPULATION_NAME), any(), any(), eq(TrainingResult.SUCCESS));
-    }
-
-    @Test
-    public void testTrainingFailure_nonExist() {
-        when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(null);
-        boolean result = mFcpWorker.startTrainingRun(JOB_ID);
-
-        assertFalse(result);
-        verify(mMockJobManager, times(0))
-                .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), anyInt());
-    }
+    private Context mContext;
+    private FederatedComputeWorker mSpyWorker;
+    @Mock private HttpFederatedProtocol mMockHttpFederatedProtocol;
+    @Mock private ComputationRunner mMockComputationRunner;
+    @Mock private ResultCallbackHelper mMockResultCallbackHelper;
 
     private static byte[] createTrainingConstraints(
             boolean requiresSchedulerIdle,
-            boolean requiresSchedulerCharging,
+            boolean requiresSchedulerBatteryNotLow,
             boolean requiresSchedulerUnmeteredNetwork) {
         FlatBufferBuilder builder = new FlatBufferBuilder();
         builder.finish(
                 TrainingConstraints.createTrainingConstraints(
                         builder,
                         requiresSchedulerIdle,
-                        requiresSchedulerCharging,
+                        requiresSchedulerBatteryNotLow,
                         requiresSchedulerUnmeteredNetwork));
         return builder.sizedByteArray();
     }
@@ -129,4 +193,310 @@
                         builder, SchedulingMode.ONE_TIME, 0));
         return builder.sizedByteArray();
     }
+
+    @Before
+    public void doBeforeEachTest() throws Exception {
+        mContext = ApplicationProvider.getApplicationContext();
+        mSpyWorker =
+                Mockito.spy(
+                        new FederatedComputeWorker(
+                                mContext,
+                                mMockJobManager,
+                                mTrainingConditionsChecker,
+                                mMockComputationRunner,
+                                mMockResultCallbackHelper,
+                                new TestInjector()));
+        when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any()))
+                .thenReturn(EnumSet.noneOf(Condition.class));
+        when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any()))
+                .thenReturn(Futures.immediateFuture(CallbackResult.SUCCESS));
+        when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1);
+        doReturn(mMockHttpFederatedProtocol)
+                .when(mSpyWorker)
+                .getHttpFederatedProtocol(anyString(), anyString());
+        when(mMockComputationRunner.runTaskWithNativeRunner(
+                        anyInt(),
+                        anyString(),
+                        anyString(),
+                        anyString(),
+                        any(),
+                        any(),
+                        any(),
+                        any(),
+                        any()))
+                .thenReturn(FL_RUNNER_SUCCESS_RESULT);
+    }
+
+    @Test
+    public void testJobNonExist_returnsFail() throws Exception {
+        when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(null);
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+
+        assertNull(result);
+        verify(mMockJobManager, times(0))
+                .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), any());
+    }
+
+    @Test
+    public void testTrainingConditionsCheckFailed_returnsFail() throws Exception {
+        when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any()))
+                .thenReturn(ImmutableSet.of(Condition.BATTERY_NOT_OK));
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+
+        assertNull(result);
+        verify(mMockJobManager)
+                .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), any());
+    }
+
+    @Test
+    public void testCheckinFails_throwsException() throws Exception {
+        setUpExampleStoreService();
+
+        doReturn(
+                        immediateFailedFuture(
+                                new ExecutionException(
+                                        "issue checkin failed",
+                                        new IllegalStateException("http 404"))))
+                .when(mMockHttpFederatedProtocol)
+                .issueCheckin();
+        doReturn(FluentFuture.from(immediateFuture(null)))
+                .when(mMockHttpFederatedProtocol)
+                .reportResult(any());
+
+        assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+        mSpyWorker.finish(null, ContributionResult.FAIL, false);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+    }
+
+    @Test
+    public void testReportResultFails_throwsException() throws Exception {
+        setUpExampleStoreService();
+
+        doReturn(immediateFuture(FA_CHECKIN_RESULT))
+                .when(mMockHttpFederatedProtocol)
+                .issueCheckin();
+        doReturn(
+                        FluentFuture.from(
+                                immediateFailedFuture(
+                                        new ExecutionException(
+                                                "report result failed",
+                                                new IllegalStateException("http 404")))))
+                .when(mMockHttpFederatedProtocol)
+                .reportResult(any());
+
+        assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+        mSpyWorker.finish(null, ContributionResult.FAIL, false);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+        verify(mSpyWorker).unbindFromExampleStoreService();
+    }
+
+    @Test
+    public void testBindToExampleStoreFails_throwsException() throws Exception {
+        setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
+
+        // Mock failure bind to ExampleStoreService.
+        doReturn(null).when(mSpyWorker).getExampleStoreService(anyString());
+        doNothing().when(mSpyWorker).unbindFromExampleStoreService();
+
+        assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+        mSpyWorker.finish(null, ContributionResult.FAIL, false);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+        verify(mSpyWorker, times(0)).unbindFromExampleStoreService();
+    }
+
+    @Test
+    public void testRunFAComputationReturnsFailResult() throws Exception {
+        setUpExampleStoreService();
+        setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+        // Mock return failed runner result from native fcp client.
+        when(mMockComputationRunner.runTaskWithNativeRunner(
+                        anyInt(),
+                        anyString(),
+                        anyString(),
+                        anyString(),
+                        any(),
+                        any(),
+                        any(),
+                        any(),
+                        any()))
+                .thenReturn(FL_RUNNER_FAILURE_RESULT);
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+        assertThat(result.getContributionResult()).isEqualTo(ContributionResult.FAIL);
+
+        mSpyWorker.finish(result);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+        verify(mSpyWorker).unbindFromExampleStoreService();
+    }
+
+    @Test
+    public void testPublishToResultHandlingServiceFails_returnsSuccess() throws Exception {
+        setUpExampleStoreService();
+        setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+        // Mock publish to ResultHandlingService fails which is best effort and should not affect
+        // final result.
+        when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any()))
+                .thenReturn(Futures.immediateFuture(CallbackResult.FAIL));
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+        assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+        mSpyWorker.finish(result);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS));
+        verify(mSpyWorker).unbindFromExampleStoreService();
+    }
+
+    @Test
+    public void testPublishToResultHandlingServiceThrowsException_returnsSuccess()
+            throws Exception {
+        setUpExampleStoreService();
+        setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+        // Mock publish to ResultHandlingService throws exception which is best effort and should
+        // not affect final result.
+        when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any()))
+                .thenReturn(
+                        immediateFailedFuture(
+                                new ExecutionException(
+                                        "ResultHandlingService fail",
+                                        new IllegalStateException("can't bind to service"))));
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+        assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+        mSpyWorker.finish(result);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS));
+        verify(mSpyWorker).unbindFromExampleStoreService();
+        verify(mMockResultCallbackHelper).callHandleResult(eq(TASK_NAME), any(), any());
+    }
+
+    @Test
+    public void testRunFAComputation_returnsSuccess() throws Exception {
+        setUpExampleStoreService();
+        setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+        assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+        mSpyWorker.finish(result);
+        verify(mMockJobManager).onTrainingCompleted(anyInt(), anyString(), any(), any(), any());
+    }
+
+    @Test
+    public void testBindToIsolatedTrainingServiceFail_returnsFail() throws Exception {
+        doReturn(immediateFuture(FL_CHECKIN_RESULT))
+                .when(mMockHttpFederatedProtocol)
+                .issueCheckin();
+        setUpExampleStoreService();
+
+        // Mock failure bind to IsolatedTrainingService.
+        doReturn(null).when(mSpyWorker).getIsolatedTrainingService();
+        doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService();
+
+        ExecutionException exception =
+                assertThrows(
+                        ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+        assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class);
+        assertThat(exception.getCause())
+                .hasMessageThat()
+                .isEqualTo("Could not bind to IsolatedTrainingService");
+
+        mSpyWorker.finish(null, ContributionResult.FAIL, false);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+    }
+
+    @Test
+    public void testRunFLComputation_returnsSuccess() throws Exception {
+        setUpExampleStoreService();
+        setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
+
+        // Mock bind to IsolatedTrainingService.
+        doReturn(new FakeIsolatedTrainingService()).when(mSpyWorker).getIsolatedTrainingService();
+        doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService();
+
+        FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
+        assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS);
+
+        mSpyWorker.finish(result);
+        verify(mMockJobManager)
+                .onTrainingCompleted(
+                        anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS));
+        verify(mSpyWorker).unbindFromIsolatedTrainingService();
+        verify(mSpyWorker).unbindFromExampleStoreService();
+    }
+
+    private void setUpExampleStoreService() {
+        TestExampleStoreService testExampleStoreService = new TestExampleStoreService();
+        doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString());
+        doNothing().when(mSpyWorker).unbindFromExampleStoreService();
+    }
+
+    private void setUpHttpFederatedProtocol(CheckinResult checkinResult) {
+        doReturn(immediateFuture(checkinResult)).when(mMockHttpFederatedProtocol).issueCheckin();
+        doReturn(FluentFuture.from(immediateFuture(null)))
+                .when(mMockHttpFederatedProtocol)
+                .reportResult(any());
+    }
+
+    private static class TestExampleStoreService extends IExampleStoreService.Stub {
+        @Override
+        public void startQuery(Bundle params, IExampleStoreCallback callback)
+                throws RemoteException {
+            callback.onStartQuerySuccess(
+                    new FakeExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1.toByteArray())));
+        }
+    }
+
+    private static class TestInjector extends FederatedComputeWorker.Injector {
+        @Override
+        ExampleConsumptionRecorder getExampleConsumptionRecorder() {
+            return new ExampleConsumptionRecorder() {
+                @Override
+                public synchronized ArrayList<ExampleConsumption> finishRecordingAndGet() {
+                    ArrayList<ExampleConsumption> exampleList = new ArrayList<>();
+                    exampleList.add(EXAMPLE_CONSUMPTION_1);
+                    return exampleList;
+                }
+            };
+        }
+    }
+
+    private static final class FakeIsolatedTrainingService extends IIsolatedTrainingService.Stub {
+        @Override
+        public void runFlTraining(Bundle params, ITrainingResultCallback callback)
+                throws RemoteException {
+            Bundle bundle = new Bundle();
+            bundle.putByteArray(
+                    Constants.EXTRA_FL_RUNNER_RESULT, FL_RUNNER_SUCCESS_RESULT.toByteArray());
+            ArrayList<ExampleConsumption> exampleConsumptionList = new ArrayList<>();
+            exampleConsumptionList.add(EXAMPLE_CONSUMPTION_1);
+            bundle.putParcelableArrayList(
+                    ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptionList);
+            callback.onResult(bundle);
+        }
+
+        @Override
+        public void cancelTraining(long runId) {}
+    }
 }
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java
index 3210068..b9e6086 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java
@@ -16,13 +16,16 @@
 
 package com.android.federatedcompute.services.training;
 
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -33,20 +36,39 @@
 import com.android.federatedcompute.services.common.FederatedComputeExecutors;
 import com.android.federatedcompute.services.common.PhFlagsTestUtil;
 
+import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.MoreExecutors;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
+import com.google.intelligence.fcp.client.RetryInfo;
+import com.google.intelligence.fcp.client.engine.TaskRetry;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.mockito.Mock;
 import org.mockito.MockitoSession;
 import org.mockito.quality.Strictness;
 
 @RunWith(JUnit4.class)
 public final class FederatedJobServiceTest {
-    private static final long WAIT_IN_MILLIS = 1_000L;
+    private static final TaskRetry TASK_RETRY =
+            TaskRetry.newBuilder().setRetryToken("foobar").build();
+    private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
+            FLRunnerResult.newBuilder()
+                    .setContributionResult(ContributionResult.SUCCESS)
+                    .setRetryInfo(
+                            RetryInfo.newBuilder()
+                                    .setRetryToken(TASK_RETRY.getRetryToken())
+                                    .build())
+                    .build();
 
     private FederatedJobService mSpyService;
+    @Mock private FederatedComputeWorker mMockWorker;
+
+    private MockitoSession mStaticMockSession;
 
     @Before
     public void setUp() throws Exception {
@@ -56,60 +78,54 @@
         mSpyService = spy(new FederatedJobService());
         doNothing().when(mSpyService).jobFinished(any(), anyBoolean());
         doReturn(mSpyService).when(mSpyService).getApplicationContext();
+        mStaticMockSession =
+                ExtendedMockito.mockitoSession()
+                        .spyStatic(FederatedComputeExecutors.class)
+                        .spyStatic(FederatedComputeWorker.class)
+                        .initMocks(this)
+                        .strictness(Strictness.LENIENT)
+                        .startMocking();
+        ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService())
+                .when(() -> FederatedComputeExecutors.getBackgroundExecutor());
+        ExtendedMockito.doReturn(mMockWorker).when(() -> FederatedComputeWorker.getInstance(any()));
+    }
+
+    @After
+    public void teardown() {
+        if (mStaticMockSession != null) {
+            mStaticMockSession.finishMocking();
+        }
     }
 
     @Test
     public void testOnStartJob() throws Exception {
-        MockitoSession session =
-                ExtendedMockito.mockitoSession()
-                        .spyStatic(FederatedComputeExecutors.class)
-                        .strictness(Strictness.LENIENT)
-                        .startMocking();
-        try {
-            ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService())
-                    .when(FederatedComputeExecutors::getBackgroundExecutor);
+        doReturn(Futures.immediateFuture(FL_RUNNER_SUCCESS_RESULT))
+                .when(mMockWorker)
+                .startTrainingRun(anyInt());
+        doNothing().when(mMockWorker).finish(eq(FL_RUNNER_SUCCESS_RESULT));
 
-            boolean result = mSpyService.onStartJob(mock(JobParameters.class));
+        boolean result = mSpyService.onStartJob(mock(JobParameters.class));
 
-            assertTrue(result);
-            Thread.sleep(WAIT_IN_MILLIS);
-
-            verify(mSpyService, times(1)).jobFinished(any(), anyBoolean());
-        } finally {
-            session.finishMocking();
-        }
+        assertTrue(result);
+        verify(mSpyService, times(1)).jobFinished(any(), anyBoolean());
     }
 
     @Test
     public void testOnStartJobKillSwitch() throws Exception {
         PhFlagsTestUtil.enableGlobalKillSwitch();
-        MockitoSession session =
-                ExtendedMockito.mockitoSession()
-                        .spyStatic(FederatedComputeExecutors.class)
-                        .strictness(Strictness.LENIENT)
-                        .startMocking();
-        try {
-            ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService())
-                    .when(FederatedComputeExecutors::getBackgroundExecutor);
 
-            boolean result = mSpyService.onStartJob(mock(JobParameters.class));
+        boolean result = mSpyService.onStartJob(mock(JobParameters.class));
 
-            assertTrue(result);
-
-            verify(mSpyService, times(1)).jobFinished(any(), eq(false));
-        } finally {
-            session.finishMocking();
-        }
+        assertTrue(result);
+        verify(mMockWorker, never()).startTrainingRun(anyInt());
+        verify(mSpyService, times(1)).jobFinished(any(), eq(false));
     }
 
     @Test
     public void testOnStopJob() {
-        MockitoSession session =
-                ExtendedMockito.mockitoSession().strictness(Strictness.LENIENT).startMocking();
-        try {
-            assertTrue(mSpyService.onStopJob(mock(JobParameters.class)));
-        } finally {
-            session.finishMocking();
-        }
+        doNothing().when(mMockWorker).finish(any(), eq(ContributionResult.FAIL), eq(true));
+
+        // Do not reschedule in JobService. FederatedComputeJobManager will handle it.
+        assertFalse(mSpyService.onStopJob(mock(JobParameters.class)));
     }
 }
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java
index a72eb85..57e14b6 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java
@@ -24,12 +24,10 @@
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.when;
 
-import android.content.Context;
+import android.federatedcompute.common.ClientConstants;
 import android.os.Bundle;
 import android.os.ParcelFileDescriptor;
 
-import androidx.test.core.app.ApplicationProvider;
-
 import com.android.dx.mockito.inline.extended.ExtendedMockito;
 import com.android.federatedcompute.services.common.Constants;
 import com.android.federatedcompute.services.common.FederatedComputeExecutors;
@@ -62,18 +60,11 @@
 @RunWith(JUnit4.class)
 public final class IsolatedTrainingServiceImplTest {
     private static final String POPULATION_NAME = "population_name";
-    private static final String CLIENT_PACKAGE_NAME = "de.myselph";
     private static final long RUN_ID = 12345L;
-    private static final String SESSION_NAME = "session_name";
-    private static final String TASK_NAME = "task_name";
-    private static final String INPUT_CHECKPOINT_FD = "fd:///5";
-    private static final String OUTPUT_CHECKPOINT_FD = "fd:///6";
     private static final FakeExampleStoreIterator FAKE_EXAMPLE_STORE_ITERATOR =
             new FakeExampleStoreIterator(ImmutableList.of());
     private static final ExampleSelector EXAMPLE_SELECTOR =
             ExampleSelector.newBuilder().setCollectionUri("collection_uri").build();
-
-    private final CountDownLatch mLatch = new CountDownLatch(1);
     private static final TaskRetry TASK_RETRY =
             TaskRetry.newBuilder().setRetryToken("foobar").build();
     private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
@@ -84,7 +75,6 @@
                                     .setRetryToken(TASK_RETRY.getRetryToken())
                                     .build())
                     .build();
-
     private static final FLRunnerResult FL_RUNNER_FAIL_RESULT =
             FLRunnerResult.newBuilder()
                     .setContributionResult(ContributionResult.FAIL)
@@ -93,11 +83,9 @@
                                     .setRetryToken(TASK_RETRY.getRetryToken())
                                     .build())
                     .build();
-
-    private final Context mContext = ApplicationProvider.getApplicationContext();
+    private final CountDownLatch mLatch = new CountDownLatch(1);
     private IsolatedTrainingServiceImpl mIsolatedTrainingService;
     private Bundle mCallbackResult;
-    private int mCallbackErrorCode;
     @Mock private ComputationRunner mComputationRunner;
     private MockitoSession mStaticMockSession;
     private ParcelFileDescriptor mInputCheckpointFd;
@@ -159,7 +147,7 @@
     @Test
     public void runFlTrainingMissingExampleSelector_returnsFailure() throws Exception {
         Bundle bundle = new Bundle();
-        bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+        bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
         bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
         bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
         bundle.putBinder(
@@ -173,7 +161,7 @@
     @Test
     public void runFlTrainingInvalidExampleSelector_returnsFailure() throws Exception {
         Bundle bundle = new Bundle();
-        bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+        bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
         bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
         bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
         bundle.putBinder(
@@ -189,7 +177,7 @@
     @Test
     public void runFlTrainingNullPlan_returnsFailure() throws Exception {
         Bundle bundle = new Bundle();
-        bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+        bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
         bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
         bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
         bundle.putBinder(
@@ -211,13 +199,12 @@
 
     private Bundle buildInputBundle() throws Exception {
         Bundle bundle = new Bundle();
-        bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME);
+        bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME);
         bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd);
         bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd);
         bundle.putByteArray(Constants.EXTRA_EXAMPLE_SELECTOR, EXAMPLE_SELECTOR.toByteArray());
         bundle.putBinder(
                 Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, FAKE_EXAMPLE_STORE_ITERATOR);
-
         ClientOnlyPlan clientOnlyPlan = TrainingTestUtil.createFederatedAnalyticClientPlan();
         bundle.putByteArray(Constants.EXTRA_CLIENT_ONLY_PLAN, clientOnlyPlan.toByteArray());
         return bundle;
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java
index 9a58277..725cebf 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java
@@ -41,8 +41,11 @@
 import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
 import com.android.federatedcompute.services.training.ResultCallbackHelper.CallbackResult;
+import com.android.federatedcompute.services.training.util.ComputationResult;
 
 import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.intelligence.fcp.client.FLRunnerResult;
+import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
 
 import org.junit.Before;
 import org.junit.Test;
@@ -51,7 +54,6 @@
 import org.mockito.Mockito;
 
 import java.util.ArrayList;
-import java.util.concurrent.CountDownLatch;
 
 @RunWith(JUnit4.class)
 public final class ResultCallbackHelperTest {
@@ -60,10 +62,10 @@
     private static final int SCHEDULING_REASON = SchedulingReason.SCHEDULING_REASON_NEW_TASK;
     private static final String POPULATION_NAME = "population_name";
     private static final String TASK_NAME = "task_name";
-    private static final int JOB_ID = 123;
     private static final byte[] INTERVAL_OPTIONS = createDefaultTrainingIntervalOptions();
     private static final byte[] TRAINING_CONSTRAINTS = createDefaultTrainingConstraints();
-    private final CountDownLatch mLatch = new CountDownLatch(1);
+    private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
+            FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build();
 
     private static final FederatedTrainingTask TRAINING_TASK =
             FederatedTrainingTask.builder()
@@ -80,11 +82,14 @@
                     .build();
     private static final ArrayList<ExampleConsumption> EXAMPLE_CONSUMPTIONS =
             getExampleConsumptions();
+    private ComputationResult mComputationResult;
     private ResultCallbackHelper mHelper;
 
     @Before
     public void setUp() {
         Context context = ApplicationProvider.getApplicationContext();
+        mComputationResult =
+                new ComputationResult("output", FL_RUNNER_SUCCESS_RESULT, EXAMPLE_CONSUMPTIONS);
         mHelper = Mockito.spy(new ResultCallbackHelper(context));
         doNothing().when(mHelper).unbindFromResultHandlingService();
     }
@@ -96,8 +101,7 @@
                 .getResultHandlingService(eq(PACKAGE_NAME));
 
         CallbackResult result =
-                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
-                        .get();
+                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
 
         assertThat(result).isEqualTo(CallbackResult.SUCCESS);
     }
@@ -109,8 +113,7 @@
                 .getResultHandlingService(eq(PACKAGE_NAME));
 
         CallbackResult result =
-                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
-                        .get();
+                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
 
         assertThat(result).isEqualTo(CallbackResult.FAIL);
     }
@@ -122,8 +125,7 @@
                 .getResultHandlingService(eq(PACKAGE_NAME));
 
         CallbackResult result =
-                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
-                        .get();
+                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
 
         assertThat(result).isEqualTo(CallbackResult.FAIL);
     }
@@ -135,8 +137,7 @@
                 .getResultHandlingService(eq(PACKAGE_NAME));
 
         CallbackResult result =
-                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, EXAMPLE_CONSUMPTIONS, true)
-                        .get();
+                mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get();
 
         assertThat(result).isEqualTo(CallbackResult.FAIL);
     }