Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/cudecomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

#define CUDECOMP_MAJOR 0
#define CUDECOMP_MINOR 6
#define CUDECOMP_PATCH 1
#define CUDECOMP_PATCH 2

#ifdef __cplusplus
extern "C" {
Expand Down
3 changes: 2 additions & 1 deletion include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ struct cudecompHandle {

cutensorHandle_t cutensor_handle; // cuTENSOR handle;
#if CUTENSOR_MAJOR >= 2
cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference;
cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference;
bool cutensor_needs_permute_chunking = false; // Flag to enable large tensor workaround
#endif

std::vector<std::array<char, MPI_MAX_PROCESSOR_NAME>> hostnames; // list of hostnames by rank
Expand Down
44 changes: 44 additions & 0 deletions include/internal/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,50 @@ static void localPermute(const cudecompHandle_t handle, const std::array<int64_t
if (extent_out[i] == 0) return;
}

// Workaround for an out-of-bounds host write bug in cuTENSOR triggered when the
// total number of tensor elements exceeds INT32_MAX/2. We split the tensor so each
// cuTENSOR call stays below that limit.
static constexpr int64_t CUTENSOR_EXTENT_LIMIT = (int64_t)std::numeric_limits<int32_t>::max() / 2;
int64_t total_elems = extent_in[0] * extent_in[1] * extent_in[2];
if (handle->cutensor_needs_permute_chunking && total_elems > CUTENSOR_EXTENT_LIMIT) {

// Always pass explicit strides when splitting
std::array<int64_t, 3> actual_strides_in = strides_in;
if (!anyNonzeros(strides_in)) { actual_strides_in = {extent_in[1] * extent_in[2], extent_in[2], 1}; }
std::array<int64_t, 3> actual_strides_out = strides_out;
if (!anyNonzeros(strides_out)) { actual_strides_out = {extent_out[1] * extent_out[2], extent_out[2], 1}; }
// Try to split on input dims, starting with outermost dim.
std::array<int, 3> inv_order_out;
for (int i = 0; i < 3; ++i)
inv_order_out[order_out[i]] = i;
int split_dim_in = -1;
int64_t elems_per_slice = 0;
for (int j = 2; j >= 0; --j) {
elems_per_slice = total_elems / extent_in[j];
if (elems_per_slice <= CUTENSOR_EXTENT_LIMIT) {
split_dim_in = j;
break;
}
}

if (split_dim_in >= 0) {
// Run localPermute multiple times, once per chunk.
int64_t chunk = std::max((int64_t)1, CUTENSOR_EXTENT_LIMIT / elems_per_slice);
for (int64_t offset = 0; offset < extent_in[split_dim_in]; offset += chunk) {
int64_t this_chunk = std::min(chunk, extent_in[split_dim_in] - offset);
std::array<int64_t, 3> chunk_extent_in = extent_in;
chunk_extent_in[split_dim_in] = this_chunk;
localPermute(handle, chunk_extent_in, order_out, actual_strides_in, actual_strides_out,
input + offset * actual_strides_in[split_dim_in],
output + offset * actual_strides_out[inv_order_out[split_dim_in]], stream);
}
return;
}
// All pairwise products exceed the limit so splitting isn't possible (requires each dimension > sqrt(INT32_MAX/2)
// ~= 32768). This is not a realistic scenario, but throw an error here for completeness.
THROW_INTERNAL_ERROR("Input too large to work around CUTENSOR large-tensor bug");
}

auto strides_in_ptr = anyNonzeros(strides_in) ? strides_in.data() : nullptr;
auto strides_out_ptr = anyNonzeros(strides_out) ? strides_out.data() : nullptr;

Expand Down
5 changes: 5 additions & 0 deletions src/cudecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) {
CHECK_CUTENSOR(cutensorCreate(&handle->cutensor_handle));
CHECK_CUTENSOR(cutensorCreatePlanPreference(handle->cutensor_handle, &handle->cutensor_plan_pref,
CUTENSOR_ALGO_DEFAULT, CUTENSOR_JIT_MODE_NONE));
// cuTENSOR versions 2.3.x - 2.5.x have a bug where cutensorCreatePlan performs an out-of-bounds
// host write when the total number of tensor elements exceeds INT32_MAX/2. Set a flag
// to enable workaround in localPermute to split large tensors.
size_t cutensor_ver = cutensorGetVersion();
handle->cutensor_needs_permute_chunking = (cutensor_ver >= 20300 && cutensor_ver < 20600);
#else
CHECK_CUTENSOR(cutensorInit(&handle->cutensor_handle));
#endif
Expand Down
Loading