blob: a7186d3d32b8e7187d04d9823ff7f6afc4e1c67c [file] [log] [blame]
#ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#include <istream>
#include <ostream>
#include <random>
#include <string>
#include <caffe2/core/db.h>
#include <caffe2/core/logging.h>
#include <caffe2/operators/prefetch_op.h>
#include <caffe2/utils/math.h>
#include <caffe2/utils/thread_pool.h>
#include <caffe2/video/video_io.h>
namespace caffe2 {
template <class Context>
class VideoInputOp final : public PrefetchOperator<Context> {
public:
using OperatorBase::OutputSize;
using PrefetchOperator<Context>::context_;
using PrefetchOperator<Context>::prefetch_thread_;
explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
~VideoInputOp() {
PrefetchOperator<Context>::Finalize();
}
// override methods
bool Prefetch() override;
bool CopyPrefetched() override;
private:
void CheckParamsAndPrint();
bool GetClipsAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data,
int* video_id_data);
void DecodeAndTransform(
const std::string& value,
float* clip_rgb_data,
float* clip_of_data,
int* label_data,
int* video_id_data,
std::mt19937* randgen,
std::bernoulli_distribution* mirror_this_clip);
const db::DBReader* reader_;
Tensor prefetched_clip_rgb_;
Tensor prefetched_clip_of_;
Tensor prefetched_label_;
Tensor prefetched_video_id_;
Tensor prefetched_clip_rgb_on_device_{Context::GetDeviceType()};
Tensor prefetched_clip_of_on_device_{Context::GetDeviceType()};
Tensor prefetched_label_on_device_{Context::GetDeviceType()};
Tensor prefetched_video_id_on_device_{Context::GetDeviceType()};
int batch_size_;
int clip_per_video_;
std::vector<float> mean_rgb_;
std::vector<float> inv_std_rgb_;
std::vector<float> mean_of_;
std::vector<float> inv_std_of_;
int channels_rgb_;
int channels_of_;
int crop_height_;
int crop_width_;
int scale_h_;
int scale_w_;
int height_min_;
int width_min_;
int length_rgb_;
int sampling_rate_rgb_;
bool color_jitter_;
float img_saturation_;
float img_brightness_;
float img_contrast_;
bool color_lighting_;
float color_lighting_std_;
std::vector<std::vector<float>> color_lighting_eigvecs_;
std::vector<float> color_lighting_eigvals_;
int num_of_required_frame_;
int length_of_;
int sampling_rate_of_;
int frame_gap_of_;
bool random_mirror_;
int num_of_class_;
bool use_local_file_;
bool random_crop_;
bool multi_crop_;
int multi_crop_count_;
int flow_data_type_;
int flow_alg_type_;
int decode_type_;
int video_res_type_;
bool do_flow_aggregation_;
bool get_rgb_;
bool get_optical_flow_;
bool get_video_id_;
bool do_multi_label_;
// thread pool for parse + decode
int num_decode_threads_;
std::shared_ptr<TaskThreadPool> thread_pool_;
};
template <class Context>
void VideoInputOp<Context>::CheckParamsAndPrint() {
// check whether the input parameters are valid or not
CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
CAFFE_ENFORCE_GT(
clip_per_video_, 0, "Number of clips per video should be positive.");
CAFFE_ENFORCE_GT(crop_height_, 0, "Must provide the cropping height value.");
CAFFE_ENFORCE_GT(crop_width_, 0, "Must provide the cropping width value.");
CAFFE_ENFORCE_GT(
num_of_required_frame_, 0, "Required number of frames must be positive.");
if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
CAFFE_ENFORCE_GT(height_min_, 0, "Must provide the minimal height value.");
CAFFE_ENFORCE_GT(width_min_, 0, "Must provide the minimal width value.");
CAFFE_ENFORCE_GE(
height_min_,
crop_height_,
"The minimal height must be no smaller than the cropping height.");
CAFFE_ENFORCE_GE(
width_min_,
crop_width_,
"The minimal width must be no smaller than the cropping width.");
} else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
CAFFE_ENFORCE_GE(
scale_h_,
crop_height_,
"The scaled height must be no smaller than the cropping height.");
CAFFE_ENFORCE_GE(
scale_w_,
crop_width_,
"The scaled width must be no smaller than the cropping width.");
}
if (get_rgb_) {
CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
CAFFE_ENFORCE_GT(
sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
CAFFE_ENFORCE_EQ(
channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
CAFFE_ENFORCE_EQ(
channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
}
if (get_optical_flow_) {
CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
CAFFE_ENFORCE_GT(
sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
CAFFE_ENFORCE_EQ(
channels_of_,
mean_of_.size(),
"Number of optical flow channels is wrong!");
CAFFE_ENFORCE_EQ(
channels_of_,
inv_std_of_.size(),
"Number of optical flow channels is wrong!");
}
if (clip_per_video_ > 1) {
CAFFE_ENFORCE_EQ(
decode_type_,
DecodeType::DO_UNIFORM_SMP,
"Only uniformly sampling is supported when sampling multiple clips!");
}
if (do_multi_label_) {
CAFFE_ENFORCE_GT(
num_of_class_,
0,
"Number of classes must be set when using multiple labels.");
}
// print out the parameter settings
LOG(INFO) << "Creating a clip input op with the following setting: ";
LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
LOG(INFO) << " Outputting in batches of " << batch_size_ << " videos;";
LOG(INFO) << " Each video has " << clip_per_video_ << " clips;";
LOG(INFO) << " Scaling image to " << scale_h_ << "x" << scale_w_;
LOG(INFO) << " (Height, Width) is at least (" << height_min_ << ", "
<< width_min_ << ")";
LOG(INFO) << " Cropping video frame to " << crop_height_ << "x"
<< crop_width_ << (random_mirror_ ? " with " : " without ")
<< "random mirroring;";
LOG(INFO) << " Using " << (random_crop_ ? "random" : "center") << " crop";
LOG(INFO) << " Is multi-cropping enabled: " << multi_crop_;
if (get_rgb_) {
LOG(INFO) << " Using a clip of " << length_rgb_ << " rgb frames "
<< "with " << channels_rgb_ << " channels "
<< "and a sampling rate of 1:" << sampling_rate_rgb_;
LOG(INFO) << " RGB data augmentation. Color jittering: " << color_jitter_
<< ". Color lighting: " << color_lighting_;
for (int i = 0; i < channels_rgb_; i++) {
LOG(INFO) << " RGB " << i << "-th channel mean: " << mean_rgb_[i]
<< " std: " << 1.f / inv_std_rgb_[i];
}
}
if (get_optical_flow_) {
LOG(INFO) << " Using a clip of " << length_of_ << " optical flow frames "
<< "with " << channels_of_ << " channels "
<< "and a sampling rate of 1:" << sampling_rate_of_
<< " flow_data_type_: " << flow_data_type_
<< " flow_alg_type_: " << flow_alg_type_;
for (int i = 0; i < channels_of_; i++) {
LOG(INFO) << " Optical flow" << i
<< "-th channel mean: " << mean_of_[i]
<< " std: " << 1.f / inv_std_of_[i];
}
}
if (video_res_type_ == VideoResType::ORIGINAL_RES) {
LOG(INFO) << " Use original resolution";
} else if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
LOG(INFO) << " Resize with minimal size and keep aspect ratio";
} else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
LOG(INFO) << " Resize and ignore aspect ratio";
} else {
LOG(ERROR) << " Unknown video resolution type";
}
if (decode_type_ == DecodeType::DO_TMP_JITTER) {
LOG(INFO) << " Do temporal jittering";
} else if (decode_type_ == DecodeType::USE_START_FRM) {
LOG(INFO) << " Use start_frm for decoding";
} else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
LOG(INFO) << " Do uniformly sampling";
} else {
LOG(ERROR) << " Unknown video decoding type";
}
}
template <class Context>
VideoInputOp<Context>::VideoInputOp(
const OperatorDef& operator_def,
Workspace* ws)
: PrefetchOperator<Context>(operator_def, ws),
reader_(nullptr),
batch_size_(
OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
clip_per_video_(
OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
mean_rgb_(OperatorBase::template GetRepeatedArgument<float>(
"mean_rgb_per_channel",
{OperatorBase::template GetSingleArgument<float>("mean_rgb", 128.)})),
inv_std_rgb_(OperatorBase::template GetRepeatedArgument<float>(
"std_rgb_per_channel",
{OperatorBase::template GetSingleArgument<float>("std_rgb", 1.)})),
mean_of_(OperatorBase::template GetRepeatedArgument<float>(
"mean_of_per_channel",
{OperatorBase::template GetSingleArgument<float>("mean_of", 0.)})),
inv_std_of_(OperatorBase::template GetRepeatedArgument<float>(
"std_of_per_channel",
{OperatorBase::template GetSingleArgument<float>("std_of", 1.)})),
channels_rgb_(
OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
channels_of_(
OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
crop_height_(OperatorBase::template GetSingleArgument<int>(
"crop_height",
{OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
crop_width_(OperatorBase::template GetSingleArgument<int>(
"crop_width",
{OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
height_min_(OperatorBase::template GetSingleArgument<int>(
"height_min",
{OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
width_min_(OperatorBase::template GetSingleArgument<int>(
"width_min",
{OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
length_rgb_(
OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
"sampling_rate_rgb",
1)),
color_jitter_(OperatorBase::template GetSingleArgument<bool>(
"color_jitter",
false)),
img_saturation_(OperatorBase::template GetSingleArgument<float>(
"img_saturation",
0.4)),
img_brightness_(OperatorBase::template GetSingleArgument<float>(
"img_brightness",
0.4)),
img_contrast_(
OperatorBase::template GetSingleArgument<float>("img_contrast", 0.4)),
color_lighting_(OperatorBase::template GetSingleArgument<bool>(
"color_lighting",
false)),
color_lighting_std_(OperatorBase::template GetSingleArgument<float>(
"color_lighting_std",
0.1)),
length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
sampling_rate_of_(
OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
frame_gap_of_(
OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
random_mirror_(OperatorBase::template GetSingleArgument<bool>(
"random_mirror",
true)),
num_of_class_(
OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
use_local_file_(OperatorBase::template GetSingleArgument<bool>(
"use_local_file",
false)),
random_crop_(
OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
multi_crop_(
OperatorBase::template GetSingleArgument<bool>("multi_crop", false)),
flow_data_type_(
OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
flow_alg_type_(
OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
decode_type_(
OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
video_res_type_(
OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
"do_flow_aggregation",
true)),
get_rgb_(OperatorBase::template GetSingleArgument<bool>("get_rgb", true)),
get_optical_flow_(OperatorBase::template GetSingleArgument<bool>(
"get_optical_flow",
false)),
get_video_id_(OperatorBase::template GetSingleArgument<bool>(
"get_video_id",
false)),
do_multi_label_(OperatorBase::template GetSingleArgument<bool>(
"do_multi_label",
false)),
num_decode_threads_(OperatorBase::template GetSingleArgument<int>(
"num_decode_threads",
4)),
thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)) {
// hard-coded PCA eigenvectors and eigenvalues, based on RBG channel order
color_lighting_eigvecs_.push_back(
std::vector<float>{-144.7125, 183.396, 102.2295});
color_lighting_eigvecs_.push_back(
std::vector<float>{-148.104, -1.1475, -207.57});
color_lighting_eigvecs_.push_back(
std::vector<float>{-148.818, -177.174, 107.1765});
color_lighting_eigvals_ = std::vector<float>{0.2175, 0.0188, 0.0045};
// multi-cropping for testing
multi_crop_count_ = 1;
if (multi_crop_) {
// we take left-top, central-top, right-top, left-bottom, central-bottom,
// right-bottom and central-central croppings as well as their mirrorings
// In total, 14 croppings
multi_crop_count_ = 14;
}
num_of_required_frame_ = 0;
// mean and std for normalizing different optical flow data type;
// Example statistics generated from SOA are shown below, and you may
// want to change them if you are running on a different dataset;
// 7 channels: (flow_x, flow_y, flow_magitude, gray, Red, Green, Blue)
const std::vector<float> InputDataMean = {
0.0046635, 0.0046261, 0.963986, 102.976, 110.201, 100.64, 95.9966};
const std::vector<float> InputDataStd = {
0.972347, 0.755146, 1.43588, 55.3691, 58.1489, 56.4701, 55.3324};
// if we need RGB as an input
if (get_rgb_) {
// how many frames we need for RGB
num_of_required_frame_ = std::max(
num_of_required_frame_, (length_rgb_ - 1) * sampling_rate_rgb_ + 1);
channels_rgb_ = 3;
if (mean_rgb_.size() != channels_rgb_ ||
inv_std_rgb_.size() != channels_rgb_) {
mean_rgb_.clear();
inv_std_rgb_.clear();
for (int i = 4; i < 7; i++) {
mean_rgb_.push_back(InputDataMean[i]);
inv_std_rgb_.push_back(1.f / InputDataStd[i]);
}
}
}
// if we need optical flow as an input
if (get_optical_flow_) {
// how many frames we need for optical flow
num_of_required_frame_ = std::max(
num_of_required_frame_,
(length_of_ - 1) * sampling_rate_of_ + frame_gap_of_ + 1);
if (mean_of_.size() != channels_of_ || inv_std_of_.size() != channels_of_) {
mean_of_.clear();
inv_std_of_.clear();
// set the parameters for different input data types
switch (flow_data_type_) {
case FlowDataType::Flow2C:
channels_of_ = 2;
for (int i = 0; i < channels_of_; i++) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
break;
case FlowDataType::Flow3C:
channels_of_ = 3;
for (int i = 0; i < channels_of_; i++) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
break;
// early fusion with gray
case FlowDataType::FlowWithGray:
channels_of_ = 3;
for (int i = 0; i < 2; i++) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
mean_of_.push_back(InputDataMean[3]);
inv_std_of_.push_back(1.f / InputDataStd[3]);
break;
// early fusion with RGB
case FlowDataType::FlowWithRGB:
channels_of_ = 5;
for (int i = 0; i < 2; i++) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
for (int i = 4; i < 7; i++) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
break;
default:
LOG(ERROR) << "Unknown optical flow type " << flow_data_type_;
break;
}
}
}
CheckParamsAndPrint();
// Always need a dbreader, even when using local video files
CAFFE_ENFORCE_GT(
operator_def.input_size(), 0, "Need to have a DBReader blob input");
vector<int64_t> data_shape(5);
vector<int64_t> label_shape(2);
// for RGB data
data_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
data_shape[1] = channels_rgb_;
data_shape[2] = length_rgb_;
data_shape[3] = crop_height_;
data_shape[4] = crop_width_;
ReinitializeTensor(&prefetched_clip_rgb_, data_shape, at::dtype<float>().device(CPU));
// for optical flow data
data_shape[1] = channels_of_;
data_shape[2] = length_of_;
ReinitializeTensor(&prefetched_clip_of_, data_shape, at::dtype<float>().device(CPU));
// If do_multi_label is used, output label is a binary vector
// of length num_of_class indicating which labels present
if (do_multi_label_) {
label_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
label_shape[1] = num_of_class_;
ReinitializeTensor(&prefetched_label_, label_shape, at::dtype<int>().device(CPU));
} else {
prefetched_label_.Resize(
vector<int64_t>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
}
ReinitializeTensor(&prefetched_video_id_, vector<int64_t>(1, batch_size_ * clip_per_video_ * multi_crop_count_), at::dtype<int>().device(CPU));
}
template <class Context>
bool VideoInputOp<Context>::GetClipsAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data,
int* video_id_data) {
TensorProtos protos;
int curr_proto_idx = 0;
CAFFE_ENFORCE(protos.ParseFromString(value));
const TensorProto& video_proto = protos.protos(curr_proto_idx++);
const TensorProto& label_proto = protos.protos(curr_proto_idx++);
int start_frm = 0;
// start_frm is only valid when sampling 1 clip per video without
// temporal jitterring
if (decode_type_ == DecodeType::USE_START_FRM) {
CAFFE_ENFORCE_GE(
protos.protos_size(),
curr_proto_idx + 1,
"Start frm proto not provided");
const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
start_frm = start_frm_proto.int32_data(0);
}
if (get_video_id_) {
CAFFE_ENFORCE_GE(
protos.protos_size(), curr_proto_idx + 1, "Video Id not provided");
const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
video_id_data[i] = video_id_proto.int64_data(0);
}
}
// assign labels
if (!do_multi_label_) {
for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
label_data[i] = label_proto.int32_data(0);
}
} else {
// For multiple label case, output label is a binary vector
// where presented concepts are makred 1
memset(
label_data,
0,
sizeof(int) * num_of_class_ * multi_crop_count_ * clip_per_video_);
for (int i = 0; i < clip_per_video_; i++) {
for (int j = 0; j < multi_crop_count_; ++j) {
for (int k = 0; k < label_proto.int32_data_size(); k++) {
CAFFE_ENFORCE_LT(
label_proto.int32_data(k),
num_of_class_,
"Label should be less than the number of classes.");
label_data
[(i * multi_crop_count_ + j) * num_of_class_ +
label_proto.int32_data(k)] = 1;
}
}
}
}
if (use_local_file_) {
CAFFE_ENFORCE_EQ(
video_proto.data_type(),
TensorProto::STRING,
"Database with a file_list is expected to be string data");
}
// initializing the decoding params
Params params;
params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
params.video_res_type_ = video_res_type_;
params.crop_height_ = crop_height_;
params.crop_width_ = crop_width_;
params.height_min_ = height_min_;
params.width_min_ = width_min_;
params.scale_w_ = scale_w_;
params.scale_h_ = scale_h_;
params.decode_type_ = decode_type_;
params.num_of_required_frame_ = num_of_required_frame_;
char* video_buffer = nullptr; // for decoding from buffer
std::string video_filename; // for decoding from file
int encoded_size = 0;
if (video_proto.data_type() == TensorProto::STRING) {
const string& encoded_video_str = video_proto.string_data(0);
if (!use_local_file_) {
encoded_size = encoded_video_str.size();
video_buffer = const_cast<char*>(encoded_video_str.data());
} else {
video_filename = encoded_video_str;
}
} else if (video_proto.data_type() == TensorProto::BYTE) {
if (!use_local_file_) {
encoded_size = video_proto.byte_data().size();
video_buffer = const_cast<char*>(video_proto.byte_data().data());
} else {
// TODO: does this works?
video_filename = video_proto.string_data(0);
}
} else {
LOG(FATAL) << "Unknown video data type.";
}
DecodeMultipleClipsFromVideo(
video_buffer,
video_filename,
encoded_size,
params,
start_frm,
clip_per_video_,
use_local_file_,
height,
width,
buffer_rgb);
return true;
}
template <class Context>
void VideoInputOp<Context>::DecodeAndTransform(
const std::string& value,
float* clip_rgb_data,
float* clip_of_data,
int* label_data,
int* video_id_data,
std::mt19937* randgen,
std::bernoulli_distribution* mirror_this_clip) {
std::vector<unsigned char*> buffer_rgb;
// get the video resolution after decoding
int height = 0;
int width = 0;
// Decode the video from memory or read from a local file
CHECK(GetClipsAndLabelsFromDBValue(
value, height, width, buffer_rgb, label_data, video_id_data));
int clip_offset_rgb = multi_crop_count_ * channels_rgb_ * length_rgb_ *
crop_height_ * crop_width_;
int clip_crop_offset_of =
channels_of_ * length_of_ * crop_height_ * crop_width_;
int clip_offset_of = multi_crop_count_ * clip_crop_offset_of;
for (int i = 0; i < std::min(clip_per_video_, int(buffer_rgb.size())); i++) {
// get the rectangle for cropping
int h_off = 0;
int w_off = 0;
if (random_crop_) {
// using random crop for training
h_off =
std::uniform_int_distribution<>(0, height - crop_height_)(*randgen);
w_off = std::uniform_int_distribution<>(0, width - crop_width_)(*randgen);
} else {
// using center crop for testing
h_off = (height - crop_height_) / 2;
w_off = (width - crop_width_) / 2;
}
// cv::Rect rect(w_off, h_off, crop_width_, crop_height_);
// Multi cropping: we take left-top, central-top, right-top, left-bottom,
// central-bottom, right-bottom and central-central croppings as well as
// their mirrorings. In total, 14 croppings
int multi_crop_w_off[7] = {0,
(width - crop_width_) / 2,
width - crop_width_,
(width - crop_width_) / 2,
0,
(width - crop_width_) / 2,
width - crop_width_};
int multi_crop_h_off[7] = {0,
0,
0,
(height - crop_height_) / 2,
height - crop_height_,
height - crop_height_,
height - crop_height_};
// randomly mirror the image or not
bool mirror_me = random_mirror_ && (*mirror_this_clip)(*randgen);
if (get_rgb_ && clip_rgb_data) {
ClipTransformRGB(
buffer_rgb[i],
multi_crop_count_,
crop_height_,
crop_width_,
length_rgb_,
channels_rgb_,
sampling_rate_rgb_,
height,
width,
h_off,
w_off,
multi_crop_h_off,
multi_crop_w_off,
mirror_me,
color_jitter_,
img_saturation_,
img_brightness_,
img_contrast_,
color_lighting_,
color_lighting_std_,
color_lighting_eigvecs_,
color_lighting_eigvals_,
mean_rgb_,
inv_std_rgb_,
randgen,
clip_rgb_data + (i * clip_offset_rgb));
}
if (get_optical_flow_ && clip_of_data) {
cv::Rect rect;
for (int j = 0; j < multi_crop_count_; ++j) {
if (multi_crop_count_ == 1) {
rect = cv::Rect(w_off, h_off, crop_width_, crop_height_);
} else {
mirror_me = j / (multi_crop_count_ / 2);
int k = j % (multi_crop_count_ / 2);
rect = cv::Rect(
multi_crop_w_off[k],
multi_crop_h_off[k],
crop_width_,
crop_height_);
}
ClipTransformOpticalFlow(
buffer_rgb[i],
crop_height_,
crop_width_,
length_of_,
channels_of_,
sampling_rate_of_,
height,
width,
rect,
channels_rgb_,
mirror_me,
flow_alg_type_,
flow_data_type_,
frame_gap_of_,
do_flow_aggregation_,
mean_of_,
inv_std_of_,
clip_of_data + (i * clip_offset_of) + j * clip_crop_offset_of);
}
}
}
if (buffer_rgb.size() > 0) {
for (int i = 0; i < buffer_rgb.size(); i++) {
unsigned char* buff = buffer_rgb[i];
delete[] buff;
}
}
buffer_rgb.clear();
}
template <class Context>
bool VideoInputOp<Context>::Prefetch() {
// We will get the reader pointer from input.
// If we use local clips, db will store the list
reader_ = &OperatorBase::Input<db::DBReader>(0);
// Call mutable_data() once to allocate the underlying memory.
prefetched_clip_rgb_.mutable_data<float>();
prefetched_clip_of_.mutable_data<float>();
prefetched_label_.mutable_data<int>();
prefetched_video_id_.mutable_data<int>();
// Prefetching handled with a thread pool of "decode_threads" threads.
std::mt19937 meta_randgen(time(nullptr));
std::vector<std::mt19937> randgen_per_thread;
for (int i = 0; i < num_decode_threads_; ++i) {
randgen_per_thread.emplace_back(meta_randgen());
}
std::bernoulli_distribution mirror_this_clip(0.5);
for (int item_id = 0; item_id < batch_size_; ++item_id) {
std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];
int frame_size = crop_height_ * crop_width_;
// get the clip data pointer for the item_id -th example
float* clip_rgb_data = prefetched_clip_rgb_.mutable_data<float>() +
frame_size * length_rgb_ * channels_rgb_ * item_id * clip_per_video_ *
multi_crop_count_;
// get the optical flow data for the current clip
float* clip_of_data = prefetched_clip_of_.mutable_data<float>() +
frame_size * length_of_ * channels_of_ * item_id * clip_per_video_ *
multi_crop_count_;
// get the label data pointer for the item_id -th example
int* label_data = prefetched_label_.mutable_data<int>() +
(do_multi_label_ ? num_of_class_ : 1) * item_id * clip_per_video_ *
multi_crop_count_;
// get the video id data pointer for the item_id -th example
int* video_id_data = prefetched_video_id_.mutable_data<int>() +
item_id * clip_per_video_ * multi_crop_count_;
std::string key, value;
// read data
reader_->Read(&key, &value);
thread_pool_->run(std::bind(
&VideoInputOp<Context>::DecodeAndTransform,
this,
std::string(value),
clip_rgb_data,
clip_of_data,
label_data,
video_id_data,
randgen,
&mirror_this_clip));
} // for over the batch
thread_pool_->waitWorkComplete();
// If the context is not CPUContext, we will need to do a copy in the
// prefetch function as well.
if (!std::is_same<Context, CPUContext>::value) {
if (get_rgb_) {
prefetched_clip_rgb_on_device_.CopyFrom(
prefetched_clip_rgb_, true /*async*/);
}
if (get_optical_flow_) {
prefetched_clip_of_on_device_.CopyFrom(
prefetched_clip_of_, true /*async*/);
}
prefetched_label_on_device_.CopyFrom(prefetched_label_, true /*async*/);
if (get_video_id_) {
prefetched_video_id_on_device_.CopyFrom(
prefetched_video_id_, true /*async*/);
}
}
return true;
}
template <class Context>
bool VideoInputOp<Context>::CopyPrefetched() {
int index = 0;
if (get_rgb_) {
auto* clip_rgb_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
clip_rgb_output->CopyFrom(prefetched_clip_rgb_, true /*async*/);
} else {
clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, true /*async*/);
}
}
if (get_optical_flow_) {
auto* clip_of_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
clip_of_output->CopyFrom(prefetched_clip_of_, true /*async*/);
} else {
clip_of_output->CopyFrom(prefetched_clip_of_on_device_, true /*async*/);
}
}
auto* label_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
label_output->CopyFrom(prefetched_label_, true /*async*/);
} else {
label_output->CopyFrom(prefetched_label_on_device_, true /*async*/);
}
if (get_video_id_) {
auto* video_id_output =
OperatorBase::Output<Tensor>(index, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
video_id_output->CopyFrom(prefetched_video_id_, true /*async*/);
} else {
video_id_output->CopyFrom(prefetched_video_id_on_device_, true /*async*/);
}
}
return true;
}
} // namespace caffe2
#endif // CAFFE2_VIDEO_VIDEO_INPUT_OP_H_