| -- |
| -- require 'wrap' |
| --- |
| |
| interface = wrap.CInterface.new() |
| |
| |
| interface.dispatchregistry = {} |
| function interface:wrap(name, ...) |
| -- usual stuff |
| --wrap.CInterface.wrap(self, name, ...) |
| |
| -- dispatch function |
| if not interface.dispatchregistry[name] then |
| interface.dispatchregistry[name] = true |
| table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) |
| |
| interface:print(string.gsub([[ |
| static int torch_NAME(lua_State *L) |
| { |
| int narg = lua_gettop(L); |
| const void *id; |
| |
| if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ |
| { |
| if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ |
| { |
| if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ |
| lua_pop(L, 1); |
| else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) |
| luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); |
| } |
| } |
| |
| lua_pushstring(L, "NAME"); |
| lua_rawget(L, -2); |
| if(lua_isfunction(L, -1)) |
| { |
| lua_insert(L, 1); |
| lua_pop(L, 2); /* the two tables we put on the stack above */ |
| lua_call(L, lua_gettop(L)-1, LUA_MULTRET); |
| } |
| else |
| return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); |
| |
| return lua_gettop(L); |
| } |
| ]], 'NAME', name)) |
| end |
| end |
| |
| function interface:dispatchregister(name) |
| local txt = self.txt |
| table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) |
| for _,reg in ipairs(self.dispatchregistry) do |
| table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) |
| end |
| table.insert(txt, '{NULL, NULL}') |
| table.insert(txt, '};') |
| table.insert(txt, '') |
| self.dispatchregistry = {} |
| end |
| |
| interface:print('/* WARNING: autogenerated file */') |
| interface:print('') |
| |
| local reals = {ByteTensor='byte', |
| CharTensor='char', |
| ShortTensor='short', |
| IntTensor='int', |
| LongTensor='long', |
| FloatTensor='float', |
| DoubleTensor='double'} |
| |
| for _,Tensor in ipairs({"FloatTensor", "DoubleTensor", "IntTensor", "LongTensor", "ByteTensor", "CharTensor","ShortTensor"}) do |
| |
| local real = reals[Tensor] |
| |
| function interface.luaname2wrapname(self, name) |
| return string.format('torch_%s_%s', Tensor, name) |
| end |
| |
| local function cname(name) |
| return string.format('TH%s_%s', Tensor, name) |
| end |
| |
| local function lastdim(argn) |
| return function(arg) |
| return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) |
| end |
| end |
| |
| |
| for _,name in ipairs({"conv2","xcorr2","conv3","xcorr3"}) do |
| interface:wrap(name, |
| cname(name), |
| {{name=Tensor, default=true, returned=true}, |
| {name=Tensor, default=true, returned=true}, |
| {name=Tensor}, |
| {name=Tensor}} |
| ) |
| end |
| |
| |
| --interface:register(string.format("torch_%sLapack__", Tensor)) |
| |
| -- interface:print(string.gsub([[ |
| -- static void torch_TensorLapack_init(lua_State *L) |
| -- { |
| -- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); |
| -- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); |
| |
| -- luaT_pushmetaclass(L, torch_Tensor_id); |
| -- lua_getfield(L,-1,"torch"); |
| -- luaL_register(L, NULL, torch_TensorLapack__); |
| -- lua_pop(L, 2); |
| -- } |
| -- ]], 'Tensor', Tensor)) |
| end |
| |
| interface:dispatchregister("torch_TensorConv__") |
| |
| if arg[1] then |
| interface:tofile(arg[1]) |
| else |
| interface:tostdio() |
| end |