Fix some issues for MklSoftmax
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 96d5945..1a18835 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -48,10 +48,7 @@
explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
-
- if (context_.softmax_fwd == nullptr) {
- Setup(fwdParams);
- }
+ Setup(fwdParams);
}
~MklSoftmaxPrimitive() {}
@@ -66,7 +63,7 @@
context_.fwd_stream->submit(context_.fwd_primitives);
- // after execution, set data handle back
+ // After execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
}
@@ -77,14 +74,14 @@
private:
struct SoftmaxFwdContext {
- // MKLDNN memory
+ // MKL-DNN memory
std::shared_ptr<memory> src_mem;
std::shared_ptr<memory> dst_mem;
- // desc & prmitive desc
+ // Primitive desc
std::shared_ptr<mkldnn::softmax_forward::desc> fwd_desc;
- // memory desc
+ // Memory desc
std::shared_ptr<memory::desc> src_md;
// Softmax primitive
@@ -106,23 +103,23 @@
// Softmax forward primitive setup
void Setup(const MklSoftmaxParams& fwdParams) {
- // create memory descriptors for softmax data with specified format
+ // Create memory descriptors for softmax data with specified format
context_.src_md.reset(new memory::desc({fwdParams.src_dims},
MklDnnType<T>(), fwdParams.src_fmt));
- // create a softmax
+ // Create a softmax
context_.fwd_desc.reset(new mkldnn::softmax_forward::desc(
prop_kind::forward_scoring, *context_.src_md, fwdParams.axis));
context_.fwd_pd.reset(new mkldnn::softmax_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
- // create memory primitive based on dummy data
+ // Create memory primitive based on dummy data
context_.src_mem.reset(
new memory({*context_.src_md, cpu_engine_}, DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
- // create softmax primitive and add it to net
+ // Create softmax primitive and add it to net
context_.softmax_fwd.reset(new mkldnn::softmax_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
@@ -137,11 +134,11 @@
class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklSoftmaxPrimitive<T>* Get(const MklSoftmaxParams& fwdParams) {
- MklSoftmaxPrimitive<T>* softmax_forward = nullptr;
-
// Get a softmax fwd primitive from the cached pool
- softmax_forward = static_cast<MklSoftmaxPrimitive<T>*>(
- MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(fwdParams));
+ MklSoftmaxPrimitive<T>* softmax_forward =
+ static_cast<MklSoftmaxPrimitive<T>*>(
+ MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(
+ fwdParams));
if (softmax_forward == nullptr) {
softmax_forward = new MklSoftmaxPrimitive<T>(fwdParams);
MklSoftmaxPrimitiveFactory<T>::GetInstance().SetSoftmaxFwd(
@@ -164,7 +161,9 @@
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(fwdParams.src_dims);
- key_creator.AddAsKey(fwdParams.axis);
+ key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_fmt));
+ key_creator.AddAsKey<int>(fwdParams.axis);
+
return key_creator.GetKey();
}
@@ -252,29 +251,27 @@
return;
}
- // If input is in MKL layout, then simply grab input layout; otherwise,
- // construct input Tf layout. For TF layout, although input shape
- // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
- // layout
+ // If input is in MKL layout, then simply get the format from input;
+ // otherwise, use TF layout defined before.
auto src_fmt = src_mkl_shape.IsMklTensor()
? static_cast<mkldnn::memory::format>(
src_mkl_shape.GetMklLayout().data.format)
: layout_type;
- // get a softmax fwd from primitive pool
+ // Get a softmax fwd from primitive pool
MklSoftmaxParams fwdParams(src_dims, src_fmt, axis);
MklSoftmaxPrimitive<T>* softmax_fwd =
MklSoftmaxPrimitiveFactory<T>::Get(fwdParams);
- // add: output
+ // Add output
Tensor* output_tensor = nullptr;
MklDnnShape output_mkl_shape;
TensorShape output_tf_shape; // shape of output TF tensor.
auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_primitive_desc();
- // if input is MKL shape, output is also MKL shape.
- // if input is TF shape, output is also TF shape
+ // If input is MKL shape, output is also MKL shape.
+ // If input is TF shape, output is also TF shape.
if (src_mkl_shape.IsMklTensor()) {
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
@@ -292,7 +289,7 @@
const T* src_data = src_tensor.flat<T>().data();
T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
- // execute softmax
+ // Execute softmax
softmax_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +