import sys
import threading
#import queue
import random
import collections
import torch
import torch.multiprocessing as multiprocessing
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
_remove_worker_pids, _error_if_any_worker_fails
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data.dataloader import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory
from torch.utils.data.dataloader import _worker_manager_loop
from torch.utils.data.dataloader import numpy_type_map
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch
from torch.utils.data.dataloader import _SIGCHLD_handler_set
from torch.utils.data.dataloader import _set_SIGCHLD_handler
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()
torch.set_num_threads(1)
torch.manual_seed(seed)
while True:
r = index_queue.get()
if r is None:
break
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
#print ("5-1-checkpoint")
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queues = [
multiprocessing.Queue() for _ in range(self.num_workers)
]
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
self.index_queues[i],
self.worker_result_queue,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
for i in range(self.num_workers)]
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
maybe_device_id))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
#print ("5-2-1-checkpoint")
class MSDataLoader(DataLoader):
def __init__(
self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(MSDataLoader, self).__init__(
dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, batch_sampler=batch_sampler,
num_workers=args.n_threads, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
self.scale = args.scale
#print ("5-3-1-checkpoint")
def __iter__(self):
#print ("5-3-2-checkpoint")
return _MSDataLoaderIter(self)
dataloader.py
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 问题描述 最近在用 vscode 做 python 开发,在执行“转到定义”(Mac: f12, windows:...
- 原文: http://zengrong.net/post/2192.htm 本站文章除注明转载外,均为本站原创或...
- django-admin是用于管理Django的命令行工具集,当我们成功安装Django后,在操作系统中就会有这个...
- 若A.py和B.py位于不同的目录下,可以用以下方法 (假设B.py位于D盘的根目录下) 1.引用所在路径 imp...
- 之前修改了settings.py,换了别的文件名,主要是为了方便生产环境,测试环境和开发环境的三套设置自动切换 然...