Revert D6035393: [caffe2] expose observers to python, add multiple observers per observable
Summary:
This reverts commit 4563cf0203095fa979bb2160621cd16dd22ff830
bypass-lint
Differential Revision: D6035393
fbshipit-source-id: 090fba774ce433904f7ef769dda75c2fbbf784a8
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 07d77a4..51779be 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -28,7 +28,6 @@
add_subdirectory(mkl)
add_subdirectory(mobile)
add_subdirectory(mpi)
-add_subdirectory(observers)
add_subdirectory(operators)
add_subdirectory(perfkernels)
add_subdirectory(python)
diff --git a/caffe2/contrib/CMakeLists.txt b/caffe2/contrib/CMakeLists.txt
index f571ece..bcaccba 100644
--- a/caffe2/contrib/CMakeLists.txt
+++ b/caffe2/contrib/CMakeLists.txt
@@ -2,6 +2,7 @@
add_subdirectory(gloo)
add_subdirectory(nccl)
add_subdirectory(nnpack)
+add_subdirectory(observers)
add_subdirectory(shm_mutex)
# Finally pass the src lists back to the parent
diff --git a/caffe2/observers/CMakeLists.txt b/caffe2/contrib/observers/CMakeLists.txt
similarity index 100%
rename from caffe2/observers/CMakeLists.txt
rename to caffe2/contrib/observers/CMakeLists.txt
diff --git a/caffe2/observers/time_observer.cc b/caffe2/contrib/observers/time_observer.cc
similarity index 74%
rename from caffe2/observers/time_observer.cc
rename to caffe2/contrib/observers/time_observer.cc
index 630d632..7680f00 100644
--- a/caffe2/observers/time_observer.cc
+++ b/caffe2/contrib/observers/time_observer.cc
@@ -14,21 +14,23 @@
* limitations under the License.
*/
-#include "time_observer.h"
+#include "caffe2/contrib/observers/time_observer.h"
#include "caffe2/core/logging.h"
namespace caffe2 {
template <>
-bool TimeObserverBase<NetBase>::Start() {
- CAFFE_THROW(
- "This function is overridden by TimeObserver<NetBase>.\
- If it was called there is an issue with compilation.");
- return false;
+bool TimeObserver<NetBase>::Start() {
+ for (auto* op : subject_->GetOperators()) {
+ op->SetObserver(caffe2::make_unique<TimeObserver<OperatorBase>>(op));
+ }
+ start_time_ = timer_.MilliSeconds();
+ ++iterations_;
+ return true;
}
template <>
-bool TimeObserverBase<NetBase>::Stop() {
+bool TimeObserver<NetBase>::Stop() {
double current_run = timer_.MilliSeconds() - start_time_;
total_time_ += current_run;
VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n";
@@ -36,14 +38,14 @@
}
template <>
-bool TimeObserverBase<OperatorBase>::Start() {
+bool TimeObserver<OperatorBase>::Start() {
start_time_ = timer_.MilliSeconds();
++iterations_;
return true;
}
template <>
-bool TimeObserverBase<OperatorBase>::Stop() {
+bool TimeObserver<OperatorBase>::Stop() {
double current_run = timer_.MilliSeconds() - start_time_;
total_time_ += current_run;
VLOG(1) << "This operator iteration took " << current_run
diff --git a/caffe2/contrib/observers/time_observer.h b/caffe2/contrib/observers/time_observer.h
new file mode 100644
index 0000000..ea7be4c
--- /dev/null
+++ b/caffe2/contrib/observers/time_observer.h
@@ -0,0 +1,59 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
+#define CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
+
+#include <unordered_map>
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/observer.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/timer.h"
+
+namespace caffe2 {
+
+template <class T>
+class TimeObserver final : public ObserverBase<T> {
+ public:
+ explicit TimeObserver<T>(T* subject) : ObserverBase<T>(subject) {}
+ inline float average_time() const {
+ return total_time_ / iterations_;
+ }
+ float average_time_children() const {
+ float sum = 0.0f;
+ for (auto* op : this->subject_->GetOperators()) {
+ auto* observer =
+ dynamic_cast_if_rtti<TimeObserver<OperatorBase>*>(op->GetObserver());
+ sum += observer->average_time();
+ }
+ return sum / this->subject_->GetOperators().size();
+ }
+ ~TimeObserver() {}
+
+ private:
+ Timer timer_;
+ float start_time_ = 0.0f;
+ float total_time_ = 0.0f;
+ int iterations_ = 0;
+
+ bool Start() override;
+ bool Stop() override;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
diff --git a/caffe2/observers/time_observer_test.cc b/caffe2/contrib/observers/time_observer_test.cc
similarity index 89%
rename from caffe2/observers/time_observer_test.cc
rename to caffe2/contrib/observers/time_observer_test.cc
index 6233631..b6937b3 100644
--- a/caffe2/observers/time_observer_test.cc
+++ b/caffe2/contrib/observers/time_observer_test.cc
@@ -14,11 +14,11 @@
* limitations under the License.
*/
+#include "caffe2/contrib/observers/time_observer.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator.h"
-#include "time_observer.h"
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
@@ -33,9 +33,13 @@
public:
using OperatorBase::OperatorBase;
bool Run(int /* unused */) override {
- StartAllObservers();
+ if (observer_) {
+ observer_->Start();
+ }
std::this_thread::sleep_for(std::chrono::milliseconds(3000));
- StopAllObservers();
+ if (observer_) {
+ observer_->Stop();
+ }
return true;
}
};
@@ -76,9 +80,9 @@
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
unique_ptr<TimeObserver<NetBase>> net_ob =
make_unique<TimeObserver<NetBase>>(net.get());
- const auto* ob = dynamic_cast_if_rtti<const TimeObserver<NetBase>*>(
- net->AttachObserver(std::move(net_ob)));
+ net->SetObserver(std::move(net_ob));
net->Run();
+ auto* ob = dynamic_cast_if_rtti<TimeObserver<NetBase>*>(net->GetObserver());
CAFFE_ENFORCE(ob);
LOG(INFO) << "av time children: " << ob->average_time_children();
LOG(INFO) << "av time: " << ob->average_time();
diff --git a/caffe2/core/net.cc b/caffe2/core/net.cc
index 36fd3bb..27b2f5e 100644
--- a/caffe2/core/net.cc
+++ b/caffe2/core/net.cc
@@ -124,7 +124,7 @@
}
VLOG(1) << "Adding a global observer to a net";
if (net) {
- net->AttachObserver(GlobalNetObserverCreator(net.get()));
+ net->SetObserver(GlobalNetObserverCreator(net.get()));
}
return net;
}
diff --git a/caffe2/core/net.h b/caffe2/core/net.h
index 7baaa50..c30f752 100644
--- a/caffe2/core/net.h
+++ b/caffe2/core/net.h
@@ -45,10 +45,9 @@
class OperatorBase;
class Workspace;
-
// Net is a thin struct that owns all the operators together with the operator
// contexts.
-class NetBase : public Observable<NetBase> {
+class NetBase {
public:
NetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
virtual ~NetBase() noexcept {}
@@ -100,6 +99,18 @@
*/
virtual vector<OperatorBase*> GetOperators() const = 0;
+ void SetObserver(std::unique_ptr<NetObserver> observer) {
+ observer_ = std::move(observer);
+ }
+
+ void RemoveObserver() {
+ observer_ = nullptr;
+ }
+
+ NetObserver* GetObserver() {
+ return observer_.get();
+ }
+
const string& Name() const {
return name_;
}
@@ -108,6 +119,7 @@
vector<string> external_input_;
vector<string> external_output_;
string name_;
+ std::unique_ptr<NetObserver> observer_;
vector<const Event*> events_;
DISABLE_COPY_AND_ASSIGN(NetBase);
diff --git a/caffe2/core/net_dag.cc b/caffe2/core/net_dag.cc
index da27eb2..b342056 100644
--- a/caffe2/core/net_dag.cc
+++ b/caffe2/core/net_dag.cc
@@ -434,8 +434,9 @@
}
bool DAGNetBase::RunAsync() {
- StartAllObservers();
-
+ if (observer_) {
+ observer_->Start();
+ }
// Lock run_in_progress_ to prevent concurrent Run()s.
std::unique_lock<std::mutex> run_lock(run_in_progress_);
VLOG(1) << "Running parallel net.";
@@ -498,8 +499,9 @@
op.operator_->debug_def().type(),
") has some runtime parents left.");
}
-
- StopAllObservers();
+ if (observer_) {
+ observer_->Stop();
+ }
// If the above while loop finished, we know that the current run finished.
return success_;
}
diff --git a/caffe2/core/net_simple.cc b/caffe2/core/net_simple.cc
index 07e04b4..1e95d11 100644
--- a/caffe2/core/net_simple.cc
+++ b/caffe2/core/net_simple.cc
@@ -58,8 +58,9 @@
}
bool SimpleNet::RunAsync() {
- StartAllObservers();
-
+ if (observer_) {
+ observer_->Start();
+ }
const auto& net_name = name_.c_str();
VLOG(1) << "Running net " << name_;
for (auto& op : operators_) {
@@ -76,7 +77,9 @@
return false;
}
}
- StopAllObservers();
+ if (observer_) {
+ observer_->Stop();
+ }
return true;
}
diff --git a/caffe2/core/net_simple_async.cc b/caffe2/core/net_simple_async.cc
index f6de367..3de39c1 100644
--- a/caffe2/core/net_simple_async.cc
+++ b/caffe2/core/net_simple_async.cc
@@ -59,8 +59,9 @@
}
bool AsyncSimpleNet::RunAsync() {
- StartAllObservers();
-
+ if (observer_) {
+ observer_->Start();
+ }
const auto& net_name = name_.c_str();
VLOG(1) << "Running net " << name_;
for (auto& op : operators_) {
@@ -77,7 +78,9 @@
return false;
}
}
- StopAllObservers();
+ if (observer_) {
+ observer_->Stop();
+ }
return true;
}
diff --git a/caffe2/core/observer.h b/caffe2/core/observer.h
index a2416e7..58f8270 100644
--- a/caffe2/core/observer.h
+++ b/caffe2/core/observer.h
@@ -44,46 +44,4 @@
T* subject_;
};
-/**
- * Inherit to make your class observable.
- */
-template <class T>
-class Observable {
- public:
- using Observer = ObserverBase<T>;
-
- /* Returns a reference to the observer after addition. */
- const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
- const Observer* weak_observer = observer.get();
- observers_[weak_observer] = std::move(observer);
- return weak_observer;
- }
-
- /* Returns a unique_ptr to the observer. */
- std::unique_ptr<Observer> DetachObserver(const Observer* observer) {
- std::unique_ptr<Observer> strong_observer = std::move(observers_[observer]);
- observers_.erase(observer);
- return strong_observer;
- }
-
- size_t NumObservers() {
- return observers_.size();
- }
-
- void StartAllObservers() {
- for (const auto& observer : observers_) {
- observer.second->Start();
- }
- }
-
- void StopAllObservers() {
- for (const auto& observer : observers_) {
- observer.second->Stop();
- }
- }
-
- protected:
- std::map<const Observer*, std::unique_ptr<ObserverBase<T>>> observers_;
-};
-
} // namespace caffe2
diff --git a/caffe2/core/observer_test.cc b/caffe2/core/observer_test.cc
index b069b22..43f2449 100644
--- a/caffe2/core/observer_test.cc
+++ b/caffe2/core/observer_test.cc
@@ -45,7 +45,7 @@
bool DummyObserver<NetBase>::Start() {
vector<OperatorBase*> operators = subject_->GetOperators();
for (auto& op : operators) {
- op->AttachObserver(caffe2::make_unique<DummyObserver<OperatorBase>>(op));
+ op->SetObserver(caffe2::make_unique<DummyObserver<OperatorBase>>(op));
}
counter.fetch_add(1000);
return true;
@@ -73,8 +73,10 @@
public:
using OperatorBase::OperatorBase;
bool Run(int /* unused */) override {
- StartAllObservers();
- StopAllObservers();
+ if (observer_)
+ observer_->Start();
+ if (observer_)
+ observer_->Stop();
return true;
}
};
@@ -120,29 +122,12 @@
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<SimpleNet*>(net.get()), net.get());
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
- net.get()->AttachObserver(std::move(net_ob));
+ net.get()->SetObserver(std::move(net_ob));
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
-TEST(ObserverTest, TestUniqueMap) {
- auto count_before = counter.load();
- Workspace ws;
- ws.CreateBlob("in");
- NetDef net_def;
- unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
- EXPECT_EQ(caffe2::dynamic_cast_if_rtti<SimpleNet*>(net.get()), net.get());
- unique_ptr<DummyObserver<NetBase>> net_ob =
- make_unique<DummyObserver<NetBase>>(net.get());
- auto* ref = net.get()->AttachObserver(std::move(net_ob));
- net.get()->Run();
- unique_ptr<Observable<NetBase>::Observer> test =
- net.get()->DetachObserver(ref);
- auto count_after = counter.load();
- EXPECT_EQ(1212, count_after - count_before);
-}
-
TEST(ObserverTest, TestNotifyAfterDetach) {
auto count_before = counter.load();
Workspace ws;
@@ -151,8 +136,8 @@
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
- auto* ob = net.get()->AttachObserver(std::move(net_ob));
- net.get()->DetachObserver(ob);
+ net.get()->SetObserver(std::move(net_ob));
+ net.get()->RemoveObserver();
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(0, count_after - count_before);
@@ -167,35 +152,9 @@
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<DAGNetBase*>(net.get()), net.get());
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
- net.get()->AttachObserver(std::move(net_ob));
+ net.get()->SetObserver(std::move(net_ob));
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
-
-TEST(ObserverTest, TestMultipleNetBase) {
- Workspace ws;
- ws.CreateBlob("in");
- NetDef net_def;
- unique_ptr<NetBase> net(CreateNetTestHelper(&ws, true));
- EXPECT_EQ(caffe2::dynamic_cast_if_rtti<NetBase*>(net.get()), net.get());
-
- // There may be some default observers
- const size_t prev_num = net.get()->NumObservers();
- const int num_tests = 100;
- vector<const Observable<NetBase>::Observer*> observers;
- for (int i = 0; i < num_tests; ++i) {
- unique_ptr<DummyObserver<NetBase>> net_ob =
- make_unique<DummyObserver<NetBase>>(net.get());
- observers.emplace_back(net.get()->AttachObserver(std::move(net_ob)));
- }
-
- net.get()->Run();
-
- for (const auto& observer : observers) {
- net.get()->DetachObserver(observer);
- }
-
- EXPECT_EQ(net.get()->NumObservers(), prev_num);
}
-} // namespace caffe2
diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h
index 809d8b9..0715f4f 100644
--- a/caffe2/core/operator.h
+++ b/caffe2/core/operator.h
@@ -38,10 +38,7 @@
namespace caffe2 {
-class OperatorBase;
-typedef ObserverBase<OperatorBase> OperatorObserver;
-
-class OperatorBase : public Observable<OperatorBase> {
+class OperatorBase {
public:
explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
virtual ~OperatorBase() noexcept {}
@@ -188,6 +185,14 @@
}
public:
+ void SetObserver(std::unique_ptr<ObserverBase<OperatorBase>> observer) {
+ observer_ = std::move(observer);
+ }
+
+ void RemoveObserver() {
+ observer_ = nullptr;
+ }
+
void RecordLastFailedOpNetPosition() {
if (net_position_ != kNoNetPositionSet) {
VLOG(1) << "Operator with id " << net_position_ << " failed";
@@ -229,6 +234,14 @@
public:
static constexpr int kNoNetPositionSet = -1;
+ ObserverBase<OperatorBase>* GetObserver() {
+ return observer_.get();
+ }
+
+ const ObserverBase<OperatorBase>* GetObserver() const {
+ return observer_.get();
+ }
+
private:
Workspace* operator_ws_;
std::shared_ptr<const OperatorDef> operator_def_;
@@ -240,6 +253,7 @@
int net_position_{kNoNetPositionSet};
protected:
+ std::unique_ptr<ObserverBase<OperatorBase>> observer_;
// An event used by asynchronous execution.
Event event_;
@@ -310,16 +324,18 @@
// instead of Run().
bool Run(int stream_id = 0) final {
try {
- StartAllObservers();
-
+ if (observer_) {
+ observer_->Start();
+ }
context_.SwitchToDevice(stream_id);
bool result = RunOnDevice();
if (!result) {
this->RecordLastFailedOpNetPosition();
}
context_.FinishDeviceComputation(); // throws on error
-
- StopAllObservers();
+ if (observer_) {
+ observer_->Stop();
+ }
return result;
} catch (EnforceNotMet& err) {
diff --git a/caffe2/observers/README.md b/caffe2/observers/README.md
deleted file mode 100644
index 3b85ffc..0000000
--- a/caffe2/observers/README.md
+++ /dev/null
@@ -1,36 +0,0 @@
-# Observers
-
-## Usage
-
-Observers are a small framework that allow users to attach code to the execution of SimpleNets and Operators.
-
-An example of an Observer is the `TimeObserver`, used as follows:
-
-### C++
-
-```
-unique_ptr<TimeObserver<NetBase>> net_ob =
- make_unique<TimeObserver<NetBase>>(net.get());
-auto* ob = net->AttachObserver(std::move(net_ob));
-net->Run();
-LOG(INFO) << "av time children: " << ob->average_time_children();
-LOG(INFO) << "av time: " << ob->average_time();
-```
-
-### Python
-
-```
-model.net.AttachTimeObserver()
-ws.RunNet(model.net)
-ob = model.net.GetObserver()
-
-print("av time children:", ob.average_time_children())
-print("av time:", ob.average_time())
-```
-
-
-## Implementing An Observer
-
-To implement an observer you must inherit from `ObserverBase` and implement the `Start` and `Stop` functions.
-
-Observers are instantiated with a `subject` of a generic type, such as a `Net` or `Operator`. The observer framework is built to be generic enough to "observe" various other types, however.
diff --git a/caffe2/observers/time_observer.h b/caffe2/observers/time_observer.h
deleted file mode 100644
index 70eeef3..0000000
--- a/caffe2/observers/time_observer.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/**
- * Copyright (c) 2016-present, Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
-#define CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
-
-#include <unordered_map>
-
-#include "caffe2/core/common.h"
-#include "caffe2/core/net.h"
-#include "caffe2/core/observer.h"
-#include "caffe2/core/operator.h"
-#include "caffe2/core/timer.h"
-
-namespace caffe2 {
-
-template <class T>
-class TimeObserverBase : public ObserverBase<T> {
- public:
- explicit TimeObserverBase<T>(T* subject) : ObserverBase<T>(subject) {}
- inline float average_time() const {
- return total_time_ / iterations_;
- }
- ~TimeObserverBase() {}
-
- bool Start() override;
- bool Stop() override;
-
- protected:
- Timer timer_;
- float start_time_ = 0.0f;
- float total_time_ = 0.0f;
- int iterations_ = 0;
-};
-
-template <class T>
-class TimeObserver final : public TimeObserverBase<T> {
- public:
- explicit TimeObserver<T>(T* subject) : TimeObserverBase<T>(subject) {}
-};
-
-template <>
-class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> {
- public:
- explicit TimeObserver<NetBase>(NetBase* subject)
- : TimeObserverBase<NetBase>(subject) {}
- float average_time_children() const {
- float sum = 0.0f;
- for (const auto* observer : operator_observers_) {
- sum += observer->average_time();
- }
- return sum / subject_->GetOperators().size();
- }
-
- bool Start() override {
- for (auto* op : subject_->GetOperators()) {
- const auto* observer = op->AttachObserver(
- caffe2::make_unique<TimeObserver<OperatorBase>>(op));
- CAFFE_ENFORCE(observer != nullptr);
- operator_observers_.push_back(
- dynamic_cast_if_rtti<const TimeObserver<OperatorBase>*>(observer));
- }
- start_time_ = timer_.MilliSeconds();
- ++iterations_;
- return true;
- }
-
- private:
- vector<const TimeObserver<OperatorBase>*> operator_observers_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
diff --git a/caffe2/python/core.py b/caffe2/python/core.py
index 0246b23..4d59b1e 100644
--- a/caffe2/python/core.py
+++ b/caffe2/python/core.py
@@ -1875,16 +1875,6 @@
* [ScopedBlobReference(b) for b in outputs]
)
- # This returns a reference to the observer
- def AddObserver(self, observer_type):
- return C.add_observer_to_net(self._net.name, observer_type)
-
- def RemoveObserver(self, observer):
- C.remove_observer_from_net(self._net.name, observer)
-
- def NumObservers(self):
- return C.num_observers_on_net(self._net.name)
-
@property
def external_inputs(self):
return [_get_blob_ref(x) for x in self._net.external_input]
diff --git a/caffe2/python/observer_test.py b/caffe2/python/observer_test.py
deleted file mode 100644
index 2184dc9..0000000
--- a/caffe2/python/observer_test.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-import numpy as np
-import unittest
-
-from caffe2.python import model_helper, brew
-import caffe2.python.workspace as ws
-
-
-class TestObservers(unittest.TestCase):
- def setUp(self):
- ws.ResetWorkspace()
- self.model = model_helper.ModelHelper()
- brew.fc(self.model, "data", "y",
- dim_in=4, dim_out=2,
- weight_init=('ConstantFill', dict(value=1.0)),
- bias_init=('ConstantFill', dict(value=0.0)),
- axis=0)
- ws.FeedBlob("data", np.zeros([4], dtype='float32'))
-
- ws.RunNetOnce(self.model.param_init_net)
- ws.CreateNet(self.model.net)
-
- def testObserver(self):
- ob = self.model.net.AddObserver("TimeObserver")
- ws.RunNet(self.model.net)
- print(ob.average_time())
- num = self.model.net.NumObservers()
- self.model.net.RemoveObserver(ob)
- assert(self.model.net.NumObservers() + 1 == num)
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 8c54ce9..cfaa786 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -25,7 +25,6 @@
#include "caffe2/core/predictor.h"
#include "caffe2/core/transform.h"
#include "caffe2/mkl/mkl_utils.h"
-#include "caffe2/observers/time_observer.h"
#include "caffe2/utils/cpuid.h"
#include "caffe2/utils/string_utils.h"
#include "google/protobuf/io/coded_stream.h"
@@ -377,21 +376,6 @@
CAFFE_ENFORCE(net->Run());
});
- py::class_<ObserverBase<NetBase>>(m, "Observer")
- .def(
- "average_time",
- [](ObserverBase<NetBase>* ob) {
- auto* cast_ob = dynamic_cast_if_rtti<TimeObserver<NetBase>*>(ob);
- CAFFE_ENFORCE(
- cast_ob, "Observer does not implement this function.");
- return cast_ob->average_time();
- })
- .def("average_time_children", [](ObserverBase<NetBase>* ob) {
- auto* cast_ob = dynamic_cast_if_rtti<TimeObserver<NetBase>*>(ob);
- CAFFE_ENFORCE(cast_ob, "Observer does not implement this function.");
- return cast_ob->average_time_children();
- });
-
py::class_<Blob>(m, "Blob")
.def(
"serialize",
@@ -902,51 +886,6 @@
return true;
});
m.def(
- "add_observer_to_net",
- [](const std::string& net_name, const std::string& observer_type) {
- CAFFE_ENFORCE(gWorkspace);
- CAFFE_ENFORCE(
- gWorkspace->GetNet(net_name), "Can't find net ", net_name);
- py::gil_scoped_release g;
-
- NetBase* net = gWorkspace->GetNet(net_name);
- const Observable<NetBase>::Observer* observer = nullptr;
-
-#define REGISTER_PYTHON_EXPOSED_OBSERVER(ob_type) \
- { \
- if (observer_type.compare(#ob_type) == 0) { \
- unique_ptr<ob_type<NetBase>> net_ob = \
- make_unique<ob_type<NetBase>>(net); \
- observer = net->AttachObserver(std::move(net_ob)); \
- } \
- }
-
- REGISTER_PYTHON_EXPOSED_OBSERVER(TimeObserver);
-
-#undef REGISTER_PYTHON_EXPOSED_OBSERVER
- CAFFE_ENFORCE(observer != nullptr);
- return py::cast(observer);
- });
- m.def(
- "remove_observer_from_net",
- [](const std::string& net_name, const ObserverBase<NetBase>* observer) {
- CAFFE_ENFORCE(gWorkspace);
- CAFFE_ENFORCE(
- gWorkspace->GetNet(net_name), "Can't find net ", net_name);
- py::gil_scoped_release g;
-
- NetBase* net = gWorkspace->GetNet(net_name);
- net->DetachObserver(observer);
- });
- m.def("num_observers_on_net", [](const std::string& net_name) {
- CAFFE_ENFORCE(gWorkspace);
- CAFFE_ENFORCE(gWorkspace->GetNet(net_name), "Can't find net ", net_name);
- py::gil_scoped_release g;
-
- NetBase* net = gWorkspace->GetNet(net_name);
- return net->NumObservers();
- });
- m.def(
"benchmark_net",
[](const std::string& name,
size_t warmup_runs,
diff --git a/caffe2/share/contrib/observers/perf_observer.cc b/caffe2/share/contrib/observers/perf_observer.cc
index ad8ce83..000883b 100644
--- a/caffe2/share/contrib/observers/perf_observer.cc
+++ b/caffe2/share/contrib/observers/perf_observer.cc
@@ -50,8 +50,7 @@
whenever we measure operator delay */
const auto& operators = subject_->GetOperators();
for (auto* op : operators) {
- observerMap_[op] = op->AttachObserver(
- caffe2::make_unique<PerfOperatorObserver>(op, this));
+ op->SetObserver(caffe2::make_unique<PerfOperatorObserver>(op, this));
}
}
@@ -73,7 +72,7 @@
for (int idx = 0; idx < operators.size(); ++idx) {
const auto* op = operators[idx];
auto name = getObserverName(op, idx);
- double delay = static_cast<const PerfOperatorObserver*>(observerMap_[op])
+ double delay = static_cast<const PerfOperatorObserver*>(op->GetObserver())
->getMilliseconds();
std::pair<std::string, double> name_delay_pair = {name, delay};
operator_delays.push_back(name_delay_pair);
@@ -83,9 +82,8 @@
/* clear all operator delay after use so that we don't spent time
collecting the operator delay info in later runs */
for (auto* op : operators) {
- op->DetachObserver(observerMap_[op]);
+ op->RemoveObserver();
}
- observerMap_.clear();
}
return true;
}
diff --git a/caffe2/share/contrib/observers/perf_observer.h b/caffe2/share/contrib/observers/perf_observer.h
index 60702cb..7ac4ed2 100644
--- a/caffe2/share/contrib/observers/perf_observer.h
+++ b/caffe2/share/contrib/observers/perf_observer.h
@@ -1,11 +1,8 @@
#pragma once
#include "caffe2/core/net.h"
-#include "caffe2/core/observer.h"
#include "caffe2/core/timer.h"
-#include <unordered_map>
-
namespace caffe2 {
class PerfNetObserver : public NetObserver {
@@ -31,7 +28,6 @@
};
LogType logType_;
unsigned int numRuns_;
- std::unordered_map<const OperatorBase*, ObserverBase*> observerMap_;
caffe2::Timer timer_;
};