| /* |
| * Copyright (C) 2023 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package com.android.federatedcompute.services.scheduling; |
| |
| import static com.google.common.truth.Truth.assertThat; |
| import static com.google.common.truth.Truth.assertWithMessage; |
| |
| import static org.mockito.ArgumentMatchers.any; |
| import static org.mockito.ArgumentMatchers.anyString; |
| import static org.mockito.Mockito.when; |
| |
| import static java.lang.Math.min; |
| |
| import android.app.job.JobInfo; |
| import android.app.job.JobScheduler; |
| import android.content.ComponentName; |
| import android.content.Context; |
| import android.federatedcompute.aidl.IFederatedComputeCallback; |
| import android.federatedcompute.common.TrainingInterval; |
| import android.federatedcompute.common.TrainingOptions; |
| |
| import androidx.test.core.app.ApplicationProvider; |
| |
| import com.android.federatedcompute.services.common.Clock; |
| import com.android.federatedcompute.services.common.Flags; |
| import com.android.federatedcompute.services.data.FederatedTrainingTask; |
| import com.android.federatedcompute.services.data.FederatedTrainingTaskDao; |
| import com.android.federatedcompute.services.data.FederatedTrainingTaskDbHelper; |
| import com.android.federatedcompute.services.data.fbs.SchedulingMode; |
| import com.android.federatedcompute.services.data.fbs.SchedulingReason; |
| import com.android.federatedcompute.services.data.fbs.TrainingConstraints; |
| import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; |
| |
| import com.google.flatbuffers.FlatBufferBuilder; |
| import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; |
| import com.google.intelligence.fcp.client.engine.TaskRetry; |
| |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.mockito.Mock; |
| import org.mockito.junit.MockitoJUnitRunner; |
| |
| import java.nio.ByteBuffer; |
| import java.util.List; |
| import java.util.concurrent.CountDownLatch; |
| |
| import javax.annotation.Nullable; |
| |
| @RunWith(MockitoJUnitRunner.class) |
| public final class FederatedComputeJobManagerTest { |
| private static final String CALLING_PACKAGE_NAME = "callingPkg"; |
| private static final String POPULATION_NAME1 = "population1"; |
| private static final String POPULATION_NAME2 = "population2"; |
| private static final String SERVER_ADDRESS = "https://server.uri/"; |
| private static final int JOB_ID1 = 700000001; |
| private static final int JOB_ID2 = 700000002; |
| private static final long DEFAULT_SCHEDULING_PERIOD_SECS = 1234; |
| private static final long DEFAULT_SCHEDULING_PERIOD_MILLIS = |
| DEFAULT_SCHEDULING_PERIOD_SECS * 1000; |
| private static final long MAX_SCHEDULING_PERIOD_SECS = 912000; |
| private static final long MAX_SCHEDULING_INTERVAL_SECS_FOR_FEDERATED_COMPUTATION = 604800L; |
| |
| private static final String TRAINING_JOB_SERVICE = |
| "com.android.federatedcompute.services.training.FederatedJobService"; |
| private static final long CURRENT_TIME_MILLIS = 1000L; |
| private static final byte[] DEFAULT_CONSTRAINTS = createDefaultTrainingConstraints(); |
| private static final TrainingOptions OPTIONS1 = |
| new TrainingOptions.Builder() |
| .setPopulationName(POPULATION_NAME1) |
| .setServerAddress(SERVER_ADDRESS) |
| .build(); |
| private static final TrainingOptions OPTIONS2 = |
| new TrainingOptions.Builder() |
| .setPopulationName(POPULATION_NAME2) |
| .setServerAddress(SERVER_ADDRESS) |
| .build(); |
| private static final TaskRetry TASK_RETRY = |
| TaskRetry.newBuilder().setDelayMin(5000000).setDelayMax(6000000).build(); |
| private final CountDownLatch mLatch = new CountDownLatch(1); |
| private FederatedComputeJobManager mJobManager; |
| private Context mContext; |
| private FederatedTrainingTaskDao mTrainingTaskDao; |
| private boolean mSuccess = false; |
| @Mock private Clock mClock; |
| @Mock private Flags mMockFlags; |
| @Mock private FederatedJobIdGenerator mMockJobIdGenerator; |
| private JobScheduler mJobScheduler; |
| |
| @Before |
| public void setUp() { |
| mContext = ApplicationProvider.getApplicationContext(); |
| mJobScheduler = mContext.getSystemService(JobScheduler.class); |
| mJobScheduler.cancelAll(); |
| mTrainingTaskDao = FederatedTrainingTaskDao.getInstanceForTest(mContext); |
| mJobManager = |
| new FederatedComputeJobManager( |
| mContext, |
| mTrainingTaskDao, |
| mMockJobIdGenerator, |
| new JobSchedulerHelper(mClock), |
| mClock, |
| mMockFlags); |
| when(mClock.currentTimeMillis()).thenReturn(CURRENT_TIME_MILLIS); |
| when(mMockFlags.getDefaultSchedulingPeriodSecs()) |
| .thenReturn(DEFAULT_SCHEDULING_PERIOD_SECS); |
| when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()) |
| .thenReturn(MAX_SCHEDULING_INTERVAL_SECS_FOR_FEDERATED_COMPUTATION); |
| when(mMockFlags.getMinSchedulingIntervalSecsForFederatedComputation()).thenReturn(1L); |
| when(mMockFlags.getMaxSchedulingPeriodSecs()).thenReturn(MAX_SCHEDULING_PERIOD_SECS); |
| when(mMockJobIdGenerator.generateJobId(any(), anyString())) |
| .thenReturn(JOB_ID1) |
| .thenReturn(JOB_ID2); |
| } |
| |
| @After |
| public void tearDown() { |
| // Manually clean up the database. |
| mTrainingTaskDao.clearDatabase(); |
| FederatedTrainingTaskDbHelper dbHelper = |
| FederatedTrainingTaskDbHelper.getInstanceForTest(mContext); |
| dbHelper.getWritableDatabase().close(); |
| dbHelper.getReadableDatabase().close(); |
| dbHelper.close(); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalledSuccess() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L).thenReturn(2000L); |
| |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| assertThat(mSuccess).isTrue(); |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| assertThat(taskList) |
| .containsExactly( |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) |
| .build()); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalled_firstTime() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| // Make three onTrainerStart calls, each with different job ID and session name. |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS2, new TestFederatedComputeCallback()); |
| mLatch.await(); |
| |
| assertThat(mSuccess).isTrue(); |
| // verify training tasks in database. |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| assertThat(taskList) |
| .containsExactly( |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .build(), |
| basicFLTrainingTaskBuilder(JOB_ID2, POPULATION_NAME2, null) |
| .creationTime(2000L) |
| .lastScheduledTime(2000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .earliestNextRunTime(2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .build()); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(2); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS)); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID2), |
| buildExpectedJobInfo(JOB_ID2, DEFAULT_SCHEDULING_PERIOD_MILLIS)); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalledFL_withIntervalSmallerThanDefaultInterval() |
| throws Exception { |
| testOnTrainerStartCalledFLWithInterval( |
| /* userDefinedIntervalMillis= */ 1000000, /* defaultIntervalMillis= */ 2000000); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalledFL_withIntervalLargerThanDefaultInterval() |
| throws Exception { |
| testOnTrainerStartCalledFLWithInterval( |
| /* userDefinedIntervalMillis= */ 2000000, /* defaultIntervalMillis= */ 1000000); |
| } |
| |
| private void testOnTrainerStartCalledFLWithInterval( |
| long userDefinedIntervalMillis, long defaultIntervalMillis) throws Exception { |
| when(mMockFlags.getDefaultSchedulingPeriodSecs()).thenReturn(defaultIntervalMillis / 1000); |
| TrainingOptions trainerOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(userDefinedIntervalMillis) |
| .build()) |
| .build(); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); |
| |
| byte[] trainingIntervalOptions = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); |
| |
| long expectedInterval = min(userDefinedIntervalMillis, defaultIntervalMillis); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, trainingIntervalOptions) |
| .earliestNextRunTime(CURRENT_TIME_MILLIS + expectedInterval) |
| .lastScheduledTime(CURRENT_TIME_MILLIS) |
| .creationTime(CURRENT_TIME_MILLIS) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .build(); |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| |
| assertThat(taskList).containsExactly(expectedTask); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, expectedInterval)); |
| } |
| |
| /** |
| * Tests onTrainerStart being called multiple times with the same parameters (the common |
| * expected use case). |
| * |
| * <p>After the first call, most fields in the task (like creation time, earliest next run time, |
| * etc.) must be preserved, and only certain fields (like last scheduled time) should be |
| * updated. |
| */ |
| @Test |
| public void testOnTrainerStartCalled_multipleTimes_sameParams() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| when(mClock.currentTimeMillis()).thenReturn(3000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) |
| .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) |
| .lastScheduledTime(3000L) |
| .creationTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS)); |
| } |
| |
| /** |
| * Tests when the user specified interval is larger than the maximum server specified interval, |
| * multiple scheduling with same user specified interval will not be incorrectly capped at the |
| * maximum server specified interval. |
| */ |
| @Test |
| public void testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalLargerThanServerMax() |
| throws Exception { |
| long minIntervalMills = 10000L; // 10 seconds |
| // Maximum server specified interval is 5 seconds |
| when(mMockFlags.getMaxSchedulingPeriodSecs()).thenReturn(5L); |
| TrainingOptions trainingOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(minIntervalMills) |
| .build()) |
| .build(); |
| |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); |
| |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); |
| |
| when(mClock.currentTimeMillis()).thenReturn(3000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| byte[] expectedInterval = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, minIntervalMills); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) |
| .earliestNextRunTime(1000 + minIntervalMills) |
| .lastScheduledTime(3000L) |
| .creationTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, minIntervalMills)); |
| } |
| |
| /** |
| * Tests when a task got scheduled with the same set of parameters multiple times, the brella |
| * defined max for user specified interval has been lowered between the multiple scheduling |
| * events, the user specified interval should be always guarded with the latest max. |
| */ |
| @Test |
| public void testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalDifferentMax() |
| throws Exception { |
| // Initial max 20 seconds is larger than the user specified interval. |
| when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()).thenReturn(20L); |
| long minIntervalMills = 10000L; // 10 seconds |
| TrainingOptions trainingOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(minIntervalMills) |
| .build()) |
| .build(); |
| |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| byte[] expectedInterval = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, minIntervalMills); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) |
| .earliestNextRunTime(1000L + minIntervalMills) |
| .lastScheduledTime(1000L) |
| .creationTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, minIntervalMills)); |
| |
| // Now lower allowed max for the user specified interval |
| long newMaxSec = 5L; |
| long newMinIntervalMills = newMaxSec * 1000; |
| when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()) |
| .thenReturn(newMaxSec); |
| TrainingOptions newTrainingOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(newMinIntervalMills) |
| .build()) |
| .build(); |
| |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, newTrainingOptions, new TestFederatedComputeCallback()); |
| |
| taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| expectedInterval = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, newMinIntervalMills); |
| expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) |
| .earliestNextRunTime(2000L + newMinIntervalMills) |
| .lastScheduledTime(2000L) |
| .creationTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, newMinIntervalMills)); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalled_fLCustomerSpecifiedIntervalSmallerThanDefinedMin() |
| throws Exception { |
| when(mMockFlags.getDefaultSchedulingPeriodSecs()).thenReturn(2000L); |
| long minTrainingIntervalSecByFederatedCompute = 1800L; |
| long minTrainingIntervalMillsByFederatedCompute = |
| minTrainingIntervalSecByFederatedCompute * 1000; |
| when(mMockFlags.getMinSchedulingIntervalSecsForFederatedComputation()) |
| .thenReturn(minTrainingIntervalSecByFederatedCompute); |
| |
| TrainingOptions trainingOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(1000L) |
| .build()) |
| .build(); |
| |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| byte[] expectedInterval = createTrainingIntervalOptions(SchedulingMode.RECURRENT, 1000L); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) |
| .earliestNextRunTime(1000L + minTrainingIntervalMillsByFederatedCompute) |
| .lastScheduledTime(1000L) |
| .creationTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, minTrainingIntervalMillsByFederatedCompute)); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalled_trainingIntervalChange_FL() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(), |
| new TestFederatedComputeCallback()); |
| |
| long minTrainingIntervalMillis = 60000L; |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(minTrainingIntervalMillis) |
| .build()) |
| .build(), |
| new TestFederatedComputeCallback()); |
| byte[] trainingInterval = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, minTrainingIntervalMillis); |
| verifyTaskAndJobAfterIntervalChange( |
| trainingInterval, 1000, 2000, minTrainingIntervalMillis); |
| |
| long newInterval = 70000L; |
| when(mClock.currentTimeMillis()).thenReturn(3000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(newInterval) |
| .build()) |
| .build(), |
| new TestFederatedComputeCallback()); |
| byte[] trainingIntervalOption2 = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, newInterval); |
| // Verify the creation time not changed, modified time is set to now, and the min interval |
| // is set to the new interval. |
| verifyTaskAndJobAfterIntervalChange(trainingIntervalOption2, 1000, 3000, newInterval); |
| |
| // Change to default training interval {one_time, interval 0}. |
| when(mClock.currentTimeMillis()).thenReturn(4000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_ONE_TIME) |
| .build()) |
| .build(), |
| new TestFederatedComputeCallback()); |
| byte[] trainingIntervalOption3 = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0L); |
| // Verify the creation time not changed, modified time is set to now, and the min interval |
| // is set to the new interval. |
| verifyTaskAndJobAfterIntervalChange( |
| trainingIntervalOption3, 1000, 4000, DEFAULT_SCHEDULING_PERIOD_MILLIS); |
| } |
| |
| private void verifyTaskAndJobAfterIntervalChange( |
| @Nullable byte[] trainingIntervalOptions, |
| long createTimeMillis, |
| long modifyTimeMillis, |
| long expectedIntervalMillis) |
| throws Exception { |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, trainingIntervalOptions) |
| .earliestNextRunTime(modifyTimeMillis + expectedIntervalMillis) |
| .lastScheduledTime(modifyTimeMillis) |
| .creationTime(createTimeMillis) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, expectedIntervalMillis)); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalled_multipleTimes_changingPopulationName() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| TrainingOptions options1 = |
| new TrainingOptions.Builder() |
| .setPopulationName(POPULATION_NAME1) |
| .setServerAddress(SERVER_ADDRESS) |
| .build(); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, options1, new TestFederatedComputeCallback()); |
| |
| // Pass in a new population name and We will assign new job id since population name |
| // changes. |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| TrainingOptions options2 = |
| new TrainingOptions.Builder() |
| .setPopulationName(POPULATION_NAME2) |
| .setServerAddress(SERVER_ADDRESS) |
| .build(); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, options2, new TestFederatedComputeCallback()); |
| |
| // Verify two training tasks in database. |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| assertThat(taskList) |
| .containsExactly( |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .build(), |
| basicFLTrainingTaskBuilder(JOB_ID2, POPULATION_NAME2, null) |
| .creationTime(2000L) |
| .lastScheduledTime(2000L) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .earliestNextRunTime(2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .build()); |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(2); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS)); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID2), |
| buildExpectedJobInfo(JOB_ID2, DEFAULT_SCHEDULING_PERIOD_MILLIS)); |
| } |
| |
| @Test |
| public void testOnTrainerStartCalled_twoJobsWithSamePopulationName() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| TrainingOptions options1 = |
| new TrainingOptions.Builder() |
| .setPopulationName(POPULATION_NAME1) |
| .setServerAddress(SERVER_ADDRESS) |
| .build(); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, options1, new TestFederatedComputeCallback()); |
| |
| // For same population, we will reuse the job id assigned to the previous task. |
| when(mClock.currentTimeMillis()).thenReturn(2000L); |
| TrainingOptions options2 = |
| new TrainingOptions.Builder() |
| .setPopulationName(POPULATION_NAME1) |
| .setServerAddress(SERVER_ADDRESS) |
| .build(); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, options2, new TestFederatedComputeCallback()); |
| |
| // Verify only task in database. |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| long earliestNextRunTimeMillis = 1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS; |
| |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) |
| .earliestNextRunTime(earliestNextRunTimeMillis) |
| .lastScheduledTime(2000L) |
| .creationTime(1000L) |
| .constraints(DEFAULT_CONSTRAINTS) |
| .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) |
| .intervalOptions(createDefaultTrainingInterval()) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS)); |
| } |
| |
| @Test |
| public void testOnTrainingStarted_doesNotExist() throws Exception { |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| FederatedTrainingTask taskToRun = mJobManager.onTrainingStarted(JOB_ID1); |
| |
| // No task should be found. |
| assertThat(taskToRun).isNull(); |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| assertThat(taskList).isEmpty(); |
| } |
| |
| @Test |
| public void testOnTrainingStarted_taskTtling_noTtlSet() throws Exception { |
| // Set task TTL to 0, which should disable TTLing. |
| when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(0L); |
| |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| // Simulate attempting to run a task a lot later. This should not fail, b/c we're not yet |
| // past the TTL threshold. |
| assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); |
| } |
| |
| @Test |
| public void testOnTrainingStarted_taskTtling() throws Exception { |
| // Set task TTL to 1 second. |
| when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(1L); |
| |
| when(mClock.currentTimeMillis()).thenReturn(1000L); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| // Simulate attempting to run a task one second later. This should not fail, b/c we're not |
| // yet |
| // past the TTL threshold. |
| long nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); |
| |
| // Now reschedule again, should keep the task alive for another second. |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| // The task should again still be alive a second later. |
| nowMillis = 3000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); |
| |
| // Now move forward one millisecond. The task should now get TTLd. |
| nowMillis = 3001; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull(); |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); |
| } |
| |
| @Test |
| public void testRescheduleFLTask_success() throws Exception { |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingStarted(JOB_ID1); |
| |
| nowMillis = 3000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingCompleted( |
| JOB_ID1, |
| POPULATION_NAME1, |
| createTrainingIntervalOptionsAsRoot(SchedulingMode.RECURRENT, 0), |
| TASK_RETRY, |
| ContributionResult.SUCCESS); |
| |
| assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); |
| } |
| |
| @Test |
| public void testRescheduleFLTask_oneoff_success() throws Exception { |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingStarted(JOB_ID1); |
| |
| nowMillis = 3000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingCompleted( |
| JOB_ID1, |
| POPULATION_NAME1, |
| createTrainingIntervalOptionsAsRoot(SchedulingMode.ONE_TIME, 0), |
| TASK_RETRY, |
| ContributionResult.SUCCESS); |
| |
| assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull(); |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); |
| } |
| |
| @Test |
| public void testRescheduleFLTask_didnotContribute_oneOff() throws Exception { |
| long serverRetryDelayMillis = 5000_000; |
| |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| TrainingOptions trainerOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_ONE_TIME) |
| .build()) |
| .build(); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingStarted(JOB_ID1); |
| |
| nowMillis = 3000; |
| byte[] intervalOptions = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0); |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingCompleted( |
| JOB_ID1, |
| POPULATION_NAME1, |
| TrainingIntervalOptions.getRootAsTrainingIntervalOptions( |
| ByteBuffer.wrap(intervalOptions)), |
| TaskRetry.newBuilder() |
| .setDelayMin(serverRetryDelayMillis) |
| .setDelayMax(serverRetryDelayMillis) |
| .build(), |
| ContributionResult.FAIL); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .lastRunStartTime(2000L) |
| .lastRunEndTime(3000L) |
| .schedulingReason( |
| SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) |
| .earliestNextRunTime(3000 + serverRetryDelayMillis) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, serverRetryDelayMillis)); |
| } |
| |
| /** Reschedule a recurrent fl task with the user defined interval. */ |
| @Test |
| public void testRescheduleFLTask_success_recurrent_userDefinedInterval() throws Exception { |
| // The user defined interval is larger than the server specified interval. |
| long minRetryDelayMillis = 3000_000; |
| long maxRetryDelayMillis = 3000_000; |
| long userDefinedIntervalMillis = 4000_000; |
| TrainingOptions trainerOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(userDefinedIntervalMillis) |
| .build()) |
| .build(); |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingStarted(JOB_ID1); |
| |
| nowMillis = 3000; |
| byte[] intervalOptions = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingCompleted( |
| JOB_ID1, |
| POPULATION_NAME1, |
| TrainingIntervalOptions.getRootAsTrainingIntervalOptions( |
| ByteBuffer.wrap(intervalOptions)), |
| TaskRetry.newBuilder() |
| .setDelayMin(minRetryDelayMillis) |
| .setDelayMax(maxRetryDelayMillis) |
| .build(), |
| ContributionResult.SUCCESS); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted() |
| .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted() |
| .schedulingReason( |
| SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) |
| .earliestNextRunTime(3000 + userDefinedIntervalMillis) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, userDefinedIntervalMillis)); |
| } |
| |
| @Test |
| public void testRescheduleFLTask_recurrent_serverDefinedInterval() throws Exception { |
| // Define a server returned interval which is larger than the user defined interval |
| long serverDefinedIntervalMillis = 4000_000; |
| long userDefinedIntervalMillis = 3000_000; |
| |
| TrainingOptions trainerOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(userDefinedIntervalMillis) |
| .build()) |
| .build(); |
| |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingStarted(JOB_ID1); |
| |
| nowMillis = 3000; |
| byte[] intervalOptions = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingCompleted( |
| JOB_ID1, |
| POPULATION_NAME1, |
| TrainingIntervalOptions.getRootAsTrainingIntervalOptions( |
| ByteBuffer.wrap(intervalOptions)), |
| TaskRetry.newBuilder() |
| .setDelayMin(serverDefinedIntervalMillis) |
| .setDelayMax(serverDefinedIntervalMillis) |
| .build(), |
| ContributionResult.SUCCESS); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted() |
| .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted() |
| .schedulingReason( |
| SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) |
| .earliestNextRunTime(3000 + serverDefinedIntervalMillis) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis)); |
| } |
| |
| @Test |
| public void testRescheduleFLTask_recurrent_didnotContribute() throws Exception { |
| // Define a server returned interval which is larger than the user defined interval |
| long serverDefinedIntervalMillis = 4000_000; |
| long userDefinedIntervalMillis = 3000_000; |
| |
| TrainingOptions trainerOptions = |
| basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) |
| .setTrainingInterval( |
| new TrainingInterval.Builder() |
| .setSchedulingMode( |
| TrainingInterval.SCHEDULING_MODE_RECURRENT) |
| .setMinimumIntervalMillis(userDefinedIntervalMillis) |
| .build()) |
| .build(); |
| |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingStarted(JOB_ID1); |
| |
| nowMillis = 3000; |
| byte[] intervalOptions = |
| createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainingCompleted( |
| JOB_ID1, |
| POPULATION_NAME1, |
| TrainingIntervalOptions.getRootAsTrainingIntervalOptions( |
| ByteBuffer.wrap(intervalOptions)), |
| TaskRetry.newBuilder() |
| .setDelayMin(serverDefinedIntervalMillis) |
| .setDelayMax(serverDefinedIntervalMillis) |
| .build(), |
| ContributionResult.FAIL); |
| |
| List<FederatedTrainingTask> taskList = |
| mTrainingTaskDao.getFederatedTrainingTask(null, null); |
| FederatedTrainingTask expectedTask = |
| basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) |
| .creationTime(1000L) |
| .lastScheduledTime(1000L) |
| .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted() |
| .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted() |
| .schedulingReason( |
| SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) |
| .earliestNextRunTime(3000 + serverDefinedIntervalMillis) |
| .build(); |
| assertThat(taskList).containsExactly(expectedTask); |
| |
| assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); |
| assertJobInfosMatch( |
| mJobScheduler.getPendingJob(JOB_ID1), |
| buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis)); |
| } |
| |
| @Test |
| public void testOnTrainerStopCalled_withoutOnTrainerStartCalled() throws Exception { |
| // Should not fail, even if onTrainerStartCalled was never called. |
| mJobManager.onTrainerStopCalled( |
| CALLING_PACKAGE_NAME, POPULATION_NAME1, new TestFederatedComputeCallback()); |
| |
| // No task should exist, nor should a job have been scheduled. |
| assertThat(mSuccess).isTrue(); |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); |
| assertThat(mJobScheduler.getAllPendingJobs()).isEmpty(); |
| } |
| |
| @Test |
| public void testOnTrainerStopCalled_afterOnTrainerStartCalled() throws Exception { |
| // After a cycle of onTrainerStartCalled -> onTrainerStopCalled there should be no pending |
| // jobs. |
| long nowMillis = 1000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStartCalled( |
| CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); |
| |
| nowMillis = 2000; |
| when(mClock.currentTimeMillis()).thenReturn(nowMillis); |
| mJobManager.onTrainerStopCalled( |
| CALLING_PACKAGE_NAME, POPULATION_NAME1, new TestFederatedComputeCallback()); |
| |
| // No task should exist, nor should a job be scheduled anymore |
| assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); |
| assertThat(mJobScheduler.getAllPendingJobs()).isEmpty(); |
| } |
| |
| /** |
| * Helper for checking that two JobInfos match, since JobInfos unfortunately can't be compared |
| * directly. |
| */ |
| public static void assertJobInfosMatch(JobInfo pendingJob, JobInfo expectedJobInfo) { |
| // Compare most of JobInfo's properties that may be set by our code. |
| assertWithMessage("id").that(pendingJob.getId()).isEqualTo(expectedJobInfo.getId()); |
| assertWithMessage("service") |
| .that(pendingJob.getService()) |
| .isEqualTo(expectedJobInfo.getService()); |
| assertWithMessage("persisted") |
| .that(pendingJob.isPersisted()) |
| .isEqualTo(expectedJobInfo.isPersisted()); |
| assertWithMessage("networkType") |
| .that(pendingJob.getNetworkType()) |
| .isEqualTo(expectedJobInfo.getNetworkType()); |
| assertWithMessage("requireDeviceIdle") |
| .that(pendingJob.isRequireDeviceIdle()) |
| .isEqualTo(expectedJobInfo.isRequireDeviceIdle()); |
| assertWithMessage("requireBatteryNotLow") |
| .that(pendingJob.isRequireBatteryNotLow()) |
| .isEqualTo(expectedJobInfo.isRequireBatteryNotLow()); |
| assertWithMessage("minLatencyMillis") |
| .that(pendingJob.getMinLatencyMillis()) |
| .isEqualTo(expectedJobInfo.getMinLatencyMillis()); |
| assertWithMessage("maxExecutionDelayMillis") |
| .that(pendingJob.getMaxExecutionDelayMillis()) |
| .isEqualTo(expectedJobInfo.getMaxExecutionDelayMillis()); |
| } |
| |
| private static TrainingOptions.Builder basicFLOptionsBuilder(int jobId, String population) { |
| return new TrainingOptions.Builder() |
| .setPopulationName(population) |
| .setServerAddress(SERVER_ADDRESS); |
| } |
| |
| private JobInfo buildExpectedJobInfo(int jobId, long minLatencyMillis) { |
| JobInfo.Builder jobInfo = |
| new JobInfo.Builder( |
| jobId, |
| new ComponentName(mContext.getPackageName(), TRAINING_JOB_SERVICE)) |
| .setPersisted(true) |
| .setRequiresDeviceIdle(true) |
| // the latency should be capped. |
| .setMinimumLatency(minLatencyMillis) |
| .setRequiresBatteryNotLow(true); |
| jobInfo.setRequiredNetworkType(JobInfo.NETWORK_TYPE_UNMETERED); |
| |
| return jobInfo.build(); |
| } |
| |
| private FederatedTrainingTask.Builder basicFLTrainingTaskBuilder( |
| int jobId, String population, @Nullable byte[] trainingIntervalOptions) { |
| FederatedTrainingTask.Builder builder = |
| FederatedTrainingTask.builder() |
| .jobId(jobId) |
| .populationName(population) |
| .lastScheduledTime(0L) |
| .lastRunStartTime(0L) |
| .lastRunEndTime(0L) |
| .constraints(DEFAULT_CONSTRAINTS) |
| .serverAddress(SERVER_ADDRESS) |
| .appPackageName(CALLING_PACKAGE_NAME); |
| if (trainingIntervalOptions != null) { |
| builder.intervalOptions(trainingIntervalOptions); |
| } |
| return builder; |
| } |
| |
| private static TrainingIntervalOptions createTrainingIntervalOptionsAsRoot( |
| int schedulingMode, long intervalMillis) { |
| byte[] intervalOptions = createTrainingIntervalOptions(schedulingMode, intervalMillis); |
| return TrainingIntervalOptions.getRootAsTrainingIntervalOptions( |
| ByteBuffer.wrap(intervalOptions)); |
| } |
| |
| private static byte[] createTrainingIntervalOptions(int schedulingMode, long intervalMillis) { |
| FlatBufferBuilder builder = new FlatBufferBuilder(); |
| builder.finish( |
| TrainingIntervalOptions.createTrainingIntervalOptions( |
| builder, schedulingMode, intervalMillis)); |
| return builder.sizedByteArray(); |
| } |
| |
| private static byte[] createDefaultTrainingInterval() { |
| return createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0); |
| } |
| |
| private static byte[] createDefaultTrainingConstraints() { |
| FlatBufferBuilder builder = new FlatBufferBuilder(); |
| builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true)); |
| return builder.sizedByteArray(); |
| } |
| |
| class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub { |
| @Override |
| public void onSuccess() { |
| mSuccess = true; |
| mLatch.countDown(); |
| } |
| |
| @Override |
| public void onFailure(int errorCode) { |
| mLatch.countDown(); |
| } |
| } |
| } |