File size: 4,867 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#pragma once

#include <ATen/core/IListRef.h>
#include <ATen/core/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/irange.h>

namespace at {

// if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
// range [-1, 0]. This is a special case for scalar tensors and manifests in
// e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
// [-dim_post_expr, dim_post_expr-1].
using c10::maybe_wrap_dim;

inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
  return maybe_wrap_dim(dim, tensor->dim());
}

inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
  if (tensors.empty()) {
    // can't wrap empty TensorList; rely on underlying implementation to throw
    // error if necessary.
    return dim;
  }
  return maybe_wrap_dim(dim, tensors[0].dim());
}

inline int64_t maybe_wrap_dim(
    int64_t dim,
    const std::vector<std::vector<int64_t>>& tensor_sizes) {
  if (tensor_sizes.empty()) {
    // can't wrap empty list; rely on underlying implementation to throw error
    // if necessary
    return dim;
  }
  return maybe_wrap_dim(dim, static_cast<int64_t>(tensor_sizes[0].size()));
}

// Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
// specified using negative indices.
//
// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
// dimensions not in the range [-dim_post_expr, dim_post_expr).
inline void maybe_wrap_dims_n(
    int64_t* dims,
    int64_t ndims,
    int64_t dim_post_expr,
    bool wrap_scalars = true) {
  if (dim_post_expr <= 0) {
    if (wrap_scalars) {
      dim_post_expr = 1; // this will make range [-1, 0]
    } else {
      TORCH_CHECK_INDEX(
          ndims == 0,
          "Dimension specified as ",
          dims[0],
          " but tensor has no dimensions");
      return;
    }
  }
  int64_t min = -dim_post_expr;
  int64_t max = dim_post_expr - 1;
  for (const auto i : c10::irange(ndims)) {
    auto& dim = dims[i];
    if (dim < min || dim > max) {
      TORCH_CHECK_INDEX(
          false,
          "Dimension out of range (expected to be in range of [",
          min,
          ", ",
          max,
          "], but got ",
          dim,
          ")");
    }
    if (dim < 0)
      dim += dim_post_expr;
  }
}

// Given a contiguous container of dimensions `dims`, this function "Wraps"
// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
// specified using negative indices.
//
// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
// dimensions not in the range [-dim_post_expr, dim_post_expr).
template <typename Container>
inline void maybe_wrap_dims(
    Container& dims,
    int64_t dim_post_expr,
    bool wrap_scalars = true) {
  return maybe_wrap_dims_n(
      dims.data(), dims.size(), dim_post_expr, wrap_scalars);
}

// previously, size [0] tensors were the only possible empty tensors; thus, it
// wasn't possible to cat empty tensors unless all the other tensors were
// 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
// dimension behavior and dimension size checking). We maintain this behavior
// for backwards compatibility, but only for this specific size (i.e. other
// empty sizes are not skipped).
inline int64_t legacy_cat_wrap_dim(
    int64_t dim,
    const std::vector<std::vector<int64_t>>& tensor_sizes) {
  for (auto& sizes : tensor_sizes) {
    if (sizes.size() == 1 && sizes[0] == 0) {
      continue;
    }
    return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
  }
  return dim;
}

inline int64_t legacy_cat_wrap_dim_symint(
    int64_t dim,
    const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
  for (auto& sizes : tensor_sizes) {
    if (sizes.size() == 1) {
      if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) {
        continue;
      }
    }
    return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
  }
  return dim;
}

inline int64_t legacy_cat_wrap_dim(
    int64_t dim,
    const MaterializedITensorListRef& tensors) {
  for (const Tensor& tensor : tensors) {
    if (tensor.dim() == 1) {
      if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) {
        continue;
      }
    }
    return maybe_wrap_dim(dim, tensor.dim());
  }
  return dim;
}

// wrap negative dims in a vector
inline void wrap_all_dims(
    std::vector<int64_t>& dims_to_wrap,
    int64_t tensor_total_dims) {
  for (const auto i : c10::irange(dims_to_wrap.size())) {
    dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
  }
}

} // namespace at