More trt tests (#6782)
diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc
index 500edbb..e4aedd1 100644
--- a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc
+++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc
@@ -170,7 +170,7 @@
OperatorDef op;
op.set_type("TensorRT");
- tensorrt::TrtLogger logger;
+ tensorrt::TrtLogger logger((nvinfer1::ILogger::Severity)(verbosity_));
auto trt_builder = tensorrt::TrtObject(nvinfer1::createInferBuilder(logger));
auto trt_network = tensorrt::TrtObject(trt_builder->createNetwork());
auto importer =
diff --git a/caffe2/python/trt/test_trt.py b/caffe2/python/trt/test_trt.py
index 2e32162..c5d1747 100644
--- a/caffe2/python/trt/test_trt.py
+++ b/caffe2/python/trt/test_trt.py
@@ -113,29 +113,63 @@
X = np.random.randn(52, 1, 3, 2).astype(np.float32)
self._test_relu_graph(X, 52, 50)
-
- @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
- def test_resnet50(self):
- input_blob_dims = (1, 3, 224, 224)
- model_dir = _download_onnx_model('resnet50')
+ def _test_onnx_importer(self, model_name, data_input_index = 0):
+ model_dir = _download_onnx_model(model_name)
model_def = onnx.load(os.path.join(model_dir, 'model.onnx'))
+ input_blob_dims = [int(x.dim_value) for x in model_def.graph.input[data_input_index].type.tensor_type.shape.dim]
op_inputs = [x.name for x in model_def.graph.input]
op_outputs = [x.name for x in model_def.graph.output]
- n, c, h, w = input_blob_dims
- data = np.random.randn(n, c, h, w).astype(np.float32)
- Y_c2 = c2.run_model(model_def, {op_inputs[0]: data})
- op = convert_onnx_model_to_trt_op(model_def)
+ print("{}".format(op_inputs))
+ data = np.random.randn(*input_blob_dims).astype(np.float32)
+ Y_c2 = c2.run_model(model_def, {op_inputs[data_input_index]: data})
+ op = convert_onnx_model_to_trt_op(model_def, verbosity=3)
device_option = core.DeviceOption(caffe2_pb2.CUDA, 0)
op.device_option.CopyFrom(device_option)
Y_trt = None
ws = Workspace()
with core.DeviceScope(device_option):
- ws.FeedBlob(op_inputs[0], data)
+ ws.FeedBlob(op_inputs[data_input_index], data)
ws.RunOperatorsOnce([op])
output_values = [ws.FetchBlob(name) for name in op_outputs]
Y_trt = namedtupledict('Outputs', op_outputs)(*output_values)
np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3)
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_resnet50(self):
+ self._test_onnx_importer('resnet50')
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_bvlc_alexnet(self):
+ self._test_onnx_importer('bvlc_alexnet')
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_densenet121(self):
+ self._test_onnx_importer('densenet121', -1)
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_inception_v1(self):
+ self._test_onnx_importer('inception_v1', -1)
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_inception_v2(self):
+ self._test_onnx_importer('inception_v2')
+
+ # Doesn't work yet due to recent change of reshape definition of Reshape node in ONNX
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_shufflenet(self):
+ self._test_onnx_importer('shufflenet')
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_squeezenet(self):
+ self._test_onnx_importer('squeezenet', -1)
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_vgg16(self):
+ self._test_onnx_importer('vgg16')
+
+ @unittest.skipIf('TEST_C2_TRT' not in os.environ, "No TensortRT support")
+ def test_vgg19(self):
+ self._test_onnx_importer('vgg19', -1)
class TensorRTTransformTest(TestCase):
def _model_dir(self, model):
@@ -155,7 +189,7 @@
downloadFromURLToFile(url, dest,
show_progress=False)
except TypeError:
-# show_progress not supported prior to
+ # show_progress not supported prior to
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
# (Sep 17, 2017)
downloadFromURLToFile(url, dest)