在工作中,需要在训练模型的过程中,读入大规模稀疏矩阵,因此考虑用tfrecord进行加载
1.生TFRecord
import tensorflow as tf
import numpy as np
"""
txt文件中保存的是矩阵每一行的行坐标,列坐标,以及元素值
数据格式为:‘行坐标’ + ‘[对应所有列坐标]’ + ‘[对应所有元素值]’
"""
def write_TFRecord(srcpath, dstpath):
writer = tf.python_io.TFRecordWriter(dstpath)
f = open(srcpath)
line = f.readline()
while line:
line_ = line.strip().split('\t')
cols = eval(line[1])
vals = eval(line[2])
rows = [int(line[0])]
features = tf.train.Features(
feature={'rows': tf.train.Feature(int64_list=tf.train.Int64List(value=rows)),
'photos': tf.train.Feature(int64_list=tf.train.Int64List(value=cols)),
'vals': tf.train.Feature(float_list=tf.train.FloatList(value=vals))
})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()
f.close()
2. 利用tf.data.TFRecordDataset接口进行解析
2.1 将每一行的值解析为稠密张量
def parser(example):
dicts = {
'rows': tf.FixedLenFeature(shape=[],dtype=tf.int64),
'cols': tf.VarLenFeature(dtype=tf.int64), #由于cols为变长,需要使用 tf.VarLenFeature
'vals': tf.VarLenFeature(dtype=tf.float32)
}
parsed_example = tf.parse_single_example(example, dicts)
rows = parsed_example['rows']
cols = parsed_example['cols']
vals = parsed_example['vals']
return rows, tf.sparse_tensor_to_dense(rows), tf.sparse_tensor_to_dense(vals)
# 采用这种方式,返回的是稀疏张量,需要用tf.sparse_tensor_to_dense转化为稠密张量
def get_batch_dataset(recordfile, parser):
dataset = tf.data.TFRecordDataset(recordfile).map(parser).padded_batch(2, padded_shapes=([],[None],[None]))
# 由于row_index跟vals均不为定长,无法进行batch,所以需要对其进行填充,将短的张量用0填充,直到其长度与batch中最长的张量相等
return dataset
dataset = get_batch_dataset('tfrecord', parser)
2.2 直接读取为稀疏张量
def parser1(example):
my_example_features = {'sparse': tf.SparseFeature(index_key=['rows', 'cols'],
value_key='vals',
dtype=tf.float32,
size=[1,max_col])} #size[0]表示一行,size[1]表示稀疏矩阵的列数
parsed_example = tf.parse_single_example(example, my_example_features)
return parsed_example['sparse']
def get_batch_dataset(recordfile, parser):
dataset = tf.data.TFRecordDataset(recordfile).map(parser).repeat(2).batch(20000)
return dataset
dataset = get_batch_dataset('tfrecord', parser)