blob: 9d1c2bd87fc1080403ec34a30bc056fa5feb49e5 [file] [log] [blame]
#pragma once
#include <Python.h>
#include <memory>
#include <vector>
#include <THPP/THPP.h>
#include <ATen/ATen.h>
namespace torch {
struct TupleParser {
TupleParser(PyObject* args, int num_args=-1);
void parse(bool& x, const std::string& param_name);
void parse(int& x, const std::string& param_name);
void parse(double& x, const std::string& param_name);
void parse(std::unique_ptr<thpp::Tensor>& x, const std::string& param_name);
void parse(std::shared_ptr<thpp::Tensor>& x, const std::string& param_name);
void parse(at::Tensor& x, const std::string& param_name);
void parse(std::vector<int>& x, const std::string& param_name);
void parse(std::string& x, const std::string& param_name);
protected:
std::runtime_error invalid_type(const std::string& expected, const std::string& param_name);
PyObject* next_arg();
private:
PyObject* args;
int idx;
};
} // namespace torch