blob: 1f5d4b9260ce211743150f66261fcbf2bdf3bab8 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
#define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/platform/logging.h"
namespace stream_executor {
// Allows to represent a value that is either a host scalar or a scalar stored
// on the GPU device.
template <typename ElemT>
class HostOrDeviceScalar {
public:
// Not marked as explicit because when using this constructor, we usually want
// to set this to a compile-time constant.
HostOrDeviceScalar(ElemT value) : value_(value), is_pointer_(false) {}
explicit HostOrDeviceScalar(const DeviceMemory<ElemT>& pointer)
: pointer_(pointer), is_pointer_(true) {
CHECK_EQ(1, pointer.ElementCount());
}
bool is_pointer() const { return is_pointer_; }
const DeviceMemory<ElemT>& pointer() const {
CHECK(is_pointer());
return pointer_;
}
const ElemT& value() const {
CHECK(!is_pointer());
return value_;
}
private:
union {
ElemT value_;
DeviceMemory<ElemT> pointer_;
};
bool is_pointer_;
};
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_