blob: d1813efa437618a91c07188553e12495947fc859 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
#include <algorithm>
namespace ui {
#if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
PalmFilterDeviceInfo CreatePalmFilterDeviceInfo(
const EventDeviceInfo& devinfo) {
PalmFilterDeviceInfo info;
info.max_x = devinfo.GetAbsMaximum(ABS_MT_POSITION_X);
info.x_res = devinfo.GetAbsResolution(ABS_MT_POSITION_X);
info.max_y = devinfo.GetAbsMaximum(ABS_MT_POSITION_Y);
info.y_res = devinfo.GetAbsResolution(ABS_MT_POSITION_Y);
if (info.x_res == 0) {
info.x_res = 1;
}
if (info.y_res == 0) {
info.y_res = 1;
}
info.major_radius_res = devinfo.GetAbsResolution(ABS_MT_TOUCH_MAJOR);
if (info.major_radius_res == 0) {
// Device does not report major res: set to 1.
info.major_radius_res = 1;
}
if (devinfo.HasAbsEvent(ABS_MT_TOUCH_MINOR)) {
info.minor_radius_supported = true;
info.minor_radius_res = devinfo.GetAbsResolution(ABS_MT_TOUCH_MINOR);
} else {
info.minor_radius_supported = false;
info.minor_radius_res = info.major_radius_res;
}
if (info.minor_radius_res == 0) {
// Device does not report minor res: set to 1.
info.minor_radius_res = 1;
}
return info;
}
#endif
namespace {
float ScaledRadius(
float radius,
const NeuralStylusPalmDetectionFilterModelConfig& model_config) {
if (model_config.radius_polynomial_resize.empty()) {
return radius;
}
float return_value = 0.0f;
for (uint32_t i = 0; i < model_config.radius_polynomial_resize.size(); ++i) {
float power = model_config.radius_polynomial_resize.size() - 1 - i;
return_value +=
model_config.radius_polynomial_resize[i] * powf(radius, power);
}
return return_value;
}
float interpolate(float start_value, float end_value, float proportion) {
return start_value + (end_value - start_value) * proportion;
}
/**
* During resampling, the later events are used as a basis to populate
* non-resampled fields like major and minor. However, if the requested time is
* within this delay of the earlier event, the earlier event will be used as a
* basis instead.
*/
const static auto kPreferInitialEventDelay =
base::TimeDelta::FromMicroseconds(1);
/**
* Interpolate between the "before" and "after" events to get a resampled value
* at the timestamp 'time'. Not all fields are interpolated. For fields that are
* not interpolated, the values are taken from the 'after' sample unless the
* requested time is very close to the 'before' sample.
*/
PalmFilterSample getSampleAtTime(base::TimeTicks time,
const PalmFilterSample& before,
const PalmFilterSample& after) {
// Use the newest sample as the base, except when the requested time is very
// close to the 'before' sample.
PalmFilterSample result = after;
if (time - before.time < kPreferInitialEventDelay) {
result = before;
}
// Only the x and y values are interpolated. We could also interpolate the
// oval size and orientation, but it's not a simple computation, and would
// likely not provide much value.
const float proportion =
static_cast<float>((time - before.time).InNanoseconds()) /
(after.time - before.time).InNanoseconds();
result.edge = interpolate(before.edge, after.edge, proportion);
result.point.set_x(
interpolate(before.point.x(), after.point.x(), proportion));
result.point.set_y(
interpolate(before.point.y(), after.point.y(), proportion));
result.time = time;
return result;
}
} // namespace
PalmFilterSample CreatePalmFilterSample(
const InProgressTouchEvdev& touch,
const base::TimeTicks& time,
const NeuralStylusPalmDetectionFilterModelConfig& model_config,
const PalmFilterDeviceInfo& dev_info) {
// radius_x and radius_y have been
// scaled by resolution already.
PalmFilterSample sample;
sample.time = time;
sample.major_radius = ScaledRadius(
std::max(touch.major, touch.minor) / dev_info.major_radius_res,
model_config);
if (dev_info.minor_radius_supported) {
sample.minor_radius = ScaledRadius(
std::min(touch.major, touch.minor) / dev_info.minor_radius_res,
model_config);
} else {
sample.minor_radius = ScaledRadius(touch.major, model_config);
}
float nearest_x_edge = std::min(touch.x, dev_info.max_x - touch.x);
float nearest_y_edge = std::min(touch.y, dev_info.max_y - touch.y);
float normalized_x_edge = nearest_x_edge / dev_info.x_res;
float normalized_y_edge = nearest_y_edge / dev_info.y_res;
// Nearest edge distance, in mm.
sample.edge = std::min(normalized_x_edge, normalized_y_edge);
sample.point =
gfx::PointF(touch.x / dev_info.x_res, touch.y / dev_info.y_res);
sample.tracking_id = touch.tracking_id;
sample.pressure = touch.pressure;
return sample;
}
PalmFilterStroke::PalmFilterStroke(
const NeuralStylusPalmDetectionFilterModelConfig& model_config,
int tracking_id)
: tracking_id_(tracking_id),
max_sample_count_(model_config.max_sample_count),
resample_period_(model_config.resample_period) {}
PalmFilterStroke::PalmFilterStroke(const PalmFilterStroke& other) = default;
PalmFilterStroke::PalmFilterStroke(PalmFilterStroke&& other) = default;
PalmFilterStroke::~PalmFilterStroke() {}
void PalmFilterStroke::ProcessSample(const PalmFilterSample& sample) {
DCHECK_EQ(tracking_id_, sample.tracking_id);
if (resample_period_.has_value()) {
Resample(sample);
return;
}
AddSample(sample);
while (samples_.size() > max_sample_count_) {
AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
samples_.pop_front();
}
}
void PalmFilterStroke::AddSample(const PalmFilterSample& sample) {
AddToUnscaledCentroid(sample.point.OffsetFromOrigin());
samples_.push_back(sample);
samples_seen_++;
}
/**
* When resampling is enabled, we don't store all samples. Only the resampled
* values are stored into samples_. In addition, the last real event is stored
* into last_sample_, which is used to calculate the resampled values.
*/
void PalmFilterStroke::Resample(const PalmFilterSample& sample) {
if (samples_seen_ == 0) {
AddSample(sample);
last_sample_ = sample;
return;
}
// We already have a valid last sample here.
DCHECK_LE(last_sample_.time, sample.time);
// Generate resampled values
base::TimeTicks next_sample_time = samples_.back().time + *resample_period_;
while (next_sample_time <= sample.time) {
AddSample(getSampleAtTime(next_sample_time, last_sample_, sample));
next_sample_time = samples_.back().time + (*resample_period_);
}
last_sample_ = sample;
// Prune the resampled collection
while ((samples_.back().time - samples_.front().time) >=
(*resample_period_) * max_sample_count_) {
AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
samples_.pop_front();
}
}
void PalmFilterStroke::AddToUnscaledCentroid(const gfx::Vector2dF point) {
const gfx::Vector2dF corrected_point = point - unscaled_centroid_sum_error_;
const gfx::PointF new_unscaled_centroid =
unscaled_centroid_ + corrected_point;
unscaled_centroid_sum_error_ =
(new_unscaled_centroid - unscaled_centroid_) - corrected_point;
unscaled_centroid_ = new_unscaled_centroid;
}
gfx::PointF PalmFilterStroke::GetCentroid() const {
if (samples_.size() == 0) {
return gfx::PointF(0., 0.);
}
return gfx::ScalePoint(unscaled_centroid_, 1.f / samples_.size());
}
const std::deque<PalmFilterSample>& PalmFilterStroke::samples() const {
return samples_;
}
int PalmFilterStroke::tracking_id() const {
return tracking_id_;
}
uint64_t PalmFilterStroke::samples_seen() const {
return samples_seen_;
}
float PalmFilterStroke::MaxMajorRadius() const {
float maximum = 0.0;
for (const auto& sample : samples_) {
maximum = std::max(maximum, sample.major_radius);
}
return maximum;
}
float PalmFilterStroke::BiggestSize() const {
float biggest = 0;
for (const auto& sample : samples_) {
float size;
if (sample.minor_radius <= 0) {
size = sample.major_radius * sample.major_radius;
} else {
size = sample.major_radius * sample.minor_radius;
}
biggest = std::max(biggest, size);
}
return biggest;
}
static std::string addLinePrefix(std::string str, const std::string& prefix) {
std::stringstream ss;
bool newLineStarted = true;
for (const auto& ch : str) {
if (newLineStarted) {
ss << prefix;
newLineStarted = false;
}
if (ch == '\n') {
newLineStarted = true;
}
ss << ch;
}
return ss.str();
}
std::ostream& operator<<(std::ostream& out, const gfx::PointF& point) {
out << "PointF(" << point.x() << ", " << point.y() << ")";
return out;
}
std::ostream& operator<<(std::ostream& out, const gfx::Vector2dF& vec) {
out << "Vector2dF(" << vec.x() << ", " << vec.y() << ")";
return out;
}
std::ostream& operator<<(std::ostream& out, const PalmFilterDeviceInfo& info) {
out << "PalmFilterDeviceInfo(max_x=" << info.max_x;
out << ", max_y=" << info.max_y;
out << ", x_res=" << info.x_res;
out << ", y_res=" << info.y_res;
out << ", major_radius_res=" << info.major_radius_res;
out << ", minor_radius_res=" << info.minor_radius_res;
out << ", minor_radius_supported=" << info.minor_radius_supported;
out << ")";
return out;
}
std::ostream& operator<<(std::ostream& out, const PalmFilterSample& sample) {
out << "PalmFilterSample(major=" << sample.major_radius
<< ", minor=" << sample.minor_radius << ", pressure=" << sample.pressure
<< ", edge=" << sample.edge << ", tracking_id=" << sample.tracking_id
<< ", point=" << sample.point << ", time=" << sample.time << ")";
return out;
}
std::ostream& operator<<(std::ostream& out, const PalmFilterStroke& stroke) {
out << "PalmFilterStroke(\n";
out << " GetCentroid() = " << stroke.GetCentroid() << "\n";
out << " BiggestSize() = " << stroke.BiggestSize() << "\n";
out << " MaxMajorRadius() = " << stroke.MaxMajorRadius() << "\n";
std::stringstream stream;
stream << stroke.samples();
out << " samples (" << stroke.samples().size() << " total): \n"
<< addLinePrefix(stream.str(), " ") << "\n";
out << " samples_seen() = " << stroke.samples_seen() << "\n";
out << " tracking_id() = " << stroke.tracking_id() << "\n";
out << " max_sample_count_ = " << stroke.max_sample_count_ << "\n";
if (stroke.resample_period_) {
out << " resample_period_ = " << *(stroke.resample_period_) << "\n";
out << " last_sample_ = " << stroke.last_sample_ << "\n";
} else {
out << " resample_period_ = <not set>\n";
out << " last_sample_ = <not valid b/c resampling is off>\n";
}
out << " unscaled_centroid_ = " << stroke.unscaled_centroid_ << "\n";
out << " unscaled_centroid_sum_error_ = "
<< stroke.unscaled_centroid_sum_error_ << "\n";
out << ")\n";
return out;
}
} // namespace ui