blob: 4075d918da851328ac826c84ede82d45f0b4eaff [file] [log] [blame]
import sys
import os
import torch
testEvalModeForLoadedModule_module_path = 'dropout_model.pt'
def testEvalModeForLoadedModule_setup():
class Model(torch.jit.ScriptModule):
def __init__(self):
super(Model, self).__init__()
self.dropout = torch.nn.Dropout(0.1)
@torch.jit.script_method
def forward(self, x):
x = self.dropout(x)
return x
model = Model()
model = model.train()
model.save(testEvalModeForLoadedModule_module_path)
def testEvalModeForLoadedModule_shutdown():
if os.path.exists(testEvalModeForLoadedModule_module_path):
os.remove(testEvalModeForLoadedModule_module_path)
def setup():
testEvalModeForLoadedModule_setup()
def shutdown():
testEvalModeForLoadedModule_shutdown()
if __name__ == "__main__":
command = sys.argv[1]
if command == "setup":
setup()
elif command == "shutdown":
shutdown()