blob: f0e002ca119e737e3d3df9393d07a58cddbfc194 [file] [log] [blame]
#pragma once
#include <bitset>
#include <mutex>
#include <unordered_map>
#include <vector>
namespace torch {
namespace monitor {
constexpr int NUM_AGGREGATIONS = 7;
// Aggregation is the list of possible aggregations for Stats.
// These use bitwise flags so they can be efficiently stored.
enum Aggregation {
// NONE means no aggregations are set.
NONE = 0,
// VALUE exports the most recently set value.
VALUE = 1,
// MEAN computes the mean of the set values within the window. Zero if no
// values.
MEAN = 2,
// COUNT tracks the number of times a value is set within the window.
COUNT = 3,
// SUM computes the sum of the values set within the window.
SUM = 4,
// MIN computes the minimum of the values set within the window. Zero if no
// values.
MAX = 5,
// MAX computes the maximum of the values set within the window. Zero if no
// values.
MIN = 6,
};
const char* aggregationName(Aggregation agg);
template <typename T>
class Stat;
namespace {
inline std::bitset<NUM_AGGREGATIONS> merge(
std::initializer_list<Aggregation>& list) {
std::bitset<NUM_AGGREGATIONS> a;
for (Aggregation b : list) {
a.set(b);
}
return a;
}
} // namespace
namespace detail {
void registerStat(Stat<double>* stat);
void registerStat(Stat<int64_t>* stat);
void unregisterStat(Stat<double>* stat);
void unregisterStat(Stat<int64_t>* stat);
} // namespace detail
template <typename T>
class Stat {
private:
struct Values {
T value{0};
T sum{0};
T min{0};
T max{0};
int64_t count{0};
};
public:
Stat(
std::string name,
std::initializer_list<Aggregation> aggregations,
int64_t windowSize = -1)
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize) {
detail::registerStat(this);
}
~Stat() {
detail::unregisterStat(this);
}
// add adds the value v to the current window.
void add(T v) noexcept {
std::lock_guard<std::mutex> guard(mu_);
if (aggregations_.test(VALUE)) {
current_.value = v;
}
if (aggregations_.test(MEAN) || aggregations_.test(SUM)) {
current_.sum += v;
}
if (aggregations_.test(MAX)) {
if (current_.max < v || current_.count == 0) {
current_.max = v;
}
}
if (aggregations_.test(MIN)) {
if (current_.min > v || current_.count == 0) {
current_.min = v;
}
}
current_.count += 1;
if (windowSize_ > 0 && current_.count >= windowSize_) {
saveCurrentLocked();
}
}
const std::string& name() const noexcept {
return name_;
}
int64_t windowSize() const noexcept {
return windowSize_;
}
// count returns the number of items in the current open window.
int64_t count() noexcept {
std::lock_guard<std::mutex> guard(mu_);
return current_.count;
}
// closeWindow finalizes the collected stats window so they can be accessed
// via get().
// If the Stat has a windowSize specified this doesn't do anything since the
// window is automatically closed when enough samples have been logged.
void closeWindow() noexcept {
if (windowSize_ <= 0) {
std::lock_guard<std::mutex> guard(mu_);
saveCurrentLocked();
}
}
std::vector<std::pair<Aggregation, T>> get() noexcept {
std::vector<std::pair<Aggregation, T>> out;
out.reserve(aggregations_.count());
std::lock_guard<std::mutex> guard(mu_);
if (aggregations_.test(VALUE)) {
out.emplace_back(VALUE, prev_.value);
}
if (aggregations_.test(MEAN)) {
if (prev_.count == 0) {
out.emplace_back(MEAN, 0);
} else {
out.emplace_back(MEAN, prev_.sum / prev_.count);
}
}
if (aggregations_.test(COUNT)) {
out.emplace_back(COUNT, prev_.count);
}
if (aggregations_.test(SUM)) {
out.emplace_back(SUM, prev_.sum);
}
if (aggregations_.test(MAX)) {
out.emplace_back(MAX, prev_.max);
}
if (aggregations_.test(MIN)) {
out.emplace_back(MIN, prev_.min);
}
return out;
}
private:
void saveCurrentLocked() {
prev_ = current_;
current_ = Values();
}
const std::string name_;
const std::bitset<NUM_AGGREGATIONS> aggregations_;
const int64_t windowSize_;
std::mutex mu_;
Values current_;
Values prev_;
};
std::pair<
std::unordered_map<std::string, double>,
std::unordered_map<std::string, int64_t>>
closeAndGetStats() noexcept;
} // namespace monitor
} // namespace torch