blob: 959d2d878ba10c9a43898905cde6421809bcb11a [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
#include <map>
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#ifndef PyObject_HEAD
struct _object;
typedef _object PyObject;
#endif
namespace tensorflow {
// Global registry mapping C API error codes to the corresponding custom Python
// exception type. This is used to expose the exception types to C extension
// code (i.e. so we can raise custom exceptions via SWIG).
//
// Init() must be called exactly once at the beginning of the process before
// Lookup() can be used.
//
// Example usage:
// TF_Status* status = TF_NewStatus();
// TF_Foo(..., status);
//
// if (TF_GetCode(status) != TF_OK) {
// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
// // Arguments to OpError base class. Set `node_def` and `op` to None.
// PyObject* args =
// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
// PyErr_SetObject(exc_type, args);
// Py_DECREF(args);
// TF_DeleteStatus(status);
// return NULL;
// }
class PyExceptionRegistry {
public:
// Initializes the process-wide registry. Should be called exactly once near
// the beginning of the process. The arguments are the various Python
// exception types (e.g. `cancelled_exc` corresponds to
// errors.CancelledError).
static void Init(PyObject* code_to_exc_type_map);
// Returns the Python exception type corresponding to `code`. Init() must be
// called before using this function. `code` should not be TF_OK.
static PyObject* Lookup(TF_Code code);
static inline PyObject* Lookup(error::Code code) {
return Lookup(static_cast<TF_Code>(code));
}
private:
static PyExceptionRegistry* singleton_;
PyExceptionRegistry() = default;
// Maps error codes to the corresponding Python exception type.
std::map<TF_Code, PyObject*> exc_types_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_