Add THNN conversion of {RReLU, Sigmoid, SmoothL1Criterion,SoftMax, SoftPlus}
7 files changed
tree: c49f1db588cc4defe4d9dec8528b1b1fb4dbf678
  1. generic/
  2. CMakeLists.txt
  3. init.c
  4. README.md
  5. THNN.h
README.md

API design guidelines

All functions should accept arguments in the following order. Dots represent any module-specific parameters or buffers, disregarding whether they are used for writing or reading. They should follow the order

[weight], [bias], [any buffers], [additional arguments], [optional arugments]

Modules

updateOutput: state, input, output, ...
updateGradInput: state, input, gradOutput, gradInput, ...
accGradParameters: state, input, gradOutput, [gradWeight], [gradBias], ...

e.g.

void THNN_(HardShrink_updateGradInput)(
          THNNState* state,
          THTensor *input,
          THTensor *gradOutput,
          THTensor *gradInput,
          real lambda)

Criterions

updateOutput: state, input, target, output, ...
updateGradInput: state, input, target, gradInput, ...

e.g.

void THNN_(ClassNLLCriterion_updateOutput)(
          THNNState* state,
          THTensor *input,
          THLongTensor *target,
          THTensor *output,
          THTensor *weights,
          THTensor *total_weight,
          bool sizeAverage)

Code style guide

void THNN_Linear_updateOutput(
          THTensor *input,
          THTensor *output,
          THTensor *weight,
          THTensor *bias);
//<- 10 ->

All arguments should start on a new line after function name, and they should be indented using 10 spaces.

Use 2 spaces for block indentation.

Conversion Steps

  1. copy old .c file to lib/THNN/generic
  • replace static int nn_ -> void THNN_
  • replace lua_State *L with ‘actual’ parameters (+ add THNNState* state)
  • remove any numeric values from return statements, remove the return at the end of the function body
  • remove old luaL_Reg & _init function
  1. add forward declarations to generic/THNN.h
  2. include the generic/xyz.c file in init.c
  3. add functions to ffi.lua
  4. copy & adapt lua file: specify module THNN for torch.class(), use THNN.errcheck
  5. include module lua file in init.lua
  6. add & run unit test to lua/tests/test.lua