Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ struct cudecompCommInfo {
nvshmem_team_t nvshmem_team = NVSHMEM_TEAM_INVALID;
uint64_t* nvshmem_signals = nullptr;
#endif

bool mnnvl_active = false; // flag to indicate whether communicator has MNNVL connections
};

// Structure to contain data for transpose performance sample
Expand Down Expand Up @@ -349,6 +351,23 @@ static void setCommInfo(cudecompHandle_t& handle, cudecompGridDesc_t& grid_desc,
if (count != e.second) { count = gcd(count, e.second); }
}
}

// Check if any cliques contain multiple nodes (i.e. there are MNNVL connections in this communicator)
std::map<unsigned int, std::string> clique_to_hostname;
for (int i = 0; i < info.nranks; ++i) {
int peer_rank_global = getGlobalRank(handle, grid_desc, comm_axis, i);
unsigned int clique = handle->rank_to_clique[peer_rank_global];
std::string hostname = std::string(handle->hostnames[peer_rank_global].data());
if (clique_to_hostname.count(clique)) {
if (clique_to_hostname[clique] != hostname) {
// Multiple hostnames in clique detected, MNNVL connections are present
info.mnnvl_active = true;
break;
}
} else {
clique_to_hostname[clique] = hostname;
}
}
}

info.npergroup = count;
Expand Down
26 changes: 15 additions & 11 deletions include/internal/halo.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
CHECK_CUDA(cudaEventRecord(current_sample->halo_start_event, stream));
}

int count = 0;
for (int i = 0; i < 3; ++i) {
if (i == ax) continue;
if (i == dim) break;
count++;
}

auto comm_axis = (count == 0) ? CUDECOMP_COMM_COL : CUDECOMP_COMM_ROW;
int comm_rank = (comm_axis == CUDECOMP_COMM_COL) ? grid_desc->col_comm_info.rank : grid_desc->row_comm_info.rank;
auto& comm_info = (comm_axis == CUDECOMP_COMM_COL) ? grid_desc->col_comm_info : grid_desc->row_comm_info;

// Select correct case based on pencil memory order and transfer dim
int c;
if (dim != pinfo_h.order[0] && dim != pinfo_h.order[1]) {
Expand All @@ -103,16 +114,6 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
return;
} else {
// For multi-rank cases, check if halos include ranks other than nearest neighbor process (unsupported currently).
int count = 0;
for (int i = 0; i < 3; ++i) {
if (i == ax) continue;
if (i == dim) break;
count++;
}

auto comm_axis = (count == 0) ? CUDECOMP_COMM_COL : CUDECOMP_COMM_ROW;
int comm_rank = (comm_axis == CUDECOMP_COMM_COL) ? grid_desc->col_comm_info.rank : grid_desc->row_comm_info.rank;

auto splits =
getSplits(grid_desc->config.gdims_dist[dim], grid_desc->config.pdims[comm_axis == CUDECOMP_COMM_COL ? 0 : 1],
grid_desc->config.gdims[dim] - grid_desc->config.gdims_dist[dim]);
Expand Down Expand Up @@ -143,9 +144,12 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
bool input_has_padding = anyNonzeros(padding);

if (c == 2 && (input_has_padding || haloBackendRequiresNvshmem(grid_desc->config.halo_comm_backend) ||
(managed && haloBackendRequiresMpi(grid_desc->config.halo_comm_backend)))) {
(managed && haloBackendRequiresMpi(grid_desc->config.halo_comm_backend)) ||
(handle->cuda_cumem_enable && comm_info.mnnvl_active &&
haloBackendRequiresMpi(grid_desc->config.halo_comm_backend)))) {
// For padded input, always stage to work space.
// For managed memory, always stage to work space if using MPI.
// If using MPI and communicator has MNNVL connections, stage to work space if fabric-allocated.
// For any memory, always stage to workspace if using NVSHMEM.
// Can revisit for NVSHMEM if input is NVSHMEM allocated.
c = 1;
Expand Down
11 changes: 8 additions & 3 deletions include/internal/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,16 @@ static void cudecompTranspose_(int ax, int dir, const cudecompHandle_t handle, c
// in to workspace (which should be nvshmem allocated). Can revisit support for input/output
// arrays allocated with nvshmem.
enable = false;
} else if (transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) &&
(isManagedPointer(input) || isManagedPointer(output))) {
} else if (transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend)) {
// Note: For MPI, disable special cases if input or output pointers are to managed memory
// since MPI performance directly from managed memory is not great
enable = false;
if (isManagedPointer(input) || isManagedPointer(output)) { enable = false; }

// Note: For MPI, disable special cases if communicator has an MNNVL connection and the workspace
// is fabric allocated. This forces MPI comms to always use the fabric allocated workspace
// which is more performant.
auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
if (handle->cuda_cumem_enable && comm_info.mnnvl_active) { enable = false; }
}

if (enable) {
Expand Down