| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <stdint.h> |
| #include <cstddef> |
| #include <functional> |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "unicode/appendable.h" // TF:icu |
| #include "unicode/schriter.h" // TF:icu |
| #include "unicode/uchar.h" // TF:icu |
| #include "unicode/ucnv.h" // TF:icu |
| #include "unicode/ucnv_err.h" // TF:icu |
| #include "unicode/umachine.h" // TF:icu |
| #include "unicode/uniset.h" // TF:icu |
| #include "unicode/unistr.h" // TF:icu |
| #include "unicode/uset.h" // TF:icu |
| #include "unicode/utypes.h" // TF:icu |
| #include "tensorflow/core/framework/bounds_check.h" |
| #include "tensorflow/core/framework/kernel_def_builder.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/tensor_types.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/kernels/string_util.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/util/bcast.h" |
| #include "tensorflow/core/util/ptr_util.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| void Encode(const UnicodeEncoding encoding, const icu::UnicodeString& in, |
| string* out) { |
| if (encoding == UnicodeEncoding::UTF8) { |
| out->clear(); |
| in.toUTF8String(*out); |
| } else if (encoding == UnicodeEncoding::UTF16BE) { |
| // TODO(gbillock): consider using the |
| // extract(char *dest, int32_t destCapacity, UConverter *cnv) |
| // for UTF16/32 |
| out->clear(); // subtle: must come before reserve() |
| out->reserve(2 * in.length() + 1); |
| const char16_t* buf = in.getBuffer(); |
| for (int i = 0; i < in.length(); ++i) { |
| // Emit big-endian encoding for UTF-16 always. |
| out->push_back((buf[i] & 0xFF00) >> 8); |
| out->push_back(buf[i] & 0x00FF); |
| } |
| } else if (encoding == UnicodeEncoding::UTF32BE) { |
| out->clear(); // subtle: must come before reserve() |
| out->reserve(4 * in.countChar32() + 1); |
| icu::StringCharacterIterator it(in); |
| UChar32 ch; |
| while (it.hasNext()) { |
| ch = it.next32PostInc(); |
| out->push_back((ch & 0xFF000000) >> 24); |
| out->push_back((ch & 0x00FF0000) >> 16); |
| out->push_back((ch & 0x0000FF00) >> 8); |
| out->push_back((ch & 0x000000FF)); |
| } |
| } |
| } |
| |
| // This error callback is only useful for finding illegal encoding errors when |
| // we want to be strict -- otherwise illegal encodings are replaced on read |
| // with 0xFFFD and signaled to the callback. |
| void unicode_error_callback(const void* context, UConverterToUnicodeArgs* args, |
| const char* codeUnits, int32_t length, |
| UConverterCallbackReason reason, |
| UErrorCode* pErrorCode) { |
| // Careful: this depends on setting up the context settings when the |
| // callback is registered. |
| bool* format_error = const_cast<bool*>(static_cast<const bool*>(context)); |
| |
| if (reason == UCNV_UNASSIGNED || reason == UCNV_ILLEGAL || |
| reason == UCNV_IRREGULAR) { |
| *format_error = true; |
| } |
| |
| // Side note: the default behavior in this case is that without a substitution |
| // made by the callback, the UConverter will signal an error to the iterator |
| // making the string iteration bail out. Instead, forward to the built-in |
| // substitution handler. |
| UCNV_TO_U_CALLBACK_SUBSTITUTE(nullptr, args, codeUnits, length, reason, |
| pErrorCode); |
| } |
| |
| // Iterates through a source string given the provided input UConverter specific |
| // to the encoding for that string. Calls a provided callback for each codepoint |
| // consumed. Provides the callback with the codepoint and the number of bytes |
| // consumed from the input string to produce it. If there are invalid encoding |
| // loci in the source string, they will be provided as a 0xFFFD codepoint to |
| // the callback, unless the "fail_on_formatting_error" arg is set, in which |
| // case the callback will be passed the signal that there is such an invalid |
| // encoding position. |
| // callback: function(UChar32 codepoint, int num_bytes_consumed_from_source_str, |
| // bool fatal_format_error) |
| void IterateUnicodeString(const string& str, UConverter* converter, |
| std::function<void(UChar32, int, bool)> callback) { |
| const char* source = str.data(); |
| const char* limit = str.data() + str.length(); |
| UErrorCode status = U_ZERO_ERROR; |
| |
| UConverterToUCallback oldAction = nullptr; |
| const void* oldContext = nullptr; |
| bool format_error = false; |
| |
| // Subtle. You can't make a function pointer from a std::function. :-( |
| // Instead, we pass the boolean pointer as the "context" object. |
| ucnv_setToUCallBack(converter, unicode_error_callback, &format_error, |
| &oldAction, &oldContext, &status); |
| if (U_FAILURE(status)) { |
| LOG(ERROR) << "Could not set unicode error callback on converter"; |
| return; |
| } |
| |
| while (source < limit) { |
| const char* source_pre_fetch = source; |
| // Note: ucnv_getNextUChar returns 0xFFFD on an encoding error. |
| UChar32 next_char = ucnv_getNextUChar(converter, &source, limit, &status); |
| if (U_FAILURE(status)) { |
| source = limit; |
| } |
| int bytes_consumed = source - source_pre_fetch; |
| callback(next_char, bytes_consumed, format_error); |
| format_error = false; |
| } |
| |
| ucnv_setToUCallBack(converter, oldAction, oldContext, nullptr, nullptr, |
| &status); |
| } |
| |
| // Lifecycle wrapper for UConverter making it easier to use with thread_local. |
| // TODO(gbillock): Consider whether to use the higher-level convert API and |
| // create a specialized fast code path for UTF8. |
| class WrappedConverter { |
| public: |
| WrappedConverter() {} |
| |
| ~WrappedConverter() { |
| if (converter_) { |
| ucnv_close(converter_); |
| } |
| } |
| |
| void init(const string& name) { |
| if (converter_ && name == name_) { |
| // Note: this reset is not typically needed, but if not done, then in some |
| // cases the cached converter will maintain state of input endianness |
| // which isn't valid from input to input in every batched case. |
| ucnv_reset(converter_); |
| return; |
| } |
| |
| if (converter_) { |
| ucnv_close(converter_); |
| converter_ = nullptr; |
| name_ = ""; |
| } |
| |
| UErrorCode status = U_ZERO_ERROR; |
| converter_ = ucnv_open(name.c_str(), &status); |
| if (U_FAILURE(status)) { |
| if (converter_) { |
| ucnv_close(converter_); |
| converter_ = nullptr; |
| } |
| } else { |
| name_ = name; |
| } |
| } |
| |
| UConverter* converter_ = nullptr; |
| string name_; |
| }; |
| |
| struct ErrorOptions { |
| UChar32 subst = 0xFFFD; |
| bool elide_replacement = false; |
| bool replace_control_chars = false; |
| bool error_on_malformatting = false; |
| }; |
| |
| Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) { |
| *out = ErrorOptions(); |
| |
| string error_policy; |
| TF_RETURN_IF_ERROR(ctx->GetAttr("errors", &error_policy)); |
| |
| if (error_policy == "replace") { |
| out->elide_replacement = false; |
| } else if (error_policy == "ignore") { |
| out->elide_replacement = true; |
| } else if (error_policy == "strict") { |
| out->error_on_malformatting = true; |
| } else { |
| return errors::InvalidArgument( |
| "errors policy must be one of 'strict', 'replace', or 'ignore'"); |
| } |
| |
| int32 replacement_char; |
| TF_RETURN_IF_ERROR(ctx->GetAttr("replacement_char", &replacement_char)); |
| |
| if (replacement_char >= UCHAR_MIN_VALUE && |
| replacement_char <= UCHAR_MAX_VALUE) { |
| out->subst = replacement_char; |
| } else { |
| return errors::InvalidArgument( |
| "replacement_char out of unicode codepoint range"); |
| } |
| |
| if (ctx->HasAttr("replace_control_characters")) { |
| TF_RETURN_IF_ERROR(ctx->GetAttr("replace_control_characters", |
| &(out->replace_control_chars))); |
| } |
| |
| return Status::OK(); |
| } |
| |
| inline bool ShouldHandleFormatError(const ErrorOptions& error_options, |
| UChar32 ch, bool format_error) { |
| return ((error_options.replace_control_chars && ch <= 0x1F) || format_error); |
| } |
| |
| } // namespace |
| |
| class UnicodeTranscodeOp : public OpKernel { |
| public: |
| explicit UnicodeTranscodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_)); |
| |
| string output_encoding; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding", &output_encoding)); |
| OP_REQUIRES_OK(ctx, |
| ParseUnicodeEncoding(output_encoding, &output_encoding_)); |
| |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding", &input_encoding_)); |
| // Make a temporary UConverter to ensure it will create without error |
| // at execution time (and to warm any data caches the converter needs). |
| // This instance is not used. |
| std::unique_ptr<WrappedConverter> input_encoder = |
| absl::make_unique<WrappedConverter>(); |
| input_encoder->init(input_encoding_); |
| OP_REQUIRES(ctx, input_encoder->converter_, |
| errors::InvalidArgument( |
| "Could not create converter for input encoding: " + |
| input_encoding_)); |
| } |
| |
| void Compute(OpKernelContext* ctx) override { |
| const Tensor* input_tensor; |
| OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); |
| |
| static thread_local std::unique_ptr<WrappedConverter> input_encoder; |
| if (!input_encoder) { |
| input_encoder.reset(new WrappedConverter()); |
| } |
| input_encoder->init(input_encoding_); |
| OP_REQUIRES(ctx, input_encoder->converter_, |
| errors::InvalidArgument( |
| "Could not create converter for input encoding: " + |
| input_encoding_)); |
| |
| // Output may be forwardable from input, in which case work in-place. |
| Tensor* output_tensor; |
| std::unique_ptr<Tensor> maybe_forwarded = |
| ctx->forward_input(0 /*input_index*/, 0 /*output_index*/, |
| tensorflow::DT_STRING, input_tensor->shape(), |
| ctx->input_memory_type(0), ctx->input_alloc_attr(0)); |
| if (maybe_forwarded) { |
| output_tensor = maybe_forwarded.get(); |
| OP_REQUIRES_OK(ctx, ctx->set_output("output", *output_tensor)); |
| } else { |
| OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), |
| &output_tensor)); |
| output_tensor->flat<tstring>() = input_tensor->flat<tstring>(); |
| } |
| |
| auto output_flat = output_tensor->flat<tstring>(); |
| bool found_any_format_error = false; |
| for (size_t i = 0; i < output_flat.size(); ++i) { |
| Transcode(&(output_flat(i)), input_encoder->converter_, |
| &found_any_format_error); |
| } |
| if (error_options_.error_on_malformatting && found_any_format_error) { |
| ctx->CtxFailure( |
| errors::InvalidArgument("Invalid formatting on input string")); |
| } |
| } |
| |
| private: |
| // Consume a codepoint from the input string and add it to the buffer. |
| // This function takes care of any replacement configuration on invalid or |
| // out-of-range inputs. |
| void TranslateCodepoints(icu::UnicodeString* s, bool* found_any_format_error, |
| UChar32 ch, int src_bytes, bool format_error) { |
| if (ShouldHandleFormatError(error_options_, ch, format_error)) { |
| *found_any_format_error = true; |
| if (error_options_.elide_replacement) { |
| return; |
| } else { |
| ch = error_options_.subst; |
| } |
| } |
| s->append(ch); |
| } |
| |
| // Transcode the string from input encoding to the output_encoding_. If |
| // non-valid characters are encountered, use the subst_/elide_replacement_ |
| // config to handle them. |
| void Transcode(string* s, UConverter* input_encoder, |
| bool* found_any_format_error) { |
| icu::UnicodeString source; |
| IterateUnicodeString( |
| *s, input_encoder, |
| std::bind(&UnicodeTranscodeOp::TranslateCodepoints, this, &source, |
| found_any_format_error, std::placeholders::_1, |
| std::placeholders::_2, std::placeholders::_3)); |
| |
| Encode(output_encoding_, source, s); |
| } |
| |
| string input_encoding_; |
| ErrorOptions error_options_; |
| UnicodeEncoding output_encoding_ = UnicodeEncoding::UTF8; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU), |
| UnicodeTranscodeOp); |
| |
| template <typename SPLITS_TYPE> |
| class UnicodeDecodeBaseOp : public OpKernel { |
| public: |
| explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets) |
| : OpKernel(ctx), generate_offsets_(generate_offsets) { |
| OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding", &input_encoding_)); |
| // Make a temporary UConverter to ensure it will create without error |
| // at execution time (and to warm any data caches the converter needs). |
| // This instance is not used. |
| std::unique_ptr<WrappedConverter> input_encoder = |
| absl::make_unique<WrappedConverter>(); |
| input_encoder->init(input_encoding_); |
| OP_REQUIRES(ctx, input_encoder->converter_, |
| errors::InvalidArgument( |
| "Could not create converter for input encoding: " + |
| input_encoding_)); |
| } |
| |
| void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values, |
| std::vector<SPLITS_TYPE>* offset_values, int* current_offset, |
| SPLITS_TYPE* next_row_split, UChar32 char_value, int char_length, |
| bool found_any_format_error) { |
| if (error_options_.error_on_malformatting && found_any_format_error) { |
| ctx->CtxFailure( |
| errors::InvalidArgument("Invalid formatting on input string")); |
| } |
| UChar32 decoded_value = char_value; |
| if (ShouldHandleFormatError(error_options_, char_value, |
| found_any_format_error)) { |
| if (error_options_.elide_replacement && (offset_values != nullptr)) { |
| *current_offset += char_length; |
| return; |
| } else { |
| decoded_value = error_options_.subst; |
| } |
| } |
| |
| // Emit the char value. |
| char_values->push_back(decoded_value); |
| |
| // Emit the byte offset |
| if (offset_values != nullptr) { |
| offset_values->push_back(*current_offset); |
| *current_offset += char_length; |
| } |
| *next_row_split += 1; |
| } |
| |
| void Compute(OpKernelContext* ctx) override { |
| const Tensor* input_tensor; |
| OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); |
| |
| // Go through all the strings in `input`. |
| const auto& input_vec = input_tensor->flat<tstring>(); |
| |
| std::unique_ptr<WrappedConverter> input_encoder = |
| absl::make_unique<WrappedConverter>(); |
| input_encoder->init(input_encoding_); |
| OP_REQUIRES(ctx, input_encoder->converter_, |
| errors::InvalidArgument( |
| "Could not create converter for input encoding: " + |
| input_encoding_)); |
| |
| std::vector<UChar32> char_values; |
| std::vector<SPLITS_TYPE> offset_values; |
| |
| Tensor* output_row_splits; |
| OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits", |
| {input_tensor->NumElements() + 1}, |
| &output_row_splits)); |
| auto out_row_splits = output_row_splits->vec<SPLITS_TYPE>(); |
| |
| int row_split_index = 0; |
| SPLITS_TYPE next_row_split = 0; |
| for (int i = 0; i < input_vec.size(); ++i) { |
| const string& input = input_vec(i); |
| // Convert input strings into unicode values. Output to a list of |
| // char_values, record row splits and char_to_byte_starts, which are all |
| // the fields needed to construct a RaggedTensor. |
| out_row_splits(row_split_index) = next_row_split; |
| row_split_index++; |
| int current_offset = 0; |
| IterateUnicodeString( |
| input, input_encoder->converter_, |
| std::bind(&UnicodeDecodeBaseOp::Decode, this, ctx, &char_values, |
| &offset_values, ¤t_offset, &next_row_split, |
| std::placeholders::_1, std::placeholders::_2, |
| std::placeholders::_3)); |
| } |
| out_row_splits(row_split_index) = next_row_split; |
| |
| Tensor* output_char_values; |
| OP_REQUIRES_OK( |
| ctx, ctx->allocate_output( |
| "char_values", {static_cast<SPLITS_TYPE>(char_values.size())}, |
| &output_char_values)); |
| auto out_char_values = output_char_values->vec<int32>(); |
| if (generate_offsets_) { |
| DCHECK(offset_values.size() == char_values.size()); |
| Tensor* output_offset_values; |
| OP_REQUIRES_OK(ctx, ctx->allocate_output( |
| "char_to_byte_starts", |
| {static_cast<SPLITS_TYPE>(offset_values.size())}, |
| &output_offset_values)); |
| auto out_offset_values = output_offset_values->vec<SPLITS_TYPE>(); |
| |
| // Load output tensors from intermediate value arrays. |
| for (int i = 0; i < char_values.size(); ++i) { |
| out_char_values(i) = static_cast<int32>(char_values[i]); |
| out_offset_values(i) = offset_values[i]; |
| } |
| } else { |
| for (int i = 0; i < char_values.size(); ++i) { |
| out_char_values(i) = static_cast<int32>(char_values[i]); |
| } |
| } |
| } |
| |
| private: |
| string input_encoding_; |
| ErrorOptions error_options_; |
| bool generate_offsets_ = false; |
| }; |
| |
| template <typename SPLITS_TYPE> |
| class UnicodeDecodeOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> { |
| public: |
| explicit UnicodeDecodeOp(OpKernelConstruction* ctx) |
| : UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, false) {} |
| }; |
| |
| template <typename SPLITS_TYPE> |
| class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> { |
| public: |
| explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx) |
| : UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, true) {} |
| }; |
| |
| REGISTER_KERNEL_BUILDER( |
| Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint<int64>("Tsplits"), |
| UnicodeDecodeOp<int64>); |
| REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets") |
| .Device(DEVICE_CPU) |
| .TypeConstraint<int64>("Tsplits"), |
| UnicodeDecodeWithOffsetsOp<int64>); |
| REGISTER_KERNEL_BUILDER( |
| Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits"), |
| UnicodeDecodeOp<int32>); |
| REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets") |
| .Device(DEVICE_CPU) |
| .TypeConstraint<int32>("Tsplits"), |
| UnicodeDecodeWithOffsetsOp<int32>); |
| |
| template <typename SPLITS_TYPE> |
| class UnicodeEncodeOp : public OpKernel { |
| public: |
| explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
| string encoding_tmp; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding", &encoding_tmp)); |
| OP_REQUIRES_OK(ctx, ParseUnicodeEncoding(encoding_tmp, &encoding_)); |
| OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_)); |
| } |
| |
| /** |
| * Encodes Unicode codepoints into the desired string representation. |
| * |
| * We lose a dimension while encoding, since a series of integer codepoints is |
| * encoded into a single string. |
| * |
| * This accepts two input tensors: a rank 1 tensor of code point values and |
| * a single rank 1 tensor of splits which determine where each string begins |
| * and ends from the provided code points. |
| */ |
| void Compute(OpKernelContext* context) override { |
| // Get inputs |
| const Tensor& input_tensor = context->input(0); |
| const auto input_tensor_flat = input_tensor.flat<int32>(); |
| const Tensor& input_splits = context->input(1); |
| const auto input_splits_flat = input_splits.flat<SPLITS_TYPE>(); |
| |
| // Since we limit to a 2-D input (flat_values of rank 1 and a single splits |
| // tensor), our output dimension will be 1 with it's size equal to the |
| // number of splits (outer dimension or ragged tensor). |
| TensorShape output_shape({input_splits.dim_size(0) - 1}); |
| Tensor* output_tensor; |
| OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, |
| &output_tensor)); |
| auto output_tensor_flat = output_tensor->flat<tstring>(); |
| |
| // Use a single index over the flattened input values tensor. |
| int idx = 0; |
| // Loop through our split dimension to create a new string at each split. |
| for (int i = 1; i < input_splits_flat.size(); ++i) { |
| icu::UnicodeString unicode_string; |
| icu::UnicodeStringAppendable appendable_unicode_string(unicode_string); |
| for (; idx < input_splits_flat(i); ++idx) { |
| int32 code_point = input_tensor_flat(idx); |
| // Check for invalid code point |
| if (code_point > UCHAR_MAX_VALUE || code_point < UCHAR_MIN_VALUE) { |
| if (error_options_.error_on_malformatting) { |
| context->CtxFailure(errors::InvalidArgument( |
| "Code point value out of valid Unicode range.")); |
| return; |
| } else if (!error_options_.elide_replacement) { |
| code_point = error_options_.subst; |
| } |
| } |
| appendable_unicode_string.appendCodePoint(code_point); |
| } |
| // Encode our string and save in the output. |
| string result; |
| Encode(encoding_, unicode_string, &result); |
| output_tensor_flat(i - 1) = result; |
| } |
| } |
| |
| private: |
| UnicodeEncoding encoding_; |
| ErrorOptions error_options_; |
| }; |
| |
| REGISTER_KERNEL_BUILDER( |
| Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint<int64>("Tsplits"), |
| UnicodeEncodeOp<int64>); |
| REGISTER_KERNEL_BUILDER( |
| Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits"), |
| UnicodeEncodeOp<int32>); |
| |
| } // namespace tensorflow |