File size: 3,564 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#pragma once
#include <ATen/EmptyTensor.h>
#include <ATen/Formatting.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <c10/core/ScalarType.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <algorithm>
#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
namespace at {
TORCH_API int _crash_if_asan(int);
// Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
// Once cat is ported entirely to ATen this can be deleted!
inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
ArrayRef<Tensor> tensors,
const char* name,
int pos,
c10::DeviceType device_type,
ScalarType scalar_type) {
std::vector<TensorImpl*> unwrapped;
unwrapped.reserve(tensors.size());
for (const auto i : c10::irange(tensors.size())) {
const auto& expr = tensors[i];
if (expr.layout() != Layout::Strided) {
TORCH_CHECK(
false,
"Expected dense tensor but got ",
expr.layout(),
" for sequence element ",
i,
" in sequence argument at position #",
pos,
" '",
name,
"'");
}
if (expr.device().type() != device_type) {
TORCH_CHECK(
false,
"Expected object of device type ",
device_type,
" but got device type ",
expr.device().type(),
" for sequence element ",
i,
" in sequence argument at position #",
pos,
" '",
name,
"'");
}
if (expr.scalar_type() != scalar_type) {
TORCH_CHECK(
false,
"Expected object of scalar type ",
scalar_type,
" but got scalar type ",
expr.scalar_type(),
" for sequence element ",
i,
" in sequence argument at position #",
pos,
" '",
name,
"'");
}
unwrapped.emplace_back(expr.unsafeGetTensorImpl());
}
return unwrapped;
}
template <size_t N>
std::array<int64_t, N> check_intlist(
ArrayRef<int64_t> list,
const char* name,
int pos) {
if (list.empty()) {
// TODO: is this necessary? We used to treat nullptr-vs-not in IntList
// differently with strides as a way of faking optional.
list = {};
}
auto res = std::array<int64_t, N>();
if (list.size() == 1 && N > 1) {
res.fill(list[0]);
return res;
}
if (list.size() != N) {
TORCH_CHECK(
false,
"Expected a list of ",
N,
" ints but got ",
list.size(),
" for argument #",
pos,
" '",
name,
"'");
}
std::copy_n(list.begin(), N, res.begin());
return res;
}
using at::detail::check_size_nonnegative;
namespace detail {
template <typename T>
TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
template <typename T>
TORCH_API Tensor
tensor_backend(ArrayRef<T> values, const TensorOptions& options);
template <typename T>
TORCH_API Tensor
tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
template <typename T>
TORCH_API Tensor
tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
} // namespace detail
} // namespace at
|