在DLRM中有对训练集做处理的函数,我们对训练序列做了研究,
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
# WARNING: notice that we are processing the batch at once. We implicitly
# assume that the data is laid out such that:
# 1. each embedding is indexed with a group of sparse indices,
# corresponding to a single lookup
# 2. for each embedding the lookups are further organized into a batch
# 3. for a list of embedding tables there is a list of batched lookups
ly = []
for k, sparse_index_group_batch in enumerate(lS_i):
sparse_offset_group_batch = lS_o[k]
# embedding lookup
# We are using EmbeddingBag, which implicitly uses sum operator.
# The embeddings are represented as tall matrices, with sum
# happening vertically across 0 axis, resulting in a row vector
# E = emb_l[k]
if v_W_l[k] is not None:
per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
else:
per_sample_weights = None
if:
....
else:
E = emb_l[k]
V = E(
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(V)
重点是这个地方,其中E是所有打包好的Embedding:
其中第一维为这个Embedding table中包括的vector的数量,第二维64为vector的维度(有64个float)。
sparse_index_group_batch
以及sparse_offset_group_batch
为训练时需要的index以及offset,Embedding table会根据index找具体的vector。
offset需要注意,offset = torch.LongTensor([0,1,4]).to(0)
代表三个样本,第一个样本是0 ~ 1,第二个是1 ~ 4,第三个是4(网上解释的都不够清楚,所以我这里通过代码实际跑了一下测出来是这个结果) 。且左闭右开[0,1)这种形式取整数(已经根据代码进行过验证)。
详细解释一下流程:
首先在apply_emb
函数中每次循环会取出当前第k个Emb table:E = emb_l[k]
,其中k是当前所在轮数。
对于index数组与offset数组:
我们能看到,第一个tensor是index,有五个元素,代表我要取的当前table中的vector的编号(共5个)。
而后面的offset就代表我取出来的这5个数组哪些要进行reduce操作(加和等)。
例如我如果取offset为[0,3],则代表0,1,2相加进行reduce,3,4进行reduce。所以最终出来的数字个数就是offset的size。
IS_I以及IS_O生成的位置
在dlrm_data_pytorch.py中的collate_wrapper_criteo_offset()
函数里:
def collate_wrapper_criteo_offset(list_of_tuples):
# where each tuple is (X_int, X_cat, y)
transposed_data = list(zip(*list_of_tuples))
X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)
batchSize = X_cat.shape[0]
featureCnt = X_cat.shape[1]
lS_i = [X_cat[:, i] for i in range(featureCnt)]
lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
return X_int, torch.stack(lS_o), torch.stack(lS_i), T
在这里生成访问序列,首先将传入的数据解析为X_cat,当bs=2时,X_cat为:
tensor([[ 0, 17, 36684, 11838, 1, 0, 145, 9, 0, 1176,
24, 34569, 24, 5, 24, 15109, 0, 19, 14, 3,
32351, 0, 1, 4159, 32, 5050],
[ 3, 12, 33818, 19987, 0, 5, 1426, 1, 0, 8616,
729, 31879, 658, 1, 50, 26833, 1, 12, 89, 0,
29850, 0, 1, 1637, 3, 1246]])
其中每一个tensor有26个数字,代表26个Embedding table。每一个数字代表其中每个table需要访问的vector。(比如0代表访问第一个table的0号vector)
下面将访问序列打包,IS_i为:
[tensor([0, 3]), tensor([17, 12]), tensor([36684, 33818]), tensor([11838, 19987]), tensor([1, 0]), tensor([0, 5]), tensor([ 145, 1426]), tensor([9, 1]), tensor([0, 0]), tensor([1176, 8616]), tensor([ 24, 729]), tensor([34569, 31879]), tensor([ 24, 658]), tensor([5, 1]), tensor([24, 50]), tensor([15109, 26833]), tensor([0, 1]), tensor([19, 12]), tensor([14, 89]), tensor([3, 0]), tensor([32351, 29850]), tensor([0, 0]), tensor([1, 1]), tensor([4159, 1637]), tensor([32, 3]), tensor([5050, 1246])]
这里bs为2,所以[tensor([0, 3])
代表访问第一个table的0,3个vactor。
这里我们要再次理解一下数据集的含义,这里每一个table都是用户的一个特征(所在城市、年龄等),所以每一个用户也就是每个table拥有一个数值,所以当bs=2时,这里的tensor[0,3]代表对两个用户进行训练,其中第一个用户的第一个table取值是0号vector,第二个用户第一个table取值是3号vector。