blob: a1517586fc9718248144c2c027dadd61d7099eb8 [file] [log] [blame]
#include <catch.hpp>
#include <torch/csrc/autograd/functions/comm.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/parallel/data_parallel.h>
#include <torch/nn/pimpl.h>
#include <torch/tensor.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
using Catch::StartsWith;
using namespace torch::autograd;
using namespace torch::nn;
#ifdef USE_CUDA
TEST_CASE("Parallel/DifferentiableScatter", "[multi-cuda]") {
Scatter scatter(
{torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
auto input = torch::ones(10, torch::requires_grad(true));
auto output = scatter.apply({input});
REQUIRE(output.size() == 2);
REQUIRE(output[0].size(0) == 5);
REQUIRE(output[1].size(0) == 5);
REQUIRE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
.allclose(input));
auto sum = output[0].to({torch::kCUDA, 1}) + output[1];
sum.backward();
REQUIRE(input.grad().defined());
REQUIRE(input.grad().device().is_cpu());
REQUIRE(input.grad().sum().toCInt() == 10);
}
TEST_CASE("Parallel/DifferentiableGather", "[multi-cuda]") {
Gather gather(torch::Device(torch::kCUDA, 1));
auto a = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 0}));
auto b = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 1}));
auto outputs = gather.apply({a, b});
REQUIRE(outputs.size() == 1);
auto& output = outputs.front();
REQUIRE(output.size(0) == 10);
REQUIRE(output.device() == torch::Device(torch::kCUDA, 1));
auto chunks = output.chunk(2);
REQUIRE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
REQUIRE(chunks[1].allclose(b));
output.backward();
REQUIRE(a.grad().defined());
REQUIRE(a.grad().device() == torch::Device(torch::kCUDA, 0));
REQUIRE(a.grad().sum().toCInt() == 5);
REQUIRE(b.grad().defined());
REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1));
REQUIRE(b.grad().sum().toCInt() == 5);
}
TEST_CASE("Parallel/Replicate", "[multi-cuda]") {
Linear linear(3, 4);
auto replicas = parallel::replicate(
linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
REQUIRE(replicas.size() == 2);
auto original_parameters = linear->parameters();
auto replica1_parameters = replicas[0]->parameters();
for (auto& parameter : replica1_parameters) {
REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 0));
}
replicas[0]->to(torch::kCPU);
REQUIRE(replica1_parameters.size() == original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
REQUIRE(replica1_parameters[i]->allclose(*original_parameters[i]));
REQUIRE(
replica1_parameters[i].data<float>() !=
original_parameters[i].data<float>());
}
auto replica2_parameters = replicas[1]->parameters();
for (auto& parameter : replica2_parameters) {
REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 1));
}
replicas[1]->to(torch::kCPU);
REQUIRE(replica2_parameters.size() == original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
REQUIRE(replica2_parameters[i]->allclose(*original_parameters[i]));
REQUIRE(
replica2_parameters[i].data<float>() !=
original_parameters[i].data<float>());
}
}
TEST_CASE("Parallel/ParallelApply", "[multi-cuda]") {
Linear a(3, 4);
Linear b(std::static_pointer_cast<LinearImpl>(a->clone()));
b->to({torch::kCUDA, 0});
Linear c(std::static_pointer_cast<LinearImpl>(a->clone()));
c->to({torch::kCUDA, 1});
std::vector<Linear> modules = {a, b, c};
std::vector<torch::Tensor> inputs = {
torch::ones({2, 3}),
torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
auto outputs = parallel::parallel_apply(modules, inputs);
REQUIRE(outputs.size() == 3);
REQUIRE(outputs[0].device().is_cpu());
REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
REQUIRE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
REQUIRE(outputs[2].device() == torch::Device(torch::kCUDA, 1));
REQUIRE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
}
TEST_CASE("Parallel/ParallelApplyWithDifferentOutputDevice", "[multi-cuda]") {
struct M : torch::nn::Module {
torch::Tensor forward(torch::Tensor input) {
return torch::ones({5}, torch::dtype(torch::kInt32));
}
};
std::vector<std::shared_ptr<M>> modules = {
std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
std::vector<torch::Tensor> inputs = {
torch::empty({}), torch::empty({}), torch::empty({})};
std::vector<torch::Device> devices = {
{torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
auto outputs = parallel::parallel_apply(modules, inputs, devices);
REQUIRE(outputs.size() == 3);
REQUIRE(outputs[0].device().is_cuda());
REQUIRE(outputs[0].device() == torch::Device(torch::kCUDA, 1));
REQUIRE(outputs[1].device().is_cuda());
REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
REQUIRE(outputs[2].device().is_cpu());
}
TEST_CASE("Parallel/ParallelApplyRethrowsException", "[multi-cuda]") {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
throw std::runtime_error("Badness!");
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
REQUIRE_THROWS_WITH(
parallel::data_parallel(m, input), StartsWith("Badness!"));
}
TEST_CASE(
"Parallel/DataParallelPlacesTheOutputOnTheRequestedDevice",
"[multi-cuda]") {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
// Intermediate tensors should be on the replica's current device.
intermediate_tensor = torch::rand(5);
// The returned tensor should be on the output device.
return torch::ones(3);
}
torch::Tensor intermediate_tensor;
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
{
auto output = parallel::data_parallel(
m,
input,
/*devices=*/at::nullopt,
/*output_device=*/torch::Device(torch::kCUDA, 1));
REQUIRE(output.defined());
REQUIRE(output.device().is_cuda());
REQUIRE(output.device().index() == 1);
}
{
// Verify for the single-device case (where we don't scatter/gather).
auto output = parallel::data_parallel(
m,
input,
/*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
/*output_device=*/torch::Device(torch::kCUDA, 1));
REQUIRE(m->intermediate_tensor.defined());
REQUIRE(m->intermediate_tensor.device().is_cuda());
REQUIRE(m->intermediate_tensor.device().index() == 0);
REQUIRE(output.defined());
REQUIRE(output.device().is_cuda());
REQUIRE(output.device().index() == 1);
}
}
TEST_CASE("Parallel/DataParallelUsesAllAvailableCUDADevices", "[cuda]") {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
return torch::tensor(torch::DefaultTensorOptions::get().device().index());
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
auto output = parallel::data_parallel(m, input);
const auto device_count = torch::cuda::device_count();
REQUIRE(output.numel() == device_count);
for (size_t i = 0; i < device_count; ++i) {
REQUIRE(output[i].toCInt() == i);
}
}
#endif