|
#pragma once |
|
|
|
#include <ATen/core/IListRef.h> |
|
#include <ATen/core/Tensor.h> |
|
#include <c10/core/TensorImpl.h> |
|
#include <c10/core/WrapDimMinimal.h> |
|
#include <c10/util/irange.h> |
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
|
|
using c10::maybe_wrap_dim; |
|
|
|
inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) { |
|
return maybe_wrap_dim(dim, tensor->dim()); |
|
} |
|
|
|
inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) { |
|
if (tensors.empty()) { |
|
|
|
|
|
return dim; |
|
} |
|
return maybe_wrap_dim(dim, tensors[0].dim()); |
|
} |
|
|
|
inline int64_t maybe_wrap_dim( |
|
int64_t dim, |
|
const std::vector<std::vector<int64_t>>& tensor_sizes) { |
|
if (tensor_sizes.empty()) { |
|
|
|
|
|
return dim; |
|
} |
|
return maybe_wrap_dim(dim, static_cast<int64_t>(tensor_sizes[0].size())); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void maybe_wrap_dims_n( |
|
int64_t* dims, |
|
int64_t ndims, |
|
int64_t dim_post_expr, |
|
bool wrap_scalars = true) { |
|
if (dim_post_expr <= 0) { |
|
if (wrap_scalars) { |
|
dim_post_expr = 1; |
|
} else { |
|
TORCH_CHECK_INDEX( |
|
ndims == 0, |
|
"Dimension specified as ", |
|
dims[0], |
|
" but tensor has no dimensions"); |
|
return; |
|
} |
|
} |
|
int64_t min = -dim_post_expr; |
|
int64_t max = dim_post_expr - 1; |
|
for (const auto i : c10::irange(ndims)) { |
|
auto& dim = dims[i]; |
|
if (dim < min || dim > max) { |
|
TORCH_CHECK_INDEX( |
|
false, |
|
"Dimension out of range (expected to be in range of [", |
|
min, |
|
", ", |
|
max, |
|
"], but got ", |
|
dim, |
|
")"); |
|
} |
|
if (dim < 0) |
|
dim += dim_post_expr; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Container> |
|
inline void maybe_wrap_dims( |
|
Container& dims, |
|
int64_t dim_post_expr, |
|
bool wrap_scalars = true) { |
|
return maybe_wrap_dims_n( |
|
dims.data(), dims.size(), dim_post_expr, wrap_scalars); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline int64_t legacy_cat_wrap_dim( |
|
int64_t dim, |
|
const std::vector<std::vector<int64_t>>& tensor_sizes) { |
|
for (auto& sizes : tensor_sizes) { |
|
if (sizes.size() == 1 && sizes[0] == 0) { |
|
continue; |
|
} |
|
return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size())); |
|
} |
|
return dim; |
|
} |
|
|
|
inline int64_t legacy_cat_wrap_dim_symint( |
|
int64_t dim, |
|
const std::vector<std::vector<c10::SymInt>>& tensor_sizes) { |
|
for (auto& sizes : tensor_sizes) { |
|
if (sizes.size() == 1) { |
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) { |
|
continue; |
|
} |
|
} |
|
return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size())); |
|
} |
|
return dim; |
|
} |
|
|
|
inline int64_t legacy_cat_wrap_dim( |
|
int64_t dim, |
|
const MaterializedITensorListRef& tensors) { |
|
for (const Tensor& tensor : tensors) { |
|
if (tensor.dim() == 1) { |
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) { |
|
continue; |
|
} |
|
} |
|
return maybe_wrap_dim(dim, tensor.dim()); |
|
} |
|
return dim; |
|
} |
|
|
|
|
|
inline void wrap_all_dims( |
|
std::vector<int64_t>& dims_to_wrap, |
|
int64_t tensor_total_dims) { |
|
for (const auto i : c10::irange(dims_to_wrap.size())) { |
|
dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims); |
|
} |
|
} |
|
|
|
} |
|
|