| local File = torch.getmetatable('torch.File') |
| |
| function File:writeBool(value) |
| if value then |
| self:writeInt(1) |
| else |
| self:writeInt(0) |
| end |
| end |
| |
| function File:readBool() |
| return (self:readInt() == 1) |
| end |
| |
| local TYPE_NIL = 0 |
| local TYPE_NUMBER = 1 |
| local TYPE_STRING = 2 |
| local TYPE_TABLE = 3 |
| local TYPE_TORCH = 4 |
| local TYPE_BOOLEAN = 5 |
| local TYPE_FUNCTION = 6 |
| |
| function File:isWritableObject(object) |
| local typename = type(object) |
| local typeidx |
| if type(object) ~= 'boolean' and not object then |
| typeidx = TYPE_NIL |
| elseif torch.typename(object) and torch.factory(torch.typename(object)) then |
| typeidx = TYPE_TORCH |
| elseif typename == 'table' then |
| typeidx = TYPE_TABLE |
| elseif typename == 'number' then |
| typeidx = TYPE_NUMBER |
| elseif typename == 'string' then |
| typeidx = TYPE_STRING |
| elseif typename == 'boolean' then |
| typeidx = TYPE_BOOLEAN |
| elseif typename == 'function' and pcall(string.dump, object) then |
| typeidx = TYPE_FUNCTION |
| end |
| return typeidx |
| end |
| |
| function File:writeObject(object) |
| -- we use an environment to keep a record of written objects |
| if not torch.getenv(self).writeObjects then |
| torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) |
| end |
| |
| -- if nil object, only write the type and return |
| if type(object) ~= 'boolean' and not object then |
| self:writeInt(TYPE_NIL) |
| return |
| end |
| |
| -- check the type we are dealing with |
| local typeidx = self:isWritableObject(object) |
| if not typeidx then |
| error(string.format('Unwritable object <%s>', type(object))) |
| end |
| self:writeInt(typeidx) |
| |
| if typeidx == TYPE_NUMBER then |
| self:writeDouble(object) |
| elseif typeidx == TYPE_BOOLEAN then |
| self:writeBool(object) |
| elseif typeidx == TYPE_STRING then |
| local stringStorage = torch.CharStorage():string(object) |
| self:writeInt(#stringStorage) |
| self:writeChar(stringStorage) |
| elseif typeidx == TYPE_FUNCTION then |
| local upvalues = {} |
| while true do |
| local name,value = debug.getupvalue(object, #upvalues+1) |
| if not name then break end |
| table.insert(upvalues, value) |
| end |
| local dumped = string.dump(object) |
| local stringStorage = torch.CharStorage():string(dumped) |
| self:writeInt(#stringStorage) |
| self:writeChar(stringStorage) |
| self:writeObject(upvalues) |
| elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then |
| -- check it exists already (we look at the pointer!) |
| local objects = torch.getenv(self).writeObjects |
| local objectsRef = torch.getenv(self).writeObjectsRef |
| local index = objects[torch.pointer(object)] |
| |
| if index then |
| -- if already exists, write only its index |
| self:writeInt(index) |
| else |
| -- else write the object itself |
| index = objects.nWriteObject or 0 |
| index = index + 1 |
| objects[torch.pointer(object)] = index |
| objectsRef[object] = index -- we make sure the object is not going to disappear |
| self:writeInt(index) |
| objects.nWriteObject = index |
| |
| if typeidx == TYPE_TORCH then |
| local version = torch.CharStorage():string('V ' .. torch.version(object)) |
| local className = torch.CharStorage():string(torch.typename(object)) |
| self:writeInt(#version) |
| self:writeChar(version) |
| self:writeInt(#className) |
| self:writeChar(className) |
| if object.write then |
| object:write(self) |
| elseif type(object) == 'table' then |
| local var = {} |
| for k,v in pairs(object) do |
| if self:isWritableObject(v) then |
| var[k] = v |
| else |
| print(string.format('$ Warning: cannot write object field <%s>', k)) |
| end |
| end |
| self:writeObject(var) |
| else |
| error(string.format('<%s> is a non-serializable Torch object', torch.typename(object))) |
| end |
| else -- it is a table |
| local size = 0; for k,v in pairs(object) do size = size + 1 end |
| self:writeInt(size) |
| for k,v in pairs(object) do |
| self:writeObject(k) |
| self:writeObject(v) |
| end |
| end |
| end |
| else |
| error('Unwritable object') |
| end |
| end |
| |
| function File:readObject() |
| -- we use an environment to keep a record of read objects |
| if not torch.getenv(self).writeObjects then |
| torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) |
| end |
| |
| -- read the typeidx |
| local typeidx = self:readInt() |
| |
| -- is it nil? |
| if typeidx == TYPE_NIL then |
| return nil |
| end |
| |
| if typeidx == TYPE_NUMBER then |
| return self:readDouble() |
| elseif typeidx == TYPE_BOOLEAN then |
| return self:readBool() |
| elseif typeidx == TYPE_STRING then |
| local size = self:readInt() |
| return self:readChar(size):string() |
| elseif typeidx == TYPE_FUNCTION then |
| local size = self:readInt() |
| local dumped = self:readChar(size):string() |
| local func = loadstring(dumped) |
| local upvalues = self:readObject() |
| for index,upvalue in ipairs(upvalues) do |
| debug.setupvalue(func, index, upvalue) |
| end |
| return func |
| elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then |
| -- read the index |
| local index = self:readInt() |
| |
| -- check it is loaded already |
| local objects = torch.getenv(self).readObjects |
| if objects[index] then |
| return objects[index] |
| end |
| |
| -- otherwise read it |
| if typeidx == TYPE_TORCH then |
| local version, className, versionNumber |
| version = self:readChar(self:readInt()):string() |
| versionNumber = tonumber(string.match(version, '^V (.*)$')) |
| if not versionNumber then |
| className = version |
| versionNumber = 0 -- file created before existence of versioning system |
| else |
| className = self:readChar(self:readInt()):string() |
| end |
| if not torch.factory(className) then |
| error(string.format('unknown Torch class <%s>', tostring(className))) |
| end |
| local object = torch.factory(className)() |
| objects[index] = object |
| if object.read then |
| object:read(self, versionNumber) |
| elseif type(object) == 'table' then |
| local var = self:readObject() |
| for k,v in pairs(var) do |
| object[k] = v |
| end |
| else |
| error(string.format('Cannot load object class <%s>', tostring(className))) |
| end |
| return object |
| else -- it is a table |
| local size = self:readInt() |
| local object = {} |
| objects[index] = object |
| for i = 1,size do |
| local k = self:readObject() |
| local v = self:readObject() |
| object[k] = v |
| end |
| return object |
| end |
| else |
| error('unknown object') |
| end |
| end |
| |
| -- simple helpers to save/load arbitrary objects/tables |
| function torch.save(filename, object, mode) |
| mode = mode or 'binary' |
| local file = torch.DiskFile(filename, 'w') |
| file[mode](file) |
| file:writeObject(object) |
| file:close() |
| end |
| |
| function torch.load(filename, mode) |
| mode = mode or 'binary' |
| local file = torch.DiskFile(filename, 'r') |
| file[mode](file) |
| local object = file:readObject() |
| file:close() |
| return object |
| end |
| |
| -- simple helpers to serialize/deserialize arbitrary objects/tables |
| function torch.serialize(object) |
| local f = torch.MemoryFile() |
| f:writeObject(object) |
| local s = f:storage():string() |
| f:close() |
| return s |
| end |
| |
| function torch.deserialize(str) |
| local x = torch.CharStorage():string(str) |
| local tx = torch.CharTensor(x) |
| local xp = torch.CharStorage(x:size(1)+1) |
| local txp = torch.CharTensor(xp) |
| txp:narrow(1,1,tx:size(1)):copy(tx) |
| txp[tx:size(1)+1] = 0 |
| local f = torch.MemoryFile(xp) |
| local object = f:readObject() |
| f:close() |
| return object |
| end |
| |
| -- public API (saveobj/loadobj are safe for global import) |
| torch.saveobj = torch.save |
| torch.loadobj = torch.load |