1. 介绍
本文主要介绍如何使用TorchText
处理文本数据集。
Torchtext
是一种为pytorch提供文本数据处理能力的库, 类似于图像处理库Torchvision
。
2. 安装
pip install torchtext
3. 概览
使用torchtext的目的是将文本转换成Batch,方便后面训练模型时使用。过程如下:
- 使用
Field
对象进行文本预处理, 生成example - 使用
Dataset
类生成数据集dataset - 使用
Iterator
生成迭代器
4. 常用的类
import torch
from torchtext.data import Field, Example, TabularDataset, BucketIterator
-
Field
:用来定义字段以及文本预处理方法 -
Example
: 用来表示一个样本,通常为“数据+标签” -
TabularDataset
: 用来从文件中读取数据,生成Dataset
,Dataset
是Example
实例的集合 -
BucketIterator
:迭代器,用来生成batch
, 类似的有Iterator
,BucketIterator
的功能较强大点,支持排序,动态padding等
5. 使用步骤
5.1 创建Field对象
# tokenize = lambda x: x.split()
def x_tokenize(x):
# 如果加载进来的是已经转成id的文本
# 此处必须将字符串转换成整型
# 否则必须将use_vocab设为True
return [w for w in x.split()]
def y_tokenize(y):
return y
TEXT = Field(sequential=True, tokenize=x_tokenize,
use_vocab=True, fix_length=None,
eos_token=None, init_token=None,
include_lengths=True)
LABEL = Field(sequential=False, tokenize=y_tokenize,
use_vocab=False)
参数说明:
-
sequential
: 是否把数据表示成序列,如果是False, 不能使用分词 默认值: True. -
use_vocab
: 是否使用词典对象. 如果是False 数据的类型必须已经是数值类型. 默认值: True. -
init_token
: 每一条数据的起始字符 默认值: None. -
eos_token
: 每条数据的结尾字符 默认值: None. -
fix_length
: 修改每条数据的长度为该值,不够的用pad_token补全. 默认值: None. -
tensor_type
: 把数据转换成的tensor类型 默认值: torch.LongTensor. -
preprocessing
:在分词之后和数值化之前使用的管道 默认值: None. -
postprocessing
: 数值化之后和转化成tensor之前使用的管道默认值: None. -
lower
: 是否把数据转化为小写 默认值: False. -
tokenize
: 分词函数. 默认值: str.split. -
include_lengths
: 是否返回一个已经补全的最小batch的元组和和一个包含每条数据长度的列表 . 默认值: False. -
batch_first
: Whether to produce tensors with the batch dimension first. 默认值: False. -
pad_token
: 用于补全的字符. 默认值: “”. -
unk_token
: 不存在词典里的字符. 默认值: “”. -
pad_first
: 是否补全第一个字符. 默认值: False.
重要的几个方法:
-
pad(minibatch)
: 在一个batch对齐每条数据 -
build_vocab()
: 建立词典 -
numericalize()
: 把文本数据数值化,返回tensor
5.2 读取文件生成数据集
torchtext的Dataset是继承自pytorch的Dataset,提供了一个可以下载压缩数据并解压的方法(支持.zip, .gz, .tgz)
splits方法可以同时读取训练集,验证集,测试集
TabularDataset可以很方便的读取CSV, TSV, or JSON格式的文件,例子如下:
# 读取文件生成数据集
fields = [('PhraseId', None), ('SentenceId', None), ('Phrase', TEXT), ('Sentiment', LABEL)]
train, test = TabularDataset.splits(
path='datasets/', format='tsv',
train='train.tsv', test='test.tsv',
skip_header=True, fields=fields)
# 构建词表
TEXT.build_vocab(train)
# print(train[0].__dict__.keys())
print(vars(train.examples[0]))
print(vars(test.examples[0]))
# 结果
{'Phrase': ['A', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'], 'Sentiment': '1'}
{'Phrase': ['An', 'intermittently', 'pleasing', 'but', 'mostly', 'routine', 'effort', '.']}
5.3 生成迭代器
Iterator是torchtext到模型的输出,它提供了我们对数据的一般处理方式,比如打乱,排序,等等,可以动态修改batch大小,这里也有splits方法 可以同时输出训练集,验证集,测试集
参数如下:
-
dataset
: 加载的数据集 -
batch_size
: Batch 大小. -
batch_size_fn
: 产生动态的batch大小的函数 -
sort_key
: 排序的key -
train
: 是否是一个训练集 -
repeat
: 是否在不同epoch中重复迭代 -
shuffle
: 是否打乱数据 -
sort
: 是否对数据进行排序 -
sort_within_batch
: batch内部是否排序 -
device
: 建立batch的设备 -1:CPU ;0,1 …:对应的GPU
这里要注意的是sort_with_batch要设置为True,并指定排序的key为文本长度,方便后面pytorch RNN进行pack和pad。
使用方式如下:
# 生成迭代器
train_iter, test_iter = BucketIterator.splits((train, test),
batch_sizes=(3, 4),
device = torch.device("cpu"),
sort_key=lambda x: len(x.Phrase), # field sorted by len
sort_within_batch=True)
batch = next(iter(train_iter))
print(batch)
print(batch.Phrase)
print(batch.Sentiment)
print(TEXT.vocab.freqs['A'])
print(TEXT.vocab.stoi['<pad>'])
print(TEXT.vocab.itos[1])
for i, v in enumerate(TEXT.vocab.stoi):
if i == 5:
break
print(v)
# 结果
[torchtext.data.batch.Batch of size 3]
[.Phrase]:('[torch.LongTensor of size 24x3]', '[torch.LongTensor of size 3]')
[.Sentiment]:[torch.LongTensor of size 3]
(tensor([[ 2220, 13, 2],
[ 311, 3760, 51],
[ 5, 149, 1818],
[ 663, 5, 6233],
[ 15, 4, 109],
[ 17, 328, 65],
[13022, 10, 84],
[ 735, 964, 11],
[ 1028, 15, 12],
[ 6, 219, 1977],
[ 17, 3, 1696],
[14071, 15, 269],
[ 2148, 2, 455],
[ 7, 650, 10],
[ 321, 4957, 4],
[ 54, 5, 297],
[ 4, 11676, 633],
[ 8328, 12, 161],
[ 727, 4209, 14],
[ 160, 185, 4],
[ 56, 4482, 7474],
[ 6, 3, 1380],
[ 989, 9820, 704],
[ 8, 1839, 1]]), tensor([24, 24, 23]))
tensor([4, 2, 2])
2680
1
<pad>
<unk>
<pad>
the
,
a