commit | a25d3b4d8c632589828e923df8f5328084e9b833 | [log] [tgz] |
---|---|---|
author | serega <sergey.melderis@gmail.com> | Wed Oct 31 11:03:49 2018 -0700 |
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | Wed Oct 31 11:05:40 2018 -0700 |
tree | 7a6917c131b94eaa9b7a8515868f4c8244e62f61 | |
parent | 488d393ea604184eb30e9eaa67b6c60291152e21 [diff] |
Use byte tensor for mnist labels. (#13363) Summary: The C++ mnist example https://github.com/goldsborough/examples/blob/cpp/cpp/mnist/mnist.cpp does not work because the labels are not correctly loaded. Currently it achieves 100 % accuracy. Specifying byte dtype fixes the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13363 Differential Revision: D12860258 Pulled By: goldsborough fbshipit-source-id: ad7b9256e4fc627240e25c79de9d47b31da18d38
diff --git a/torch/csrc/api/src/data/datasets/mnist.cpp b/torch/csrc/api/src/data/datasets/mnist.cpp index ba4a437..911e5c3 100644 --- a/torch/csrc/api/src/data/datasets/mnist.cpp +++ b/torch/csrc/api/src/data/datasets/mnist.cpp
@@ -90,7 +90,7 @@ expect_int32(targets, kTargetMagicNumber); expect_int32(targets, count); - auto tensor = torch::empty(count); + auto tensor = torch::empty(count, torch::kByte); targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count); return tensor.to(torch::kInt64); }