blob: 68b282277eefb4d294f975474856303cc0bf8101 [file] [log] [blame]
#include "gtest/gtest.h"
#include "ATen/ATen.h"
#include "ATen/cudnn/Descriptors.h"
#include "ATen/cudnn/Handle.h"
#include "test_seed.h"
using namespace at;
using namespace at::native;
TEST(CUDNNTest, CUDNNTestCUDA) {
manual_seed(123, at::kCUDA);
#if CUDNN_VERSION < 7000
auto handle = getCudnnHandle();
DropoutDescriptor desc1, desc2;
desc1.initialize_rng(handle, 0.5, 42, TensorOptions().device(DeviceType::CUDA).dtype(kByte));
desc2.set(handle, 0.5, desc1.state);
bool isEQ;
isEQ = (desc1.desc()->dropout == desc2.desc()->dropout);
ASSERT_TRUE(isEQ);
isEQ = (desc1.desc()->nstates == desc2.desc()->nstates);
ASSERT_TRUE(isEQ);
isEQ = (desc1.desc()->states == desc2.desc()->states);
ASSERT_TRUE(isEQ);
#endif
}