blob: 4d47f92f1d7e8311796ca750c184bd5cfd776c5b [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.voiceinteraction;
import static android.service.voice.HotwordDetectionService.AUDIO_SOURCE_EXTERNAL;
import static android.service.voice.HotwordDetectionService.AUDIO_SOURCE_MICROPHONE;
import static android.service.voice.HotwordDetectionService.INITIALIZATION_STATUS_UNKNOWN;
import static android.service.voice.HotwordDetectionService.KEY_INITIALIZATION_STATUS;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.content.ComponentName;
import android.content.ContentCaptureOptions;
import android.content.Context;
import android.content.Intent;
import android.hardware.soundtrigger.IRecognitionStatusCallback;
import android.hardware.soundtrigger.SoundTrigger;
import android.media.AudioAttributes;
import android.media.AudioFormat;
import android.media.AudioManager;
import android.media.AudioRecord;
import android.media.MediaRecorder;
import android.os.Bundle;
import android.os.IBinder;
import android.os.IRemoteCallback;
import android.os.ParcelFileDescriptor;
import android.os.PersistableBundle;
import android.os.RemoteException;
import android.os.ServiceManager;
import android.os.SharedMemory;
import android.service.voice.HotwordDetectedResult;
import android.service.voice.HotwordDetectionService;
import android.service.voice.HotwordRejectedResult;
import android.service.voice.IDspHotwordDetectionCallback;
import android.service.voice.IHotwordDetectionService;
import android.service.voice.IMicrophoneHotwordDetectionVoiceInteractionCallback;
import android.util.Pair;
import android.util.Slog;
import android.view.contentcapture.IContentCaptureManager;
import com.android.internal.annotations.GuardedBy;
import com.android.internal.app.IHotwordRecognitionStatusCallback;
import com.android.internal.infra.AndroidFuture;
import com.android.internal.infra.ServiceConnector;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A class that provides the communication with the HotwordDetectionService.
*/
final class HotwordDetectionConnection {
private static final String TAG = "HotwordDetectionConnection";
// TODO (b/177502877): Set the Debug flag to false before shipping.
private static final boolean DEBUG = true;
// Number of bytes per sample of audio (which is a short).
private static final int BYTES_PER_SAMPLE = 2;
// TODO: These constants need to be refined.
private static final long VALIDATION_TIMEOUT_MILLIS = 3000;
private static final long VOICE_INTERACTION_TIMEOUT_TO_OPEN_MIC_MILLIS = 2000;
private static final int MAX_STREAMING_SECONDS = 10;
private static final int MICROPHONE_BUFFER_LENGTH_SECONDS = 8;
private static final int HOTWORD_AUDIO_LENGTH_SECONDS = 3;
private static final long MAX_UPDATE_TIMEOUT_MILLIS = 6000;
private final Executor mAudioCopyExecutor = Executors.newCachedThreadPool();
// TODO: This may need to be a Handler(looper)
private final ScheduledExecutorService mScheduledExecutorService =
Executors.newSingleThreadScheduledExecutor();
private final AtomicBoolean mUpdateStateFinish = new AtomicBoolean(false);
final Object mLock;
final ComponentName mDetectionComponentName;
final int mUser;
final Context mContext;
final @NonNull ServiceConnector<IHotwordDetectionService> mRemoteHotwordDetectionService;
boolean mBound;
@GuardedBy("mLock")
private ParcelFileDescriptor mCurrentAudioSink;
HotwordDetectionConnection(Object lock, Context context, ComponentName serviceName,
int userId, boolean bindInstantServiceAllowed, @Nullable PersistableBundle options,
@Nullable SharedMemory sharedMemory, IHotwordRecognitionStatusCallback callback) {
mLock = lock;
mContext = context;
mDetectionComponentName = serviceName;
mUser = userId;
final Intent intent = new Intent(HotwordDetectionService.SERVICE_INTERFACE);
intent.setComponent(mDetectionComponentName);
mRemoteHotwordDetectionService = new ServiceConnector.Impl<IHotwordDetectionService>(
mContext, intent, bindInstantServiceAllowed ? Context.BIND_ALLOW_INSTANT : 0, mUser,
IHotwordDetectionService.Stub::asInterface) {
@Override // from ServiceConnector.Impl
protected void onServiceConnectionStatusChanged(IHotwordDetectionService service,
boolean connected) {
if (DEBUG) {
Slog.d(TAG, "onServiceConnectionStatusChanged connected = " + connected);
}
synchronized (mLock) {
mBound = connected;
}
}
@Override
protected long getAutoDisconnectTimeoutMs() {
return -1;
}
};
mRemoteHotwordDetectionService.connect();
if (callback == null) {
updateStateLocked(options, sharedMemory);
return;
}
updateStateWithCallbackLocked(options, sharedMemory, callback);
updateContentCaptureManager();
}
private void updateStateWithCallbackLocked(PersistableBundle options,
SharedMemory sharedMemory, IHotwordRecognitionStatusCallback callback) {
if (DEBUG) {
Slog.d(TAG, "updateStateWithCallbackLocked");
}
mRemoteHotwordDetectionService.postAsync(service -> {
AndroidFuture<Void> future = new AndroidFuture<>();
IRemoteCallback statusCallback = new IRemoteCallback.Stub() {
@Override
public void sendResult(Bundle bundle) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "updateState finish");
}
future.complete(null);
try {
if (mUpdateStateFinish.getAndSet(true)) {
Slog.w(TAG, "call callback after timeout");
return;
}
int status = bundle != null ? bundle.getInt(
KEY_INITIALIZATION_STATUS,
INITIALIZATION_STATUS_UNKNOWN)
: INITIALIZATION_STATUS_UNKNOWN;
// Add the protection to avoid unexpected status
if (status > HotwordDetectionService.getMaxCustomInitializationStatus()
&& status != INITIALIZATION_STATUS_UNKNOWN) {
status = INITIALIZATION_STATUS_UNKNOWN;
}
callback.onStatusReported(status);
} catch (RemoteException e) {
Slog.w(TAG, "Failed to report initialization status: " + e);
}
}
};
try {
service.updateState(options, sharedMemory, statusCallback);
} catch (RemoteException e) {
// TODO: (b/181842909) Report an error to voice interactor
Slog.w(TAG, "Failed to updateState for HotwordDetectionService", e);
}
return future;
}).orTimeout(MAX_UPDATE_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS)
.whenComplete((res, err) -> {
if (err instanceof TimeoutException) {
Slog.w(TAG, "updateState timed out");
try {
if (mUpdateStateFinish.getAndSet(true)) {
return;
}
callback.onStatusReported(INITIALIZATION_STATUS_UNKNOWN);
} catch (RemoteException e) {
Slog.w(TAG, "Failed to report initialization status: " + e);
}
} else if (err != null) {
Slog.w(TAG, "Failed to update state: " + err);
} else {
// NOTE: so far we don't need to take any action.
}
});
}
private void updateContentCaptureManager() {
IBinder b = ServiceManager
.getService(Context.CONTENT_CAPTURE_MANAGER_SERVICE);
IContentCaptureManager binderService = IContentCaptureManager.Stub.asInterface(b);
mRemoteHotwordDetectionService.post(
service -> service.updateContentCaptureManager(binderService,
new ContentCaptureOptions(null)));
}
private boolean isBound() {
synchronized (mLock) {
return mBound;
}
}
void cancelLocked() {
if (DEBUG) {
Slog.d(TAG, "cancelLocked");
}
if (mBound) {
mRemoteHotwordDetectionService.unbind();
mBound = false;
}
}
void updateStateLocked(PersistableBundle options, SharedMemory sharedMemory) {
mRemoteHotwordDetectionService.run(
service -> service.updateState(options, sharedMemory, null /* callback */));
}
void startListeningFromMic(
AudioFormat audioFormat,
IMicrophoneHotwordDetectionVoiceInteractionCallback callback) {
if (DEBUG) {
Slog.d(TAG, "startListeningFromMic");
}
// TODO: consider making this a non-anonymous class.
IDspHotwordDetectionCallback internalCallback = new IDspHotwordDetectionCallback.Stub() {
@Override
public void onDetected(HotwordDetectedResult result) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onDetected");
}
callback.onDetected(result, null, null);
}
@Override
public void onRejected(HotwordRejectedResult result) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onRejected");
}
// onRejected isn't allowed here
}
};
mRemoteHotwordDetectionService.run(
service -> service.detectFromMicrophoneSource(
null,
AUDIO_SOURCE_MICROPHONE,
null,
null,
internalCallback));
}
public void startListeningFromExternalSource(
ParcelFileDescriptor audioStream,
AudioFormat audioFormat,
@Nullable PersistableBundle options,
IMicrophoneHotwordDetectionVoiceInteractionCallback callback) {
if (DEBUG) {
Slog.d(TAG, "startListeningFromExternalSource");
}
handleExternalSourceHotwordDetection(
audioStream,
audioFormat,
options,
callback);
}
void stopListening() {
if (DEBUG) {
Slog.d(TAG, "stopListening");
}
mRemoteHotwordDetectionService.run(service -> service.stopDetection());
synchronized (mLock) {
if (mCurrentAudioSink != null) {
Slog.i(TAG, "Closing audio stream to hotword detector: stopping requested");
bestEffortClose(mCurrentAudioSink);
}
mCurrentAudioSink = null;
}
}
void triggerHardwareRecognitionEventForTestLocked(
SoundTrigger.KeyphraseRecognitionEvent event,
IHotwordRecognitionStatusCallback callback) {
if (DEBUG) {
Slog.d(TAG, "triggerHardwareRecognitionEventForTestLocked");
}
detectFromDspSourceForTest(event, callback);
}
private void detectFromDspSourceForTest(SoundTrigger.KeyphraseRecognitionEvent recognitionEvent,
IHotwordRecognitionStatusCallback externalCallback) {
if (DEBUG) {
Slog.d(TAG, "detectFromDspSourceForTest");
}
IDspHotwordDetectionCallback internalCallback = new IDspHotwordDetectionCallback.Stub() {
@Override
public void onDetected(HotwordDetectedResult result) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onDetected");
}
externalCallback.onKeyphraseDetected(recognitionEvent, result);
}
@Override
public void onRejected(HotwordRejectedResult result) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onRejected");
}
externalCallback.onRejected(result);
}
};
mRemoteHotwordDetectionService.run(
service -> service.detectFromDspSource(
recognitionEvent,
recognitionEvent.getCaptureFormat(),
VALIDATION_TIMEOUT_MILLIS,
internalCallback));
}
private void detectFromDspSource(SoundTrigger.KeyphraseRecognitionEvent recognitionEvent,
IHotwordRecognitionStatusCallback externalCallback) {
if (DEBUG) {
Slog.d(TAG, "detectFromDspSource");
}
// TODO: consider making this a non-anonymous class.
IDspHotwordDetectionCallback internalCallback = new IDspHotwordDetectionCallback.Stub() {
@Override
public void onDetected(HotwordDetectedResult result) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onDetected");
}
externalCallback.onKeyphraseDetected(recognitionEvent, result);
}
@Override
public void onRejected(HotwordRejectedResult result) throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onRejected");
}
externalCallback.onRejected(result);
}
};
mRemoteHotwordDetectionService.run(
service -> service.detectFromDspSource(
recognitionEvent,
recognitionEvent.getCaptureFormat(),
VALIDATION_TIMEOUT_MILLIS,
internalCallback));
}
static final class SoundTriggerCallback extends IRecognitionStatusCallback.Stub {
private SoundTrigger.KeyphraseRecognitionEvent mRecognitionEvent;
private final HotwordDetectionConnection mHotwordDetectionConnection;
private final IHotwordRecognitionStatusCallback mExternalCallback;
SoundTriggerCallback(IHotwordRecognitionStatusCallback callback,
HotwordDetectionConnection connection) {
mHotwordDetectionConnection = connection;
mExternalCallback = callback;
}
@Override
public void onKeyphraseDetected(SoundTrigger.KeyphraseRecognitionEvent recognitionEvent)
throws RemoteException {
if (DEBUG) {
Slog.d(TAG, "onKeyphraseDetected recognitionEvent : " + recognitionEvent);
}
final boolean useHotwordDetectionService = mHotwordDetectionConnection != null
&& mHotwordDetectionConnection.isBound();
if (useHotwordDetectionService) {
mRecognitionEvent = recognitionEvent;
mHotwordDetectionConnection.detectFromDspSource(
recognitionEvent, mExternalCallback);
} else {
mExternalCallback.onKeyphraseDetected(recognitionEvent, null);
}
}
@Override
public void onGenericSoundTriggerDetected(
SoundTrigger.GenericRecognitionEvent recognitionEvent)
throws RemoteException {
mExternalCallback.onGenericSoundTriggerDetected(recognitionEvent);
}
@Override
public void onError(int status) throws RemoteException {
mExternalCallback.onError(status);
}
@Override
public void onRecognitionPaused() throws RemoteException {
mExternalCallback.onRecognitionPaused();
}
@Override
public void onRecognitionResumed() throws RemoteException {
mExternalCallback.onRecognitionResumed();
}
}
// TODO: figure out if we need to let the client configure some of the parameters.
private static AudioRecord createAudioRecord(
@NonNull SoundTrigger.KeyphraseRecognitionEvent recognitionEvent) {
int sampleRate = recognitionEvent.getCaptureFormat().getSampleRate();
return new AudioRecord(
new AudioAttributes.Builder()
.setInternalCapturePreset(MediaRecorder.AudioSource.HOTWORD).build(),
recognitionEvent.getCaptureFormat(),
getBufferSizeInBytes(
sampleRate,
MAX_STREAMING_SECONDS,
recognitionEvent.getCaptureFormat().getChannelCount()),
recognitionEvent.getCaptureSession());
}
@Nullable
private AudioRecord createMicAudioRecord(AudioFormat audioFormat) {
if (DEBUG) {
Slog.i(TAG, "#createAudioRecord");
}
try {
AudioRecord audioRecord = new AudioRecord(
new AudioAttributes.Builder()
.setInternalCapturePreset(MediaRecorder.AudioSource.HOTWORD).build(),
audioFormat,
getBufferSizeInBytes(
audioFormat.getSampleRate(),
MICROPHONE_BUFFER_LENGTH_SECONDS,
audioFormat.getChannelCount()),
AudioManager.AUDIO_SESSION_ID_GENERATE);
if (audioRecord.getState() != AudioRecord.STATE_INITIALIZED) {
Slog.w(TAG, "Failed to initialize AudioRecord");
audioRecord.release();
return null;
}
return audioRecord;
} catch (IllegalArgumentException e) {
Slog.e(TAG, "Failed to create AudioRecord", e);
return null;
}
}
@Nullable
private AudioRecord createFakeAudioRecord() {
if (DEBUG) {
Slog.i(TAG, "#createFakeAudioRecord");
}
try {
AudioRecord audioRecord = new AudioRecord.Builder()
.setAudioFormat(new AudioFormat.Builder()
.setSampleRate(32000)
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setChannelMask(AudioFormat.CHANNEL_IN_MONO).build())
.setAudioAttributes(new AudioAttributes.Builder()
.setInternalCapturePreset(MediaRecorder.AudioSource.HOTWORD).build())
.setBufferSizeInBytes(
AudioRecord.getMinBufferSize(32000,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_16BIT) * 2)
.build();
if (audioRecord.getState() != AudioRecord.STATE_INITIALIZED) {
Slog.w(TAG, "Failed to initialize AudioRecord");
audioRecord.release();
return null;
}
return audioRecord;
} catch (IllegalArgumentException e) {
Slog.e(TAG, "Failed to create AudioRecord", e);
}
return null;
}
/**
* Returns the number of bytes required to store {@code bufferLengthSeconds} of audio sampled at
* {@code sampleRate} Hz, using the format returned by DSP audio capture.
*/
private static int getBufferSizeInBytes(
int sampleRate, int bufferLengthSeconds, int intChannelCount) {
return BYTES_PER_SAMPLE * sampleRate * bufferLengthSeconds * intChannelCount;
}
private static Pair<ParcelFileDescriptor, ParcelFileDescriptor> createPipe() {
ParcelFileDescriptor[] fileDescriptors;
try {
fileDescriptors = ParcelFileDescriptor.createPipe();
} catch (IOException e) {
Slog.e(TAG, "Failed to create audio stream pipe", e);
return null;
}
return Pair.create(fileDescriptors[0], fileDescriptors[1]);
}
public void dump(String prefix, PrintWriter pw) {
pw.print(prefix); pw.print("mBound="); pw.println(mBound);
}
private interface AudioReader extends Closeable {
int read(byte[] dest, int offset, int length) throws IOException;
static AudioReader createFromInputStream(InputStream is) {
return new AudioReader() {
@Override
public int read(byte[] dest, int offset, int length) throws IOException {
return is.read(dest, offset, length);
}
@Override
public void close() throws IOException {
is.close();
}
};
}
static AudioReader createFromAudioRecord(AudioRecord record) {
record.startRecording();
return new AudioReader() {
@Override
public int read(byte[] dest, int offset, int length) throws IOException {
return record.read(dest, offset, length);
}
@Override
public void close() throws IOException {
record.stop();
record.release();
}
};
}
}
private void handleExternalSourceHotwordDetection(
ParcelFileDescriptor audioStream,
AudioFormat audioFormat,
@Nullable PersistableBundle options,
IMicrophoneHotwordDetectionVoiceInteractionCallback callback) {
if (DEBUG) {
Slog.d(TAG, "#handleExternalSourceHotwordDetection");
}
AudioReader audioSource = AudioReader.createFromInputStream(
new ParcelFileDescriptor.AutoCloseInputStream(audioStream));
Pair<ParcelFileDescriptor, ParcelFileDescriptor> clientPipe = createPipe();
if (clientPipe == null) {
// TODO: Need to propagate as unknown error or something?
return;
}
ParcelFileDescriptor serviceAudioSink = clientPipe.second;
ParcelFileDescriptor serviceAudioSource = clientPipe.first;
synchronized (mLock) {
mCurrentAudioSink = serviceAudioSink;
}
mAudioCopyExecutor.execute(() -> {
try (AudioReader source = audioSource;
OutputStream fos =
new ParcelFileDescriptor.AutoCloseOutputStream(serviceAudioSink)) {
byte[] buffer = new byte[1024];
while (true) {
int bytesRead = source.read(buffer, 0, 1024);
if (bytesRead < 0) {
break;
}
// TODO: First write to ring buffer to make sure we don't lose data if the next
// statement fails.
// ringBuffer.append(buffer, bytesRead);
fos.write(buffer, 0, bytesRead);
}
} catch (IOException e) {
Slog.w(TAG, "Failed supplying audio data to validator", e);
} finally {
synchronized (mLock) {
mCurrentAudioSink = null;
}
}
});
// TODO: handle cancellations well
// TODO: what if we cancelled and started a new one?
mRemoteHotwordDetectionService.run(
service -> service.detectFromMicrophoneSource(
serviceAudioSource,
// TODO: consider making a proxy callback + copy of audio format
AUDIO_SOURCE_EXTERNAL,
audioFormat,
options,
new IDspHotwordDetectionCallback.Stub() {
@Override
public void onRejected(HotwordRejectedResult result)
throws RemoteException {
bestEffortClose(serviceAudioSink);
bestEffortClose(serviceAudioSource);
bestEffortClose(audioSource);
// TODO: Propagate the HotwordRejectedResult.
}
@Override
public void onDetected(HotwordDetectedResult triggerResult)
throws RemoteException {
bestEffortClose(serviceAudioSink);
bestEffortClose(serviceAudioSource);
// TODO: noteOp here.
callback.onDetected(triggerResult, null /* audioFormat */,
null /* audioStream */);
// TODO: Add a delay before closing.
bestEffortClose(audioSource);
}
}));
}
private static void bestEffortClose(Closeable closeable) {
try {
closeable.close();
} catch (IOException e) {
if (DEBUG) {
Slog.w(TAG, "Failed closing", e);
}
}
}
};