blob: f35453f267e06823172d94d13bd01b6403c83128 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
template <typename T, typename Tlabel>
class DecodeLibsvmOp : public OpKernel {
public:
explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_));
OP_REQUIRES(ctx, (num_features_ >= 1),
errors::InvalidArgument("Invalid number of features \"",
num_features_, "\""));
}
void Compute(OpKernelContext* ctx) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<tstring>();
Tensor* label_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor));
auto label = label_tensor->flat<Tlabel>();
std::vector<T> out_values;
std::vector<std::pair<int64, int64>> out_indices;
for (int i = 0; i < input_flat.size(); ++i) {
StringPiece line(input_flat(i));
str_util::RemoveWhitespaceContext(&line);
StringPiece piece;
OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece),
errors::InvalidArgument("No label found for input[", i,
"]: \"", input_flat(i), "\""));
Tlabel label_value;
OP_REQUIRES(ctx,
strings::SafeStringToNumeric<Tlabel>(piece, &label_value),
errors::InvalidArgument("Label format incorrect: ", piece));
label(i) = label_value;
str_util::RemoveLeadingWhitespace(&line);
while (str_util::ConsumeNonWhitespace(&line, &piece)) {
size_t p = piece.find(':');
OP_REQUIRES(ctx, (p != StringPiece::npos),
errors::InvalidArgument("Invalid feature \"", piece, "\""));
int64 feature_index;
OP_REQUIRES(
ctx, strings::safe_strto64(piece.substr(0, p), &feature_index),
errors::InvalidArgument("Feature format incorrect: ", piece));
OP_REQUIRES(ctx, (feature_index >= 0),
errors::InvalidArgument(
"Feature index should be >= 0, got ", feature_index));
T feature_value;
OP_REQUIRES(
ctx,
strings::SafeStringToNumeric<T>(piece.substr(p + 1),
&feature_value),
errors::InvalidArgument("Feature format incorrect: ", piece));
out_values.emplace_back(feature_value);
out_indices.emplace_back(std::pair<int64, int64>(i, feature_index));
str_util::RemoveLeadingWhitespace(&line);
}
}
Tensor* indices_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
1,
TensorShape({static_cast<int64>(out_indices.size()),
input_tensor->shape().dims() + 1}),
&indices_tensor));
auto indices = indices_tensor->matrix<int64>();
// Translate flat index to shaped index like np.unravel_index
// Calculate factors for each dimension
std::vector<int64> factors(input_tensor->shape().dims());
factors[input_tensor->shape().dims() - 1] = 1;
for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) {
factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1);
}
for (int i = 0; i < out_indices.size(); i++) {
indices(i, 0) = out_indices[i].first;
int64 value = out_indices[i].first;
for (int j = 0; j < input_tensor->shape().dims(); j++) {
indices(i, j) = value / factors[j];
value = value % factors[j];
}
indices(i, input_tensor->shape().dims()) = out_indices[i].second;
}
Tensor* values_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(
2, TensorShape({static_cast<int64>(out_values.size())}),
&values_tensor));
auto values = values_tensor->vec<T>();
std::copy_n(out_values.begin(), out_values.size(), &values(0));
Tensor* shape_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
3, TensorShape({input_tensor->shape().dims() + 1}),
&shape_tensor));
auto shape = shape_tensor->flat<int64>();
for (int i = 0; i < input_tensor->shape().dims(); i++) {
shape(i) = input_tensor->shape().dim_size(i);
}
shape(input_tensor->shape().dims()) = num_features_;
}
private:
int64 num_features_;
};
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<int32>("label_dtype"), \
DecodeLibsvmOp<type, int32>); \
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<int64>("label_dtype"), \
DecodeLibsvmOp<type, int64>); \
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<float>("label_dtype"), \
DecodeLibsvmOp<type, float>); \
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<double>("label_dtype"), \
DecodeLibsvmOp<type, double>);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
REGISTER_KERNEL(int32);
REGISTER_KERNEL(int64);
#undef REGISTER_KERNEL
} // namespace tensorflow