blob: 5e64db8c5e9db91106259bf37a779255f8cfec3b [file] [log] [blame]
local Tester = torch.class('torch.Tester')
function Tester:__init()
self.errors = {}
self.tests = {}
self.testnames = {}
self.curtestname = ''
end
function Tester:assert_sub (condition, message)
self.countasserts = self.countasserts + 1
if not condition then
local ss = debug.traceback('tester',2)
--print(ss)
ss = ss:match('[^\n]+\n[^\n]+\n([^\n]+\n[^\n]+)\n')
self.errors[#self.errors+1] = self.curtestname .. '\n' .. message .. '\n' .. ss .. '\n'
end
end
function Tester:assert (condition, message)
self:assert_sub(condition,string.format('%s\n%s condition=%s',message,' BOOL violation ', tostring(condition)))
end
function Tester:assertlt (val, condition, message)
self:assert_sub(val<condition,string.format('%s\n%s val=%s, condition=%s',message,' LT(<) violation ', tostring(val), tostring(condition)))
end
function Tester:assertgt (val, condition, message)
self:assert_sub(val>condition,string.format('%s\n%s val=%s, condition=%s',message,' GT(>) violation ', tostring(val), tostring(condition)))
end
function Tester:assertle (val, condition, message)
self:assert_sub(val<=condition,string.format('%s\n%s val=%s, condition=%s',message,' LE(<=) violation ', tostring(val), tostring(condition)))
end
function Tester:assertge (val, condition, message)
self:assert_sub(val>=condition,string.format('%s\n%s val=%s, condition=%s',message,' GE(>=) violation ', tostring(val), tostring(condition)))
end
function Tester:asserteq (val, condition, message)
self:assert_sub(val==condition,string.format('%s\n%s val=%s, condition=%s',message,' EQ(==) violation ', tostring(val), tostring(condition)))
end
function Tester:assertne (val, condition, message)
self:assert_sub(val~=condition,string.format('%s\n%s val=%s, condition=%s',message,' NE(~=) violation ', tostring(val), tostring(condition)))
end
function Tester:assertTensorEq(ta, tb, condition, message)
local diff = ta-tb
local err = diff:abs():max()
self:assert_sub(err<condition,string.format('%s\n%s val=%s, condition=%s',message,' TensorEQ(==) violation ', tostring(err), tostring(condition)))
end
function Tester:assertTensorNe(ta, tb, condition, message)
local diff = ta-tb
local err = diff:abs():max()
self:assert_sub(err>=condition,string.format('%s\n%s val=%s, condition=%s',message,' TensorNE(~=) violation ', tostring(err), tostring(condition)))
end
local function areTablesEqual(ta, tb)
local function isIncludedIn(ta, tb)
if type(ta) ~= 'table' or type(tb) ~= 'table' then
return ta == tb
end
for k, v in pairs(tb) do
if not areTablesEqual(ta[k], v) then return false end
end
return true
end
return isIncludedIn(ta, tb) and isIncludedIn(tb, ta)
end
function Tester:assertTableEq(ta, tb, message)
self:assert_sub(areTablesEqual(ta, tb), string.format('%s\n%s val=%s, condition=%s',message,' TableEQ(==) violation ', tostring(err), tostring(condition)))
end
function Tester:assertTableNe(ta, tb, message)
self:assert_sub(not areTablesEqual(ta, tb), string.format('%s\n%s val=%s, condition=%s',message,' TableEQ(==) violation ', tostring(err), tostring(condition)))
end
function Tester:assertError(f, message)
status, err = pcall(f)
self:assert_sub(status == false, string.format('%s\n%s condition=%s',message,' ERROR violation ', 'should have errored'))
end
function Tester:pcall(f)
local nerr = #self.errors
-- local res = f()
local stat, result = xpcall(f, debug.traceback)
if not stat then
self.errors[#self.errors+1] = self.curtestname .. '\n Function call failed \n' .. result .. '\n'
end
return stat, result, stat and (nerr == #self.errors)
-- return true, res, nerr == #self.errors
end
function Tester:report(tests)
if not tests then
tests = self.tests
end
print('Completed ' .. self.countasserts .. ' asserts in ' .. #tests .. ' tests with ' .. #self.errors .. ' errors')
print()
print(string.rep('-',80))
for i,v in ipairs(self.errors) do
print(v)
print(string.rep('-',80))
end
end
function Tester:run(run_tests)
local tests, testnames
self.countasserts = 0
tests = self.tests
testnames = self.testnames
if type(run_tests) == 'string' then
run_tests = {run_tests}
end
if type(run_tests) == 'table' then
tests = {}
testnames = {}
for i,fun in ipairs(self.tests) do
for j,name in ipairs(run_tests) do
if self.testnames[i] == name then
tests[#tests+1] = self.tests[i]
testnames[#testnames+1] = self.testnames[i]
end
end
end
end
print('Running ' .. #tests .. ' tests')
local statstr = string.rep('_',#tests)
local pstr = ''
io.write(statstr .. '\r')
for i,v in ipairs(tests) do
self.curtestname = testnames[i]
--clear
io.write('\r' .. string.rep(' ', pstr:len()))
io.flush()
--write
pstr = statstr:sub(1,i-1) .. '|' .. statstr:sub(i+1) .. ' ==> ' .. self.curtestname
io.write('\r' .. pstr)
io.flush()
local stat, message, pass = self:pcall(v)
if pass then
--io.write(string.format('\b_'))
statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1)
else
statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1)
--io.write(string.format('\b*'))
end
if not stat then
-- print()
-- print('Function call failed: Test No ' .. i .. ' ' .. testnames[i])
-- print(message)
end
collectgarbage()
end
--clear
io.write('\r' .. string.rep(' ', pstr:len()))
io.flush()
-- write finish
pstr = statstr .. ' ==> Done '
io.write('\r' .. pstr)
io.flush()
print()
print()
self:report(tests)
end
function Tester:add(f,name)
name = name or 'unknown'
if type(f) == "table" then
for i,v in pairs(f) do
self:add(v,i)
end
elseif type(f) == "function" then
self.tests[#self.tests+1] = f
self.testnames[#self.tests] = name
else
error('Tester:add(f) expects a function or a table of functions')
end
end