1#ifndef _COMM_GRID_3D_H_
2#define _COMM_GRID_3D_H_
21 CommGrid3D(MPI_Comm world,
int nlayers,
int nrowproc,
int ncolproc,
bool special =
false): gridLayers(nlayers), gridRows(nrowproc), gridCols(ncolproc), special(special)
24 MPI_Comm_dup(world, & world3D);
25 MPI_Comm_rank(world3D, & myrank);
26 MPI_Comm_size(world3D, & nproc);
29 cerr <<
"A 3D grid can not be created with less than one layer" << endl;
32 if(nproc % nlayers != 0){
33 cerr <<
"Number of processes is not divisible by number of layers" << endl;
37 if(((
int)std::sqrt((
float)nlayers) * (
int)std::sqrt((
float)nlayers)) != nlayers){
39 cerr <<
"Number of layers is not a square number" << endl;
44 int procPerLayer = nproc / nlayers;
46 if(gridRows == 0 && gridCols == 0)
49 gridRows = (int)std::sqrt((
float)procPerLayer);
53 if(gridRows * gridCols != procPerLayer)
55 cerr <<
"This version of the Combinatorial BLAS only works on a square logical processor grid in a layer of the 3D grid" << endl;
60 assert((nproc == (gridRows * gridCols * gridLayers)));
63 int nCol2D = (int)std::sqrt((
float)nproc);
64 int rankInRow2D = myrank / nCol2D;
65 int rankInCol2D = myrank % nCol2D;
66 int sqrtLayer = (int)std::sqrt((
float)nlayers);
67 rankInFiber = (rankInCol2D % sqrtLayer) * sqrtLayer + (rankInRow2D % sqrtLayer);
68 rankInLayer = (rankInRow2D / sqrtLayer) * gridCols + (rankInCol2D / sqrtLayer);
69 MPI_Comm_split(world3D, rankInFiber, rankInLayer, &layerWorld);
70 MPI_Comm_split(world3D, rankInLayer, rankInFiber, &fiberWorld);
73 rankInFiber = myrank / procPerLayer;
74 rankInLayer = myrank % procPerLayer;
75 MPI_Comm_split(world3D, rankInFiber, rankInLayer, &layerWorld);
76 MPI_Comm_split(world3D, rankInLayer, rankInFiber, &fiberWorld);
79 commGridLayer.reset(
new CommGrid(layerWorld, gridRows, gridCols));
83 MPI_Comm_free(&world3D);
84 MPI_Comm_free(&fiberWorld);
85 MPI_Comm_free(&layerWorld);
90 int GetRank(
int layerrank,
int rowrank,
int colrank) {
91 if(!special)
return layerrank * gridRows * gridCols + rowrank * gridCols + colrank;
99 int GetSize() {
return gridLayers * gridRows * gridCols; }
119 std::shared_ptr<CommGrid> commGridLayer;
MPI_Comm & GetFiberWorld()
MPI_Comm & GetLayerWorld()
CommGrid3D(MPI_Comm world, int nlayers, int nrowproc, int ncolproc, bool special=false)
int GetRank(int layerrank, int rowrank, int colrank)
std::shared_ptr< CommGrid > GetCommGridLayer()