批量获取数据的方法:
def shuffle(*args):
"Shuffles list of NumPy arrays in unison"
state = np.random.get_state()
for array in args:
np.random.set_state(state)
np.random.shuffle(array)
def grouper(iter_, n):
"""Collect data into fixed-length chunks or blocks
grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
from python itertools docs"""
args = [iter(iter_)] * n
return zip(*args)
def batches(data, labels, batch_size, randomize=True):
if len(data) != len(labels):
raise ValueError('Image data and label data must be same size')
if batch_size > len(data):
raise ValueError('Batch size cannot be larger than size of datasets')
if randomize:
shuffle(data, labels)
for res in zip(grouper(data, batch_size),
grouper(labels, batch_size)):
yield res
for b in batches(list(range(10)),
list(range(100,110)),
3, randomize=True):
print(b)