NCCL Tree源码
https://github.com/NVIDIA/nccl/blob/master/src/misc/trees.c
https://github.com/NVIDIA/nccl/pull/172
为了让代码运行起来,进行了小改
#include <iostream>
using namespace std;
int GetBtree(int nranks, int rank, int* u, int* d0, int* d1);
int GetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* s1, int* d1_0, int* d1_1);
main()
{
int nranks=12;
int i=0, u=0, d0=0, d1=0, du=0, dd0=0, dd1=0;
for ( i=0; i<nranks; i++ ) {
GetDtree (nranks, i, &u, &d0, &d1, &du, &dd0, &dd1);
cout << "i=" << i << ";\tu=" << u << ",\td0=" << d0 << ",\td1=" << d1 << endl;
cout << "\tdu=" << du << ",\tdd0=" << dd0 << ",\tdd1=" << dd1 << endl;
}
}
//得到单向tree
int GetBtree(int nranks, int rank, int* u, int* d0, int* d1) {
int up, down0, down1;
int bit;
//和IP掩码类似,从低位往高位,第一个不为0的位,标志了这个节点以下的子树边界
for (bit=1; bit<nranks; bit<<=1) {
if (bit & rank) break;
}
if (rank == 0) {
*u = -1;
*d0 = nranks > 1 ? bit >> 1 : -1;
*d1 = -1;
return 0;
}
up = (rank ^ bit) | (bit << 1);
if (up >= nranks) up = (rank ^ bit);
*u = up;
int lowbit = bit >> 1;
down0 = lowbit == 0 ? -1 : rank-lowbit;
down1 = lowbit == 0 ? -1 : rank+lowbit;
while (down1 >= nranks) {
down1 = lowbit == 0 ? -1 : rank+lowbit;
lowbit >>= 1;
}
*d0 = down0; *d1 = down1;
return 0;
}
int GetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* s1, int* d1_0, int* d1_1) {
GetBtree(nranks, rank, s0, d0_0, d0_1);
if (nranks % 2 == 0) {
int shiftrank = (rank-1+nranks) % nranks;
int u, d0, d1;
GetBtree(nranks, shiftrank, &u, &d0, &d1);
*s1 = u == -1 ? -1 : (u+1) % nranks;
*d1_0 = d0 == -1 ? -1 : (d0+1) % nranks;
*d1_1 = d1 == -1 ? -1 : (d1+1) % nranks;
} else {
int shiftrank=nranks-rank;
int u, d0, d1;
GetBtree(nranks, shiftrank, &u, &d0, &d1);
*s1 = u == -1 ? -1 : nranks-u;
*d1_0 = d0 == -1 ? -1 : nranks-d0;
*d1_1 = d1 == -1 ? -1 : nranks-d1;
}
return 0;
}
NCCL-tree生成的ranking方案
代码需要修改一下才能实现注释中提到的ranking方案
奇数节点Ranking
对于奇数个rank,比如下图的13,NV建议对第二个tree采用mirror ranking,也就是结构上镜像
文档中原图
* 8---------0---------5
* ______/ \______ _____/ \______
* 4 12 1 9
* / \ / \ / \
* 2 6 10 3 7 10
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 11 12
*
这里文档的右下角写错了,因为如果按照这个路子画,程序跑出来的树应该是
修正右下角
* 8---------0---------5
* ______/ \______ _____/ \______
* 4 12 1 9
* / \ / \ / \
* 2 6 10 3 7 11
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 12 10
如果不按照NV的做法,直接shift ranking,那么就是下图所示的结构,看上去也不错
在奇数节点中应用shift ranking
* 8---------0-----------------9
* ______/ \______ ______/ \______
* 4 12 5 1
* / \ / / \ /
* 2 6 10 3 7 11
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 10 12
偶数节点ranking
对于偶数个rank,比如下图的12,NV建议采用shift ranking
文档中原图
* or shift it by one rank (if nranks is even)
*
* 8---------0--------------9
* ______/ \ ______/ \
* 4 \ 5 \
* / \ \ / \ \
* 2 6 10 3 7 11
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 10 1
*/
这里文档又写错了,真正运行出来的结果应该是1变成了新root
修正根节点
* 8---------0--1-----------9
* ______/ \ ______/ \
* 4 \ 5 \
* / \ \ / \ \
* 2 6 10 3 7 11
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 10 0
我试试这个结构能不能做mirror ranking
结果发现leaf节点还都是奇数编号,这就是为啥偶数ranks时候不能mirror的原因
在偶数节点中应用mirror ranking
* 8---------0--------4
* ______/ \ / \______
* 4 \ / 8
* / \ \ / / \
* 2 6 10 2 6 10
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 1 3 5 7 9 11
*/
测试程序输出
nrank=13的时候:
i=0; u=-1, d0=8, d1=-1
du=-1, dd0=5, dd1=-1
i=1; u=2, d0=-1, d1=-1
du=5, dd0=3, dd1=-1
i=2; u=4, d0=1, d1=3
du=3, dd0=-1, dd1=-1
i=3; u=2, d0=-1, d1=-1
du=1, dd0=4, dd1=2
i=4; u=8, d0=2, d1=6
du=3, dd0=-1, dd1=-1
i=5; u=6, d0=-1, d1=-1
du=13, dd0=9, dd1=1
i=6; u=4, d0=5, d1=7
du=7, dd0=-1, dd1=-1
i=7; u=6, d0=-1, d1=-1
du=9, dd0=8, dd1=6
i=8; u=0, d0=4, d1=12
du=7, dd0=-1, dd1=-1
i=9; u=10, d0=-1, d1=-1
du=5, dd0=11, dd1=7
i=10; u=12, d0=9, d1=11
du=11, dd0=-1, dd1=-1
i=11; u=10, d0=-1, d1=-1
du=9, dd0=12, dd1=10
i=12; u=8, d0=10, d1=-1
du=11, dd0=-1, dd1=-1
nrank=12的时候
i=0; u=-1, d0=8, d1=-1
du=11, dd0=-1, dd1=-1
i=1; u=2, d0=-1, d1=-1
du=-1, dd0=9, dd1=-1
i=2; u=4, d0=1, d1=3
du=3, dd0=-1, dd1=-1
i=3; u=2, d0=-1, d1=-1
du=5, dd0=2, dd1=4
i=4; u=8, d0=2, d1=6
du=3, dd0=-1, dd1=-1
i=5; u=6, d0=-1, d1=-1
du=9, dd0=3, dd1=7
i=6; u=4, d0=5, d1=7
du=7, dd0=-1, dd1=-1
i=7; u=6, d0=-1, d1=-1
du=5, dd0=6, dd1=8
i=8; u=0, d0=4, d1=10
du=7, dd0=-1, dd1=-1
i=9; u=10, d0=-1, d1=-1
du=1, dd0=5, dd1=11
i=10; u=8, d0=9, d1=11
du=11, dd0=-1, dd1=-1
i=11; u=10, d0=-1, d1=-1
du=9, dd0=10, dd1=0