blob: 992aeedceb541bdfc23a1ea3b1c601da147727f8 [file] [log] [blame]
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.healthconnect.storage.request;
import static android.health.connect.datatypes.AggregationType.AVG;
import static android.health.connect.datatypes.AggregationType.COUNT;
import static android.health.connect.datatypes.AggregationType.MAX;
import static android.health.connect.datatypes.AggregationType.MIN;
import static android.health.connect.datatypes.AggregationType.SUM;
import static com.android.server.healthconnect.storage.datatypehelpers.RecordHelper.APP_INFO_ID_COLUMN_NAME;
import android.annotation.NonNull;
import android.database.Cursor;
import android.health.connect.AggregateResult;
import android.health.connect.Constants;
import android.health.connect.LocalTimeRangeFilter;
import android.health.connect.TimeRangeFilter;
import android.health.connect.TimeRangeFilterHelper;
import android.health.connect.datatypes.AggregationType;
import android.util.ArrayMap;
import android.util.Pair;
import android.util.Slog;
import com.android.server.healthconnect.storage.TransactionManager;
import com.android.server.healthconnect.storage.datatypehelpers.AppInfoHelper;
import com.android.server.healthconnect.storage.datatypehelpers.RecordHelper;
import com.android.server.healthconnect.storage.datatypehelpers.aggregation.PriorityRecordsAggregator;
import com.android.server.healthconnect.storage.utils.OrderByClause;
import com.android.server.healthconnect.storage.utils.SqlJoin;
import com.android.server.healthconnect.storage.utils.StorageUtils;
import com.android.server.healthconnect.storage.utils.WhereClauses;
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.Period;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* A request for {@link TransactionManager} to query the DB for aggregation results
*
* @hide
*/
public class AggregateTableRequest {
private static final String TAG = "HealthConnectAggregate";
private static final String GROUP_BY_COLUMN_NAME = "category";
private final long DEFAULT_TIME = -1;
private final String mTableName;
private final List<String> mColumnNamesToAggregate;
private final AggregationType<?> mAggregationType;
private final RecordHelper<?> mRecordHelper;
private final Map<Integer, AggregateResult<?>> mAggregateResults = new ArrayMap<>();
private final String mTimeColumnName;
// Additional column used for time filtering. End time for interval records,
// null for other records.
private final String mEndTimeColumnName;
private final SqlJoin mSqlJoin;
private List<Long> mPackageFilters;
private long mStartTime = DEFAULT_TIME;
private long mEndTime = DEFAULT_TIME;
private String mPackageColumnName;
private String mGroupByColumnName;
private int mGroupBySize = 1;
private final List<String> mAdditionalColumnsToFetch;
private final AggregateParams.PriorityAggregationExtraParams mPriorityParams;
private final boolean mUseLocalTime;
private List<Long> mTimeSplits;
public AggregateTableRequest(
AggregateParams params,
AggregationType<?> aggregationType,
RecordHelper<?> recordHelper,
boolean useLocalTime) {
mTableName = params.getTableName();
mColumnNamesToAggregate = params.getColumnsToFetch();
mTimeColumnName = params.getTimeColumnName();
mAggregationType = aggregationType;
mRecordHelper = recordHelper;
mSqlJoin = params.getJoin();
mPriorityParams = params.getPriorityAggregationExtraParams();
mEndTimeColumnName = params.getExtraTimeColumnName();
mAdditionalColumnsToFetch = new ArrayList<>();
mAdditionalColumnsToFetch.add(params.getTimeOffsetColumnName());
mAdditionalColumnsToFetch.add(mTimeColumnName);
if (mEndTimeColumnName != null) {
mAdditionalColumnsToFetch.add(mEndTimeColumnName);
}
mUseLocalTime = useLocalTime;
}
/**
* @return {@link AggregationType} for this request
*/
public AggregationType<?> getAggregationType() {
return mAggregationType;
}
/**
* @return {@link RecordHelper} for this request
*/
public RecordHelper<?> getRecordHelper() {
return mRecordHelper;
}
/**
* @return results fetched after performing aggregate operation for this class.
* <p>Note: Only available after the call to {@link
* TransactionManager#populateWithAggregation} has been made
*/
public List<AggregateResult<?>> getAggregateResults() {
List<AggregateResult<?>> aggregateResults = new ArrayList<>(mGroupBySize);
for (int i = 0; i < mGroupBySize; i++) {
aggregateResults.add(mAggregateResults.get(i));
}
return aggregateResults;
}
/** Returns SQL statement to get data origins for the aggregation operation */
public String getCommandToFetchAggregateMetadata() {
final StringBuilder builder = new StringBuilder("SELECT DISTINCT ");
builder.append(APP_INFO_ID_COLUMN_NAME).append(", ");
return appendAggregateCommand(builder, /* isMetadata= */ true);
}
/** Returns SQL statement to perform aggregation operation */
@NonNull
public String getAggregationCommand() {
final StringBuilder builder = new StringBuilder("SELECT ");
String aggCommand;
boolean usingPriority =
StorageUtils.supportsPriority(
mRecordHelper.getRecordIdentifier(),
mAggregationType.getAggregateOperationType())
|| StorageUtils.isDerivedType(mRecordHelper.getRecordIdentifier());
if (usingPriority) {
for (String columnName : mColumnNamesToAggregate) {
builder.append(columnName).append(", ");
}
} else {
aggCommand = getSqlCommandFor(mAggregationType.getAggregateOperationType());
for (String columnName : mColumnNamesToAggregate) {
builder.append(aggCommand)
.append("(")
.append(columnName)
.append(")")
.append(" as ")
.append(columnName)
.append(", ");
}
}
if (mAdditionalColumnsToFetch != null) {
for (String additionalColumnToFetch : mAdditionalColumnsToFetch) {
builder.append(additionalColumnToFetch).append(", ");
}
}
return appendAggregateCommand(builder, usingPriority);
}
public AggregateTableRequest setPackageFilter(
List<Long> packageFilters, String packageColumnName) {
mPackageFilters = packageFilters;
mPackageColumnName = packageColumnName;
return this;
}
/** Sets time filter for table request. */
public AggregateTableRequest setTimeFilter(long startTime, long endTime) {
// Return if the params will result in no impact on the query
if (startTime < 0 || endTime < startTime) {
return this;
}
mStartTime = startTime;
mEndTime = endTime;
mTimeSplits = List.of(mStartTime, mEndTime);
return this;
}
/** Sets group by fields. */
public void setGroupBy(
String columnName, Period period, Duration duration, TimeRangeFilter timeRangeFilter) {
mGroupByColumnName = columnName;
if (period != null) {
mTimeSplits = getGroupSplitsForPeriod(timeRangeFilter, period);
} else if (duration != null) {
mTimeSplits = getGroupSplitsForDuration(timeRangeFilter, duration);
} else {
throw new IllegalArgumentException(
"Either aggregation period or duration should be not null");
}
mGroupBySize = mTimeSplits.size() - 1;
if (Constants.DEBUG) {
Slog.d(
TAG,
"Group aggregation splits: "
+ mTimeSplits
+ " number of groups: "
+ mGroupBySize);
}
}
public void onResultsFetched(Cursor cursor, Cursor metaDataCursor) {
if (StorageUtils.isDerivedType(mRecordHelper.getRecordIdentifier())) {
deriveAggregate(cursor);
} else if (StorageUtils.supportsPriority(
mRecordHelper.getRecordIdentifier(),
mAggregationType.getAggregateOperationType())) {
processPriorityRequest(cursor);
} else {
processNoPrioritiesRequest(cursor);
}
updateResultWithDataOriginPackageNames(metaDataCursor);
}
private void processPriorityRequest(Cursor cursor) {
List<Long> priorityList =
StorageUtils.getAppIdPriorityList(mRecordHelper.getRecordIdentifier());
PriorityRecordsAggregator aggregator =
new PriorityRecordsAggregator(
mTimeSplits,
priorityList,
mAggregationType.getAggregationTypeIdentifier(),
mPriorityParams,
mUseLocalTime);
aggregator.calculateAggregation(cursor);
AggregateResult<?> result;
for (int groupNumber = 0; groupNumber < mGroupBySize; groupNumber++) {
if (aggregator.getResultForGroup(groupNumber) == null) {
continue;
}
if (mAggregationType.getAggregateResultClass() == Long.class) {
result =
new AggregateResult<>(
aggregator.getResultForGroup(groupNumber).longValue());
} else {
result = new AggregateResult<>(aggregator.getResultForGroup(groupNumber));
}
mAggregateResults.put(
groupNumber,
result.setZoneOffset(aggregator.getZoneOffsetForGroup(groupNumber)));
}
if (Constants.DEBUG) {
Slog.d(TAG, "Priority aggregation result: " + mAggregateResults);
}
}
private void processNoPrioritiesRequest(Cursor cursor) {
while (cursor.moveToNext()) {
mAggregateResults.put(
StorageUtils.getCursorInt(cursor, GROUP_BY_COLUMN_NAME),
mRecordHelper.getAggregateResult(cursor, mAggregationType));
}
}
private static String getSqlCommandFor(@AggregationType.AggregateOperationType int type) {
return switch (type) {
case MAX -> "MAX";
case MIN -> "MIN";
case AVG -> "AVG";
case SUM -> "SUM";
case COUNT -> "COUNT";
default -> null;
};
}
private String appendAggregateCommand(StringBuilder builder, boolean isMetadata) {
boolean useGroupBy = mGroupByColumnName != null && !isMetadata;
if (useGroupBy) {
builder.append(" CASE ");
int groupByIndex = 0;
for (int i = 0; i < mTimeSplits.size() - 1; i++) {
builder.append(" WHEN ")
.append(mTimeColumnName)
.append(" >= ")
.append(mTimeSplits.get(i))
.append(" AND ")
.append(mTimeColumnName)
.append(" < ")
.append(mTimeSplits.get(i + 1))
.append(" THEN ")
.append(groupByIndex++);
}
builder.append(" END " + GROUP_BY_COLUMN_NAME + " ");
} else {
builder.setLength(builder.length() - 2); // Remove the last 2 char i.e. ", "
}
builder.append(" FROM ").append(mTableName);
if (mSqlJoin != null) {
builder.append(mSqlJoin.getJoinCommand());
}
builder.append(buildAggregationWhereCondition());
if (useGroupBy) {
builder.append(" GROUP BY " + GROUP_BY_COLUMN_NAME);
}
OrderByClause orderByClause = new OrderByClause();
orderByClause.addOrderByClause(mTimeColumnName, true);
builder.append(orderByClause.getOrderBy());
if (Constants.DEBUG) {
Slog.d(TAG, "Aggregation origin query: " + builder);
}
return builder.toString();
}
private String buildAggregationWhereCondition() {
WhereClauses whereClauses = new WhereClauses();
whereClauses.addWhereInLongsClause(mPackageColumnName, mPackageFilters);
if (mEndTimeColumnName != null) {
// Filter all records which overlap with time filter interval:
// recordStartTime < filterEndTime and recordEndTime >= filterStartTime
whereClauses.addWhereGreaterThanOrEqualClause(mEndTimeColumnName, mStartTime);
} else {
whereClauses.addWhereGreaterThanOrEqualClause(mTimeColumnName, mStartTime);
}
whereClauses.addWhereLessThanClause(mTimeColumnName, mEndTime);
return whereClauses.get(/* withWhereKeyword= */ true);
}
private void updateResultWithDataOriginPackageNames(Cursor metaDataCursor) {
List<Long> packageIds = new ArrayList<>();
while (metaDataCursor.moveToNext()) {
packageIds.add(StorageUtils.getCursorLong(metaDataCursor, APP_INFO_ID_COLUMN_NAME));
}
List<String> packageNames = AppInfoHelper.getInstance().getPackageNames(packageIds);
mAggregateResults.replaceAll(
(n, v) -> mAggregateResults.get(n).setDataOrigins(packageNames));
}
public List<Pair<Long, Long>> getGroupSplitIntervals() {
List<Pair<Long, Long>> groupIntervals = new ArrayList<>();
long previous = mTimeSplits.get(0);
for (int i = 1; i < mTimeSplits.size(); i++) {
Pair<Long, Long> pair = new Pair<>(previous, mTimeSplits.get(i));
groupIntervals.add(pair);
previous = mTimeSplits.get(i);
}
return groupIntervals;
}
private List<Long> getGroupSplitsForPeriod(TimeRangeFilter timeFilter, Period period) {
LocalDateTime filterStart = ((LocalTimeRangeFilter) timeFilter).getStartTime();
LocalDateTime filterEnd = ((LocalTimeRangeFilter) timeFilter).getEndTime();
List<Long> splits = new ArrayList<>();
splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(filterStart));
LocalDateTime currentEnd = filterStart.plus(period);
while (!currentEnd.isAfter(filterEnd)) {
splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(currentEnd));
currentEnd = currentEnd.plus(period);
}
// If the last group doesn't fit the rest of the window, we cut it up to filterEnd
if (splits.get(splits.size() - 1) < TimeRangeFilterHelper.getMillisOfLocalTime(filterEnd)) {
splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(filterEnd));
}
return splits;
}
private List<Long> getGroupSplitsForDuration(
TimeRangeFilter timeRangeFilter, Duration duration) {
long groupByStart = TimeRangeFilterHelper.getFilterStartTimeMillis(timeRangeFilter);
long groupByEnd = TimeRangeFilterHelper.getFilterEndTimeMillis(timeRangeFilter);
long groupDurationMillis = duration.toMillis();
List<Long> splits = new ArrayList<>();
splits.add(groupByStart);
long currentEnd = groupByStart + groupDurationMillis;
while (currentEnd <= groupByEnd) {
splits.add(currentEnd);
currentEnd += groupDurationMillis;
}
// If the last group doesn't fit the rest of the window, we cut it up to filterEnd
if (splits.get(splits.size() - 1) < groupByEnd) {
splits.add(groupByEnd);
}
return splits;
}
private void deriveAggregate(Cursor cursor) {
double[] derivedAggregateArray = mRecordHelper.deriveAggregate(cursor, this);
int index = 0;
cursor.moveToFirst();
for (double aggregate : derivedAggregateArray) {
mAggregateResults.put(
index, mRecordHelper.getAggregateResult(cursor, mAggregationType, aggregate));
index++;
}
}
}