diff --git a/include/cudecomp.h b/include/cudecomp.h index b7ca852..9d5e5dd 100644 --- a/include/cudecomp.h +++ b/include/cudecomp.h @@ -29,7 +29,7 @@ #define CUDECOMP_MAJOR 0 #define CUDECOMP_MINOR 6 -#define CUDECOMP_PATCH 1 +#define CUDECOMP_PATCH 2 #ifdef __cplusplus extern "C" { diff --git a/include/internal/common.h b/include/internal/common.h index f11e9ca..24fa781 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -76,7 +76,8 @@ struct cudecompHandle { cutensorHandle_t cutensor_handle; // cuTENSOR handle; #if CUTENSOR_MAJOR >= 2 - cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference; + cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference; + bool cutensor_needs_permute_chunking = false; // Flag to enable large tensor workaround #endif std::vector> hostnames; // list of hostnames by rank diff --git a/include/internal/transpose.h b/include/internal/transpose.h index 84cbe77..d4140c9 100644 --- a/include/internal/transpose.h +++ b/include/internal/transpose.h @@ -82,6 +82,50 @@ static void localPermute(const cudecompHandle_t handle, const std::array::max() / 2; + int64_t total_elems = extent_in[0] * extent_in[1] * extent_in[2]; + if (handle->cutensor_needs_permute_chunking && total_elems > CUTENSOR_EXTENT_LIMIT) { + + // Always pass explicit strides when splitting + std::array actual_strides_in = strides_in; + if (!anyNonzeros(strides_in)) { actual_strides_in = {extent_in[1] * extent_in[2], extent_in[2], 1}; } + std::array actual_strides_out = strides_out; + if (!anyNonzeros(strides_out)) { actual_strides_out = {extent_out[1] * extent_out[2], extent_out[2], 1}; } + // Try to split on input dims, starting with outermost dim. + std::array inv_order_out; + for (int i = 0; i < 3; ++i) + inv_order_out[order_out[i]] = i; + int split_dim_in = -1; + int64_t elems_per_slice = 0; + for (int j = 2; j >= 0; --j) { + elems_per_slice = total_elems / extent_in[j]; + if (elems_per_slice <= CUTENSOR_EXTENT_LIMIT) { + split_dim_in = j; + break; + } + } + + if (split_dim_in >= 0) { + // Run localPermute multiple times, once per chunk. + int64_t chunk = std::max((int64_t)1, CUTENSOR_EXTENT_LIMIT / elems_per_slice); + for (int64_t offset = 0; offset < extent_in[split_dim_in]; offset += chunk) { + int64_t this_chunk = std::min(chunk, extent_in[split_dim_in] - offset); + std::array chunk_extent_in = extent_in; + chunk_extent_in[split_dim_in] = this_chunk; + localPermute(handle, chunk_extent_in, order_out, actual_strides_in, actual_strides_out, + input + offset * actual_strides_in[split_dim_in], + output + offset * actual_strides_out[inv_order_out[split_dim_in]], stream); + } + return; + } + // All pairwise products exceed the limit so splitting isn't possible (requires each dimension > sqrt(INT32_MAX/2) + // ~= 32768). This is not a realistic scenario, but throw an error here for completeness. + THROW_INTERNAL_ERROR("Input too large to work around CUTENSOR large-tensor bug"); + } + auto strides_in_ptr = anyNonzeros(strides_in) ? strides_in.data() : nullptr; auto strides_out_ptr = anyNonzeros(strides_out) ? strides_out.data() : nullptr; diff --git a/src/cudecomp.cc b/src/cudecomp.cc index ae07c18..6cbdffa 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -462,6 +462,11 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { CHECK_CUTENSOR(cutensorCreate(&handle->cutensor_handle)); CHECK_CUTENSOR(cutensorCreatePlanPreference(handle->cutensor_handle, &handle->cutensor_plan_pref, CUTENSOR_ALGO_DEFAULT, CUTENSOR_JIT_MODE_NONE)); + // cuTENSOR versions 2.3.x - 2.5.x have a bug where cutensorCreatePlan performs an out-of-bounds + // host write when the total number of tensor elements exceeds INT32_MAX/2. Set a flag + // to enable workaround in localPermute to split large tensors. + size_t cutensor_ver = cutensorGetVersion(); + handle->cutensor_needs_permute_chunking = (cutensor_ver >= 20300 && cutensor_ver < 20600); #else CHECK_CUTENSOR(cutensorInit(&handle->cutensor_handle)); #endif