File size: 3,564 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#pragma once

#include <ATen/EmptyTensor.h>
#include <ATen/Formatting.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <c10/core/ScalarType.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>

#include <algorithm>

#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
  TypeName(const TypeName&) = delete;         \
  void operator=(const TypeName&) = delete

namespace at {

TORCH_API int _crash_if_asan(int);

// Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
// Once cat is ported entirely to ATen this can be deleted!
inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
    ArrayRef<Tensor> tensors,
    const char* name,
    int pos,
    c10::DeviceType device_type,
    ScalarType scalar_type) {
  std::vector<TensorImpl*> unwrapped;
  unwrapped.reserve(tensors.size());
  for (const auto i : c10::irange(tensors.size())) {
    const auto& expr = tensors[i];
    if (expr.layout() != Layout::Strided) {
      TORCH_CHECK(
          false,
          "Expected dense tensor but got ",
          expr.layout(),
          " for sequence element ",
          i,
          " in sequence argument at position #",
          pos,
          " '",
          name,
          "'");
    }
    if (expr.device().type() != device_type) {
      TORCH_CHECK(
          false,
          "Expected object of device type ",
          device_type,
          " but got device type ",
          expr.device().type(),
          " for sequence element ",
          i,
          " in sequence argument at position #",
          pos,
          " '",
          name,
          "'");
    }
    if (expr.scalar_type() != scalar_type) {
      TORCH_CHECK(
          false,
          "Expected object of scalar type ",
          scalar_type,
          " but got scalar type ",
          expr.scalar_type(),
          " for sequence element ",
          i,
          " in sequence argument at position #",
          pos,
          " '",
          name,
          "'");
    }
    unwrapped.emplace_back(expr.unsafeGetTensorImpl());
  }
  return unwrapped;
}

template <size_t N>
std::array<int64_t, N> check_intlist(
    ArrayRef<int64_t> list,
    const char* name,
    int pos) {
  if (list.empty()) {
    // TODO: is this necessary?  We used to treat nullptr-vs-not in IntList
    // differently with strides as a way of faking optional.
    list = {};
  }
  auto res = std::array<int64_t, N>();
  if (list.size() == 1 && N > 1) {
    res.fill(list[0]);
    return res;
  }
  if (list.size() != N) {
    TORCH_CHECK(
        false,
        "Expected a list of ",
        N,
        " ints but got ",
        list.size(),
        " for argument #",
        pos,
        " '",
        name,
        "'");
  }
  std::copy_n(list.begin(), N, res.begin());
  return res;
}

using at::detail::check_size_nonnegative;

namespace detail {

template <typename T>
TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);

template <typename T>
TORCH_API Tensor
tensor_backend(ArrayRef<T> values, const TensorOptions& options);

template <typename T>
TORCH_API Tensor
tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);

template <typename T>
TORCH_API Tensor
tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
} // namespace detail

} // namespace at