model_dump: Handle dict rendering (#57657)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57657
Test Plan: Clicked around a model with some dicts in it.
Reviewed By: malfet
Differential Revision: D28531397
Pulled By: dreiss
fbshipit-source-id: 069690f147e91eadd76fec5f5ca4eec057abcb98
diff --git a/torch/utils/model_dump/code.js b/torch/utils/model_dump/code.js
index 4380ead..d5edb80 100644
--- a/torch/utils/model_dump/code.js
+++ b/torch/utils/model_dump/code.js
@@ -124,6 +124,10 @@
// TODO: Maybe show simple lists and tuples on one line.
return true;
}
+ if (data.__is_dict__) {
+ // TODO: Maybe show simple (empty?) dicts on one line.
+ return true;
+ }
if (data.__module_type__) {
return true;
}
@@ -133,7 +137,7 @@
if (data.__qtensor__) {
return false;
}
- throw new Error("TODO: handle dict, etc.");
+ throw new Error("Can't handle data type.", data);
}
renderHeadline(data) {
@@ -159,6 +163,9 @@
if (data.__tuple_values__) {
return "tuple((";
}
+ if (data.__is_dict__) {
+ return "dict({";
+ }
if (data.__module_type__) {
return data.__module_type__ + "()";
}
@@ -181,7 +188,7 @@
return this.renderTensor(
"qtensor", dtype, key, device, numel, offset, size, stride, grad, extra_parts);
}
- throw new Error("TODO: handle dict, etc.");
+ throw new Error("Can't handle data type.", data);
}
renderTensor(
@@ -237,6 +244,18 @@
// Handled the same as lists.
return this.renderBody(indent, data.__tuple_values__);
}
+ if (data.__is_dict__) {
+ let new_indent = indent + "\u00A0\u00A0";
+ let parts = [];
+ for (let idx = 0; idx < data.keys.length; idx++) {
+ if (typeof(data.keys[idx]) != "string") {
+ parts.push(html`<br/>${new_indent}Non-string key`);
+ } else {
+ parts.push(html`<br/><${ModelData} prefix=${data.keys[idx] + ": "} indent=${new_indent} data=${data.values[idx]} />`);
+ }
+ }
+ return parts;
+ }
if (data.__module_type__) {
const mstate = data.state;
if (mstate === null || typeof(mstate) != "object") {
@@ -245,6 +264,7 @@
let new_indent = indent + "\u00A0\u00A0";
let parts = [];
if (mstate.__is_dict__) {
+ // TODO: Less copy/paste between this and normal dicts.
for (let idx = 0; idx < mstate.keys.length; idx++) {
if (typeof(mstate.keys[idx]) != "string") {
parts.push(html`<br/>${new_indent}Non-string key`);
@@ -267,7 +287,7 @@
if (data.__qtensor__) {
throw "Should not reach here."
}
- throw new Error("TODO: handle dict, etc.");
+ throw new Error("Can't handle data type.", data);
}
render({data, indent, prefix}, {shown}) {