blob: 54e70cb1502a83569d65ec70d9325663931bc284 [file] [log] [blame]
/**
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "InputChannelTest"
#include "../includes/common.h"
#include <android-base/stringprintf.h>
#include <input/InputTransport.h>
using namespace android;
using android::base::StringPrintf;
static std::string memoryAsHexString(const void* const address, size_t numBytes) {
std::string str;
for (size_t i = 0; i < numBytes; i++) {
str += StringPrintf("%02X ", static_cast<const uint8_t* const>(address)[i]);
}
return str;
}
/**
* There could be non-zero bytes in-between InputMessage fields. Force-initialize the entire
* memory to zero, then only copy the valid bytes on a per-field basis.
* Input: message msg
* Output: cleaned message outMsg
*/
static void sanitizeMessage(const InputMessage& msg, InputMessage* outMsg) {
memset(outMsg, 0, sizeof(*outMsg));
// Write the header
outMsg->header.type = msg.header.type;
// Write the body
switch(msg.header.type) {
case InputMessage::TYPE_KEY: {
// uint32_t seq
outMsg->body.key.seq = msg.body.key.seq;
// nsecs_t eventTime
outMsg->body.key.eventTime = msg.body.key.eventTime;
// int32_t deviceId
outMsg->body.key.deviceId = msg.body.key.deviceId;
// int32_t source
outMsg->body.key.source = msg.body.key.source;
// int32_t displayId
outMsg->body.key.displayId = msg.body.key.displayId;
// int32_t action
outMsg->body.key.action = msg.body.key.action;
// int32_t flags
outMsg->body.key.flags = msg.body.key.flags;
// int32_t keyCode
outMsg->body.key.keyCode = msg.body.key.keyCode;
// int32_t scanCode
outMsg->body.key.scanCode = msg.body.key.scanCode;
// int32_t metaState
outMsg->body.key.metaState = msg.body.key.metaState;
// int32_t repeatCount
outMsg->body.key.repeatCount = msg.body.key.repeatCount;
// nsecs_t downTime
outMsg->body.key.downTime = msg.body.key.downTime;
break;
}
case InputMessage::TYPE_MOTION: {
// uint32_t seq
outMsg->body.motion.seq = msg.body.motion.seq;
// nsecs_t eventTime
outMsg->body.motion.eventTime = msg.body.motion.eventTime;
// int32_t deviceId
outMsg->body.motion.deviceId = msg.body.motion.deviceId;
// int32_t source
outMsg->body.motion.source = msg.body.motion.source;
// int32_t displayId
outMsg->body.motion.displayId = msg.body.motion.displayId;
// int32_t action
outMsg->body.motion.action = msg.body.motion.action;
// int32_t actionButton
outMsg->body.motion.actionButton = msg.body.motion.actionButton;
// int32_t flags
outMsg->body.motion.flags = msg.body.motion.flags;
// int32_t metaState
outMsg->body.motion.metaState = msg.body.motion.metaState;
// int32_t buttonState
outMsg->body.motion.buttonState = msg.body.motion.buttonState;
// int32_t edgeFlags
outMsg->body.motion.edgeFlags = msg.body.motion.edgeFlags;
// nsecs_t downTime
outMsg->body.motion.downTime = msg.body.motion.downTime;
// float xOffset
outMsg->body.motion.xOffset = msg.body.motion.xOffset;
// float yOffset
outMsg->body.motion.yOffset = msg.body.motion.yOffset;
// float xPrecision
outMsg->body.motion.xPrecision = msg.body.motion.xPrecision;
// float yPrecision
outMsg->body.motion.yPrecision = msg.body.motion.yPrecision;
// uint32_t pointerCount
outMsg->body.motion.pointerCount = msg.body.motion.pointerCount;
//struct Pointer pointers[MAX_POINTERS]
for (size_t i = 0; i < msg.body.motion.pointerCount; i++) {
// PointerProperties properties
outMsg->body.motion.pointers[i].properties.id =
msg.body.motion.pointers[i].properties.id;
outMsg->body.motion.pointers[i].properties.toolType =
msg.body.motion.pointers[i].properties.toolType;
// PointerCoords coords
outMsg->body.motion.pointers[i].coords.bits =
msg.body.motion.pointers[i].coords.bits;
const uint32_t count = BitSet64::count(msg.body.motion.pointers[i].coords.bits);
memcpy(&outMsg->body.motion.pointers[i].coords.values[0],
&msg.body.motion.pointers[i].coords.values[0],
count * sizeof(msg.body.motion.pointers[i].coords.values[0]));
}
break;
}
case InputMessage::TYPE_FINISHED: {
outMsg->body.finished.seq = msg.body.finished.seq;
outMsg->body.finished.handled = msg.body.finished.handled;
break;
}
}
}
/**
* Return false if vulnerability is found for a given message type
*/
static bool checkMessage(sp<InputChannel> server, sp<InputChannel> client, int type) {
InputMessage serverMsg;
// Set all potentially uninitialized bytes to 1, for easier comparison
memset(&serverMsg, 1, sizeof(serverMsg));
serverMsg.header.type = type;
if (type == InputMessage::TYPE_MOTION) {
serverMsg.body.motion.pointerCount = MAX_POINTERS;
}
status_t result = server->sendMessage(&serverMsg);
if (result != OK) {
ALOGE("Could not send message to the input channel");
return false;
}
InputMessage clientMsg;
result = client->receiveMessage(&clientMsg);
if (result != OK) {
ALOGE("Could not receive message from the input channel");
return false;
}
if (serverMsg.header.type != clientMsg.header.type) {
ALOGE("Types do not match");
return false;
}
if (clientMsg.header.padding != 0) {
ALOGE("Found padding to be uninitialized");
return false;
}
InputMessage sanitizedClientMsg;
sanitizeMessage(clientMsg, &sanitizedClientMsg);
if (memcmp(&clientMsg, &sanitizedClientMsg, clientMsg.size()) != 0) {
ALOGE("Client received un-sanitized message");
ALOGE("Received message: %s", memoryAsHexString(&clientMsg, clientMsg.size()).c_str());
ALOGE("Expected message: %s",
memoryAsHexString(&sanitizedClientMsg, clientMsg.size()).c_str());
return false;
}
return true;
}
/**
* Create an unsanitized message
* Send
* Receive
* Compare the received message to a sanitized expected message
* Do this for all message types
*/
int main() {
sp<InputChannel> server, client;
status_t result = InputChannel::openInputChannelPair(String8("channel name").string(), server, client);
if (result != OK) {
ALOGE("Could not open input channel pair");
return 0;
}
int types[] = {InputMessage::TYPE_KEY, InputMessage::TYPE_MOTION, InputMessage::TYPE_FINISHED};
for (int type : types) {
bool success = checkMessage(server, client, type);
if (!success) {
ALOGE("Check message failed for type %i", type);
return EXIT_VULNERABLE;
}
}
return 0;
}