Dataset类分批加载数据集

在做NLP任务的时候,需要分批加载数据集进行训练,这个时候可以继承pytorch.utils.data中的Dataset类,就可以进行分批加载数据,并且可以将数据转换成tensor对象数据.
处理流程:


image.png

1.自定义Dataset类

这个类要配合的torch.utils.data 中的DataLoader类才可以发挥作用

# 因为我在数据预处理的时候将转换成id的数据集全部持久化处理了,所以需要这个方法加载数据
# 获取文件
def load_pkl(path,obj_name):
    print(f'get{obj_name} in {path}')
    with codecs.open(path,'rb')as f:
        data=pkl.load(f)
    return data

# 第三方库
import torch
from torch.utils.data import Dataset

# 自定义库
from BruceNRE.utils import load_pkl
# 数据集的加载
class CustomDataset(Dataset):
    def __init__(self,file_path,obj_name):
        self.file=load_pkl(file_path,obj_name)

    def __getitem__(self, item):
        sample=self.file[item]
        return sample

    def __len__(self):
        return len(self.file)

# 这个方法负责将数据进行填充,并且转换成tensor对象
def collate_fn(batch):
# 把这个批次中的数据按照list长度由高到低排序
    batch.sort(key=lambda data: len(data[0]),reverse=True)
# 将这个批次中数据长度放到len集合中
    lens=[len(data[0])for data in batch]
# 获得最大的长度
    max_len=max(lens)

    sent_list=[]
    head_pos_list=[]
    tail_pos_list=[]
    mask_pos_list=[]
    relation_list=[]

    # 填充数据,都用0来填充
    def _padding(x,max_len):
        return x+[0]*(max_len-len(x))
# 把数据集转换成tensor对象,然后封装到对应的list中
    for data in batch:
        sent,head_pos,tail_pos,mask_pos,relation=data
        sent_list.append(_padding(sent,max_len))
        head_pos_list.append(_padding(tail_pos,max_len))
        tail_pos_list.append(_padding(tail_pos,max_len))
        mask_pos_list.append(_padding(mask_pos,max_len))
        relation_list.append(relation)

    # 将numpy转换为tensor
    return torch.tensor(sent_list),torch.tensor(head_pos_list),torch.tensor(tail_pos_list),torch.tensor(mask_pos_list),torch.tensor(relation_list)

这个类解释一下作用:

  • init方法:把所有数据集加载进来
  • getitem:如果设置suffle为True就会打乱数据,传递数据的索引给getitem,就是item,然后根据索引加载数据.
  • len:获取数据集的索引长度
  • collate_fn:因为使用DataLoader这个类要求每一个批次中的数据的长度必须要一样,所以这个方法有两个作用,第一个作用就是把数据集全部用0填充到相同的长度,然后将数据集(是转换成字典标志位的数据集)转换成tensor对象

2.使用Dataset类

# 调用Dataset实现类
train_dataset=CustomDataset(train_data_path,'train-data')
# 将train_dataset放到DataLoader中,才可以使用
train_dataloader=DataLoader(
        dataset=train_dataset,
        batch_size=128,
        shuffle=True,
        drop_last=True,
        collate_fn=collate_fn
    )

    for batch_idx,batch in enumerate(train_dataloader):
        *x,y=[data.to(device) for data in batch]
    print('dataloader测试完成')

参数解析:
dataset:Dataset类封装的数据集
batch_size:每个批次处理的数据量,一般128或者64
shuffle:是否打乱顺序
drop_last:丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
collate_fn:处理数据集成一样的长度,并且转换成tensor对象的方法

==============================================================

3.再看个例子:

我的原始数据格式:

体验2D巅峰 倚天屠龙记十大创新概览  8
60年铁树开花形状似玉米芯(组图)   5
同步A股首秀:港股缩量回调   2
中青宝sg现场抓拍 兔子舞热辣表演   8
锌价难续去年辉煌    0
2岁男童爬窗台不慎7楼坠下获救(图)  5
布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击  7
金科西府 名墅天成   1
状元心经:考前一周重点是回顾和整理   3
发改委治理涉企收费每年为企业减负超百亿 6
一年网事扫荡10年纷扰开心网李鬼之争和平落幕  4
2010英国新政府“三把火”或影响留学业    3
俄达吉斯坦共和国一名区长被枪杀 6
朝鲜要求日本对过去罪行道歉和赔偿    6
《口袋妖怪 黑白》日本首周贩售255万 8
图文:借贷成本上涨致俄罗斯铝业净利下滑21%  2
组图:新《三国》再曝海量剧照 火战场面极震撼  9
麻辣点评:如何走出“被留学”的尴尬   3
  • 创建一个Dataset的子类来处理数据
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer,BertConfig,BertModel
bert_model='./bert-base-chinese'
myconfig = BertConfig.from_pretrained("./bert-base-chinese")
tokenizer=BertTokenizer.from_pretrained(bert_model)
MAX_LEN = 256 - 2

class ElementDataset(Dataset):
    def __init__(self, f_path):
        sents, label_li = [], []  # list of lists
        with open(f_path, 'r', encoding='utf-8') as fr:
            for line in fr:
                if len(line) < 10:
                    continue
                entries = line.strip().split('\t')
                words = entries[0]
                label = entries[1:]
                label = list(map(int, label))
                sents.append(words)
                label_li.append(label)
        self.sents, self.label_li = sents, label_li

    def __getitem__(self, item):
        words,tags=self.sents[item],self.label_li[item]
        inputs=tokenizer.encode_plus(words)
        label=tags
        seqlen = len(inputs['input_ids'])
        sample=(inputs,label,seqlen)
        return sample

    def __len__(self):
        print('sents')
        return len(self.sents)

    # 填充
def collate_fn(batch):
    all_input_ids=[]
    all_attention_mask=[]
    all_token_type_ids=[]
    all_labels=[]
    lens=[data[2] for data in batch]
    max_len=max(lens)
    def padding(input,max_len,pad_token):
        return input+[pad_token]*(max_len-len(input))

    for data in batch:
        input,tags,_=data
        all_input_ids.append(padding(input['input_ids'],max_len,1))
        all_token_type_ids.append(padding(input['token_type_ids'],max_len,0))
        all_attention_mask.append(padding(input['attention_mask'],max_len,0))
        all_labels.append(tags)
    return torch.tensor(all_input_ids),torch.tensor(all_token_type_ids),torch.tensor(all_attention_mask),all_labels
  • 然后再调用的时候使用DataLoader加载数据
train_data=ElementDataset(args.Train)
    test_data=ElementDataset(args.Test)

    train_iter=DataLoader(dataset=train_data,
                               batch_size=10,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=collate_fn)

    test_iter =DataLoader(dataset=test_data,
                                 batch_size=10,
                                 shuffle=True,
                                 drop_last=True,
                                 collate_fn=collate_fn)
# 可以使用一个for循环查看数据
    for i, batch in enumerate(iterator):
        input_ids,token_type_ids,attention_mask,labels= batch

batch就是每一个批次的数据,我设置的这个批次的数据是10个,则这个10个的数据的长度就是一样的长度,并且都是tensor格式.

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,366评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,521评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,689评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,925评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,942评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,727评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,447评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,349评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,820评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,990评论 3 337
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,127评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,812评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,471评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,017评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,142评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,388评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,066评论 2 355

推荐阅读更多精彩内容