DLRM代码理解

在DLRM中有对训练集做处理的函数,我们对训练序列做了研究,

    def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
        # WARNING: notice that we are processing the batch at once. We implicitly
        # assume that the data is laid out such that:
        # 1. each embedding is indexed with a group of sparse indices,
        #   corresponding to a single lookup
        # 2. for each embedding the lookups are further organized into a batch
        # 3. for a list of embedding tables there is a list of batched lookups

        ly = []
        for k, sparse_index_group_batch in enumerate(lS_i):
            sparse_offset_group_batch = lS_o[k]

            # embedding lookup
            # We are using EmbeddingBag, which implicitly uses sum operator.
            # The embeddings are represented as tall matrices, with sum
            # happening vertically across 0 axis, resulting in a row vector
            # E = emb_l[k]

            if v_W_l[k] is not None:
                per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
            else:
                per_sample_weights = None

            if:
                ....
            else:
                E = emb_l[k]
                V = E(
                    sparse_index_group_batch,
                    sparse_offset_group_batch,
                    per_sample_weights=per_sample_weights,
                )
            
                ly.append(V)

重点是这个地方,其中E是所有打包好的Embedding:

image.png

其中第一维为这个Embedding table中包括的vector的数量,第二维64为vector的维度(有64个float)。

sparse_index_group_batch以及sparse_offset_group_batch为训练时需要的index以及offset,Embedding table会根据index找具体的vector。

offset需要注意,offset = torch.LongTensor([0,1,4]).to(0)代表三个样本,第一个样本是0 ~ 1,第二个是1 ~ 4,第三个是4(网上解释的都不够清楚,所以我这里通过代码实际跑了一下测出来是这个结果) 。且左闭右开[0,1)这种形式取整数(已经根据代码进行过验证)。

详细解释一下流程:

首先在apply_emb函数中每次循环会取出当前第k个Emb table:E = emb_l[k],其中k是当前所在轮数。

对于index数组与offset数组:

image.png

我们能看到,第一个tensor是index,有五个元素,代表我要取的当前table中的vector的编号(共5个)。

而后面的offset就代表我取出来的这5个数组哪些要进行reduce操作(加和等)。

例如我如果取offset为[0,3],则代表0,1,2相加进行reduce,3,4进行reduce。所以最终出来的数字个数就是offset的size。

IS_I以及IS_O生成的位置

在dlrm_data_pytorch.py中的collate_wrapper_criteo_offset()函数里:

def collate_wrapper_criteo_offset(list_of_tuples):
    # where each tuple is (X_int, X_cat, y)
    transposed_data = list(zip(*list_of_tuples))
    X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
    X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
    T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)

    batchSize = X_cat.shape[0]
    featureCnt = X_cat.shape[1]
    lS_i = [X_cat[:, i] for i in range(featureCnt)]
    lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
    return X_int, torch.stack(lS_o), torch.stack(lS_i), T

在这里生成访问序列,首先将传入的数据解析为X_cat,当bs=2时,X_cat为:

tensor([[    0,    17, 36684, 11838,     1,     0,   145,     9,     0,  1176,
            24, 34569,    24,     5,    24, 15109,     0,    19,    14,     3,
         32351,     0,     1,  4159,    32,  5050],
        [    3,    12, 33818, 19987,     0,     5,  1426,     1,     0,  8616,
           729, 31879,   658,     1,    50, 26833,     1,    12,    89,     0,
         29850,     0,     1,  1637,     3,  1246]])

其中每一个tensor有26个数字,代表26个Embedding table。每一个数字代表其中每个table需要访问的vector。(比如0代表访问第一个table的0号vector)

下面将访问序列打包,IS_i为:

[tensor([0, 3]), tensor([17, 12]), tensor([36684, 33818]), tensor([11838, 19987]), tensor([1, 0]), tensor([0, 5]), tensor([ 145, 1426]), tensor([9, 1]), tensor([0, 0]), tensor([1176, 8616]), tensor([ 24, 729]), tensor([34569, 31879]), tensor([ 24, 658]), tensor([5, 1]), tensor([24, 50]), tensor([15109, 26833]), tensor([0, 1]), tensor([19, 12]), tensor([14, 89]), tensor([3, 0]), tensor([32351, 29850]), tensor([0, 0]), tensor([1, 1]), tensor([4159, 1637]), tensor([32,  3]), tensor([5050, 1246])]

这里bs为2,所以[tensor([0, 3])代表访问第一个table的0,3个vactor。

这里我们要再次理解一下数据集的含义,这里每一个table都是用户的一个特征(所在城市、年龄等),所以每一个用户也就是每个table拥有一个数值,所以当bs=2时,这里的tensor[0,3]代表对两个用户进行训练,其中第一个用户的第一个table取值是0号vector,第二个用户第一个table取值是3号vector。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,937评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,503评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,712评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,668评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,677评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,601评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,975评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,637评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,881评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,621评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,710评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,387评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,971评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,947评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,189评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,805评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,449评论 2 342

推荐阅读更多精彩内容

  • 该文章为转载文章,作者简介:汪剑,现在在出门问问负责推荐与个性化。曾在微软雅虎工作,从事过搜索和推荐相关工作。 T...
    名字真的不重要阅读 5,187评论 0 3
  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 8,520评论 28 53
  • 信任包括信任自己和信任他人 很多时候,很多事情,失败、遗憾、错过,源于不自信,不信任他人 觉得自己做不成,别人做不...
    吴氵晃阅读 6,178评论 4 8
  • 步骤:发微博01-导航栏内容 -> 发微博02-自定义TextView -> 发微博03-完善TextView和...
    dibadalu阅读 3,125评论 1 3
  • 回这一趟老家,心里多了两个疙瘩。第一是堂姐现在谈了一个有妇之夫,在她的语言中感觉,她不打算跟他有太长远的计划,这让...
    安九阅读 3,498评论 2 4