blob: 63568685f27486f7a14d6c8935292605a44506f0 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
#define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
// The session state remembers the tensors we choose to keep across
// multiple run calls.
class SessionState {
public:
// Get a tensor from the session state.
Status GetTensor(const string& handle, Tensor* tensor);
// Store a tensor in the session state.
Status AddTensor(const string& handle, const Tensor& tensor);
// Delete a tensdor from the session state.
Status DeleteTensor(const string& handle);
int64 GetNewId();
static const char* kTensorHandleResourceTypeName;
private:
mutex state_lock_;
// For generating unique ids for tensors stored in the session.
int64 tensor_id_ = 0;
// The live tensors in the session. A map from tensor handle to tensor.
std::unordered_map<string, Tensor> tensors_;
};
// The tensor store remembers the tensors we choose to keep for the
// current run call. It is available to every op kernel.
class TensorStore {
public:
struct TensorAndKey {
Tensor tensor;
int64 id;
string device_name;
string GetHandle(const string& tensor_name) {
return strings::StrCat(tensor_name, ";", id, ";", device_name);
}
};
// Add the named tensor to the tensor store for this run.
Status AddTensor(const string& name, const TensorAndKey& tk);
// Save the tensors in the tensor store of this run to the session.
Status SaveTensors(const std::vector<string>& output_names,
SessionState* session_state);
// Returns true if no tensors have been added to this store.
bool empty() {
mutex_lock l(lock_);
return tensors_.empty();
}
private:
mutex lock_;
// The tensors that will be saved to session state when this run completes.
// A map from tensor string name to tensor.
std::unordered_map<string, TensorAndKey> tensors_ GUARDED_BY(lock_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_