blob: 37a6f1598fc0511070500659e3eeaaffcca311fd [file] [log] [blame]
#pragma once
#include <atomic>
#include "ATen/Type.h"
#include <iostream>
namespace at {
struct Type;
class Scalar;
struct TensorImpl {
explicit TensorImpl(Type * type)
: refcount(1), is_scalar(false), type_(type) {}
Type & type() const {
return *type_;
}
virtual const char * toString() const = 0;
virtual IntList sizes() = 0;
virtual IntList strides() = 0;
virtual int64_t dim() = 0;
virtual Scalar localScalar() = 0;
virtual void assign_(Scalar s) = 0;
virtual void * unsafeGetTH() = 0;
void retain() {
++refcount;
}
virtual void release() {
if(--refcount == 0) {
delete this;
}
}
virtual ~TensorImpl() {}
friend struct Type;
// 0-dim patchup of TH requires us to have a flag marking
// if a Tensor should be treated as 0-dim.
// the generated wrapper manipulates this flag.
// the setter should never be exposed in Tensor's public API
// because eventually we would like isScalar() to just be dim() == 0;
bool isScalar() const {
return is_scalar;
}
// this is called by the generated wrapper code when there are conditions
// when this output tensor should be a scalar. e.g. when all inputs
// to a function 'add' were scalars, then condition_when_scalar == true.
// we also prevent this from getting marked as a scalar if it is not
// the right shape afterall.
TensorImpl* maybeScalar(bool condition_when_scalar) {
is_scalar = false; //force dim() to tell the truth for TH
is_scalar = condition_when_scalar && dim() == 1 && sizes()[0] == 1;
return this;
}
void setScalar(bool s) {
is_scalar = s;
}
private:
std::atomic<int> refcount;
bool is_scalar;
Type * type_;
};
}