blob: bfa56b6a58126f4d38307c972893a1511d968738 [file] [log] [blame]
#pragma once
#include "caffe2/core/logging.h"
#include "caffe2/opt/shape_info.h"
#include "caffe2/proto/caffe2_pb.h"
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
namespace caffe2 {
// This struct stores the max bound size for batch in the general sense. We have
// the conventioal batch size and the look-up sequence, which is also batch in a
// sense.
struct CAFFE2_API BoundShapeSpec {
explicit BoundShapeSpec(int64_t b, int64_t q)
: max_batch_size(b), max_seq_size(q) {}
int64_t max_batch_size;
int64_t max_seq_size;
};
/// \class A class that does bound shape inference given a C2 net. Depending on
/// its type, each op have a maximum shape that it accepts. We define some
/// initial bound for certain dimension, for example max batch size or max
/// sequnce lookup size. And the inference will first infer the input size and
/// then propagates the bound shape down the network. For now the variable part
/// (bound part) is the first dimension of the shape, which usually corresponds
/// to the batch size or sequence lookup size.
class BoundShapeInferencerBase {
public:
explicit BoundShapeInferencerBase(const BoundShapeSpec& spec) : spec_(spec) {
CAFFE_ENFORCE_GE(spec_.max_batch_size, 0);
CAFFE_ENFORCE_GE(spec_.max_seq_size, 0);
}
virtual ~BoundShapeInferencerBase() {}
virtual void InferBoundShapeAndType(
const NetDef& net,
const std::unordered_map<std::string, ShapeInfo>& info,
caffe2::Workspace* ws) = 0;
const ShapeInfoMap& shape_info() const {
return shape_info_;
}
/// Print out all the shape info
std::string PrintShapeInfo() const {
std::stringstream ss;
for (const auto& kv : shape_info_) {
const auto& s = kv.second;
ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
for (const auto d : s.shape.dims()) {
ss << d << ", ";
}
ss << "], dtype: " << s.shape.data_type() << "\n";
}
return ss.str();
}
protected:
const BoundShapeSpec spec_;
std::unordered_map<std::string, ShapeInfo> shape_info_;
};
class CAFFE2_API BoundShapeInferencer : public BoundShapeInferencerBase {
public:
explicit BoundShapeInferencer(const BoundShapeSpec& spec)
: BoundShapeInferencerBase(spec) {}
virtual ~BoundShapeInferencer() override {}
void InferBoundShapeAndType(
const NetDef& net,
const std::unordered_map<std::string, ShapeInfo>& info,
caffe2::Workspace* ws) override;
protected:
TensorShape& CheckAndSetTensorShapeAndType(
const std::string& name,
ShapeInfo::DimType t,
std::vector<int64_t> bound_dims,
TensorProto::DataType type,
bool is_quantized,
bool allow_existing_shape = false);
TensorShape& SetTensorShapeAndTypeIfNotExist(
const std::string& name,
ShapeInfo::DimType t,
std::vector<int64_t> bound_dims,
TensorProto::DataType type,
bool is_quantized);
virtual void InferOps(const OperatorDef& op, caffe2::Workspace* ws);
void InferConcatInputs(const OperatorDef& op);
void InferGivenTensorFill(const OperatorDef& op);
void InferSparseLengthsSum(const OperatorDef& op);
void InferFC(const OperatorDef& op);
void InferConcat(const OperatorDef& op);
void InferShape(const OperatorDef& op);
void InferReshape(const OperatorDef& op);
void InferLengthsRangeFill(const OperatorDef& op);
// Standard shape/type inference using op schema registered shape inference
// function
void InferCommonOp(const OperatorDef& op);
void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) const;
ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH};
int64_t current_max_batch_size_{0};
};
CAFFE2_API std::shared_ptr<BoundShapeInferencerBase> getBoundShapeInferencer(
const BoundShapeSpec& spec);
C10_DECLARE_SHARED_REGISTRY(
BoundShapeInferencerRegistry,
BoundShapeInferencerBase,
const BoundShapeSpec&);
} // namespace caffe2