blob: 7dd19f82d3f1603542f5e423257b0dc2c64bf704 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE2_UTILS_MKL_OPERATOR_H_
#define CAFFE2_UTILS_MKL_OPERATOR_H_
#include "caffe2/core/operator.h"
#include "caffe2/mkl/utils/mkl_dnn_cppwrapper.h"
#include "caffe2/mkl/utils/mkl_memory.h"
#include "caffe2/proto/caffe2.pb.h"
namespace caffe2 {
CAFFE_DECLARE_REGISTRY(
MKLOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
#define REGISTER_MKL_OPERATOR_CREATOR(key, ...) \
CAFFE_REGISTER_CREATOR(MKLOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_MKL_OPERATOR(name, ...) \
CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name, __VA_ARGS__)
#define REGISTER_MKL_OPERATOR_STR(str_name, ...) \
CAFFE_REGISTER_TYPED_CLASS(MKLOperatorRegistry, str_name, __VA_ARGS__)
#define REGISTER_MKL_OPERATOR_WITH_ENGINE(name, engine, ...) \
CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
namespace mkl {
// MKLOperator is the base scaffolding of the operators that uses MKLDNN. It
// provides a few operators that are useful to MKLDNN specific implementations.
template <typename T>
class MKLOperator : public OperatorBase {
public:
explicit MKLOperator(const OperatorDef& operator_def, Workspace* ws)
: OperatorBase(operator_def, ws),
context_(operator_def.device_option()) {}
virtual ~MKLOperator() {}
inline const MKLMemory<T>& Input(int idx) {
return OperatorBase::template Input<MKLMemory<T>>(idx);
}
inline MKLMemory<T>* Output(int idx) {
return OperatorBase::template Output<MKLMemory<T>>(idx);
}
// The run function of Operator switches to the device, and then carries out
// the actual computation with RunOnDevice(). You should implement RunOnDevice
// instead of Run().
bool Run(int /* unused */ /*stream_id*/) final {
// Since MKLDNN does not need to do SwithToDevice and
// FinishDeviceComputation,
// it is always just a re-route to RunOnDevice().
try {
auto result = RunOnDevice();
if (result) {
event().SetFinished();
} else {
RecordEvent(getErrorMsg().c_str());
}
return result;
} catch (EnforceNotMet& err) {
err.AppendMessage(getErrorMsg());
RecordEvent(err.what());
throw;
}
}
// Waits for a previous event. Note that to properly wait and run
// asynchronously, WaitEvent, RunAsync and Record should all be executed
// on the same CPU thread.
void WaitEvent(const Event& ev, int /* unused */) final {
context_.WaitEvent(ev);
}
void WaitEvents(const std::vector<const Event*>& events, int /* unused */)
final {
for (const auto& ev : events) {
context_.WaitEvent(*ev);
}
}
void RecordEvent(const char* err_msg = nullptr) final {
if (event_) {
context_.Record(event_.get(), err_msg);
}
}
virtual bool RunOnDevice() = 0;
inline void ExecutePrimitive() {
MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
}
protected:
std::string getErrorMsg() {
if (has_debug_def()) {
return "Error from operator: " + ProtoDebugString(debug_def());
} else {
return "Error from operator: no op def";
}
}
MKLContext context_;
// The primitive used in the operator.
PrimitiveWrapper<T> primitive_;
// Size cache for all the input sizes.
vector<vector<TIndex>> input_size_cache_;
// An internal MKLMemory buffer. This is usually handy when we have a
// single output from the operator. If your operator has multiple outputs
// then you should allocate your own buffer.
MKLMemory<T> buffer_;
// The resources vector that we will need to use;
void* resources_[dnnResourceNumber];
};
} // namespace mkl
#define USE_MKLOPERATOR_FUNCTIONS(T) \
USE_OPERATOR_BASE_FUNCTIONS; \
/* using override */ using MKLOperator<T>::Input; \
/* using override */ using MKLOperator<T>::Output; \
/* using override */ using MKLOperator<T>::ExecutePrimitive; \
/* using override */ using MKLOperator<T>::primitive_; \
/* using override */ using MKLOperator<T>::input_size_cache_; \
/* using override */ using MKLOperator<T>::buffer_; \
/* using override */ using MKLOperator<T>::resources_
#define USE_SIMPLE_MKL_CTOR_DTOR(name, T) \
name(const OperatorDef& operator_def, Workspace* ws) \
: MKLOperator<T>(operator_def, ws) {} \
virtual ~name() {}
} // namespace caffe2
#endif // CAFFE2_UTILS_MKL_OPERATOR_H_