| #include "general.h" |
| #include "utils.h" |
| |
| #include <sys/time.h> |
| |
| static const void* torch_LongStorage_id = NULL; |
| static const void* torch_default_tensor_id = NULL; |
| |
| THLongStorage* torch_checklongargs(lua_State *L, int index) |
| { |
| THLongStorage *storage; |
| int i; |
| int narg = lua_gettop(L)-index+1; |
| |
| if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id)) |
| { |
| THLongStorage *storagesrc = luaT_toudata(L, index, torch_LongStorage_id); |
| storage = THLongStorage_newWithSize(storagesrc->size); |
| THLongStorage_copy(storage, storagesrc); |
| } |
| else |
| { |
| storage = THLongStorage_newWithSize(narg); |
| for(i = index; i < index+narg; i++) |
| { |
| if(!lua_isnumber(L, i)) |
| { |
| THLongStorage_free(storage); |
| luaL_argerror(L, i, "number expected"); |
| } |
| THLongStorage_set(storage, i-index, lua_tonumber(L, i)); |
| } |
| } |
| return storage; |
| } |
| |
| int torch_islongargs(lua_State *L, int index) |
| { |
| int narg = lua_gettop(L)-index+1; |
| |
| if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id)) |
| { |
| return 1; |
| } |
| else |
| { |
| int i; |
| |
| for(i = index; i < index+narg; i++) |
| { |
| if(!lua_isnumber(L, i)) |
| return 0; |
| } |
| return 1; |
| } |
| return 0; |
| } |
| |
| |
| |
| static int torch_lua_tic(lua_State* L) |
| { |
| struct timeval tv; |
| gettimeofday(&tv,NULL); |
| double ttime = (double)tv.tv_sec + (double)(tv.tv_usec)/1000000.0; |
| lua_pushnumber(L,ttime); |
| return 1; |
| } |
| |
| static int torch_lua_toc(lua_State* L) |
| { |
| struct timeval tv; |
| gettimeofday(&tv,NULL); |
| double toctime = (double)tv.tv_sec + (double)(tv.tv_usec)/1000000.0; |
| lua_Number tictime = luaL_checknumber(L,1); |
| lua_pushnumber(L,toctime-tictime); |
| return 1; |
| } |
| |
| static int torch_lua_setdefaulttensortype(lua_State *L) |
| { |
| const void *id; |
| |
| luaL_checkstring(L, 1); |
| |
| if(!(id = luaT_typename2id(L, lua_tostring(L, 1)))) \ |
| return luaL_error(L, "<%s> is not a string describing a torch object", lua_tostring(L, 1)); \ |
| |
| torch_default_tensor_id = id; |
| |
| return 0; |
| } |
| |
| static int torch_lua_getdefaulttensortype(lua_State *L) |
| { |
| lua_pushstring(L, luaT_id2typename(L, torch_default_tensor_id)); |
| return 1; |
| } |
| |
| void torch_setdefaulttensorid(const void* id) |
| { |
| torch_default_tensor_id = id; |
| } |
| |
| const void* torch_getdefaulttensorid() |
| { |
| return torch_default_tensor_id; |
| } |
| |
| static const struct luaL_Reg torch_utils__ [] = { |
| {"__setdefaulttensortype", torch_lua_setdefaulttensortype}, |
| {"getdefaulttensortype", torch_lua_getdefaulttensortype}, |
| {"tic", torch_lua_tic}, |
| {"toc", torch_lua_toc}, |
| {"factory", luaT_lua_factory}, |
| {"getconstructortable", luaT_lua_getconstructortable}, |
| {"id", luaT_lua_id}, |
| {"typename", luaT_lua_typename}, |
| {"typename2id", luaT_lua_typename2id}, |
| {"isequal", luaT_lua_isequal}, |
| {"getenv", luaT_lua_getenv}, |
| {"setenv", luaT_lua_setenv}, |
| {"newmetatable", luaT_lua_newmetatable}, |
| {"setmetatable", luaT_lua_setmetatable}, |
| {"getmetatable", luaT_lua_getmetatable}, |
| {"version", luaT_lua_version}, |
| {"pointer", luaT_lua_pointer}, |
| {NULL, NULL} |
| }; |
| |
| void torch_utils_init(lua_State *L) |
| { |
| torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); |
| luaL_register(L, NULL, torch_utils__); |
| } |