|
#include <c10/util/Exception.h> |
|
#include <utility> |
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
inline std::pair<int64_t, int64_t> collapse_dims( |
|
T* sizes, |
|
T* strides, |
|
int64_t dims, |
|
const int excludeDim = -1) { |
|
TORCH_CHECK( |
|
excludeDim >= -1 && excludeDim < dims, |
|
"expected excluded dim between -1 and dims - 1"); |
|
|
|
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim; |
|
int64_t newIndex = -1; |
|
int64_t oldIndex = 0; |
|
int64_t remappedExcludedDim = -1; |
|
|
|
while (oldIndex < dims) { |
|
|
|
for (; oldIndex < stopDim; ++oldIndex) { |
|
if (sizes[oldIndex] == 1) { |
|
continue; |
|
} |
|
|
|
++newIndex; |
|
sizes[newIndex] = sizes[oldIndex]; |
|
strides[newIndex] = strides[oldIndex]; |
|
++oldIndex; |
|
break; |
|
} |
|
|
|
|
|
for (; oldIndex < stopDim; ++oldIndex) { |
|
if (sizes[oldIndex] == 1) { |
|
continue; |
|
} |
|
|
|
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { |
|
sizes[newIndex] *= sizes[oldIndex]; |
|
strides[newIndex] = strides[oldIndex]; |
|
} else { |
|
++newIndex; |
|
sizes[newIndex] = sizes[oldIndex]; |
|
strides[newIndex] = strides[oldIndex]; |
|
} |
|
} |
|
|
|
|
|
if (oldIndex != dims) { |
|
|
|
++newIndex; |
|
sizes[newIndex] = sizes[oldIndex]; |
|
strides[newIndex] = strides[oldIndex]; |
|
remappedExcludedDim = newIndex; |
|
|
|
|
|
++oldIndex; |
|
stopDim = dims; |
|
} |
|
} |
|
|
|
|
|
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { |
|
dims = 1; |
|
sizes[0] = 1; |
|
strides[0] = 1; |
|
|
|
return std::pair<int64_t, int64_t>(0, 1); |
|
} |
|
|
|
dims = newIndex + 1; |
|
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims); |
|
} |
|
|
|
} |
|
|