from fastai.torch_basics import *
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
= (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter) _loaders
# from nbdev.showdoc import *
= 4
bs = list(string.ascii_lowercase) letters
DataLoader helpers
fastai includes a replacement for Pytorch’s DataLoader which is largely API-compatible, and adds a lot of useful functionality and flexibility. Before we look at the class, there are a couple of helpers we’ll need to define.
def _wif(worker_id):
1)
set_num_threads(= get_worker_info()
info = info.dataset.d
ds = info.num_workers,info.id
ds.num_workers,ds.offs
set_seed(info.seed)
ds.wif()
class _FakeLoader:
= None,False,noops,False
_IterableDataset_len_called,_auto_collation,collate_fn,drop_last = Inf.count,None,2
_index_sampler,generator,prefetch_factor = _dataset_kind = _DatasetKind.Iterable
dataset_kind def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers):
self.dataset,self.default,self.worker_init_fn = self,d,_wif
'd,pin_memory,num_workers,timeout,persistent_workers')
store_attr(
def __iter__(self): return iter(self.d.create_batches(self.d.sample()))
@property
def multiprocessing_context(self): return (None,multiprocessing)[self.num_workers>0]
@contextmanager
def no_multiproc(self):
= self.num_workers
old_num_workers try:
self.num_workers = 0
yield self.d
finally: self.num_workers = old_num_workers
= (ndarray, Tensor, typing.Mapping, str) _collate_types
def fa_collate(t):
"A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s"
= t[0]
b return (default_collate(t) if isinstance(b, _collate_types)
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
else default_collate(t))
#e.g. x is int, y is tuple
= [(1,(2,3)),(1,(2,3))]
t
test_eq(fa_collate(t), default_collate(t))map(type), [Tensor,tuple]) test_eq(L(fa_collate(t)).
t fa_collate(t)
[(1, (2, 3)), (1, (2, 3))]
(tensor([1, 1]), (tensor([2, 2]), tensor([3, 3])))
= [(1,(2,(3,4))),(1,(2,(3,4)))]
t
test_eq(fa_collate(t), default_collate(t))map(type), [Tensor,tuple])
test_eq(L(fa_collate(t)).1]).map(type), [Tensor,tuple]) test_eq(L(fa_collate(t)[
t
fa_collate(t)1]
fa_collate(t)[len(fa_collate(t))
[(1, (2, 3)), (1, (2, 3))]
(tensor([1, 1]), (tensor([2, 2]), tensor([3, 3])))
(tensor([2, 2]), tensor([3, 3]))
2
assemble data into dataset with pytorch
https://zhuanlan.zhihu.com/p/30385675
default_collate??
t fa_collate(t)
[(1, (2, (3, 4))), (1, (2, (3, 4)))]
(tensor([1, 1]), (tensor([2, 2]), (tensor([3, 3]), tensor([4, 4]))))
def fa_convert(t):
"A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s"
return (default_convert(t) if isinstance(t, _collate_types)
else type(t)([fa_convert(s) for s in t]) if isinstance(t, Sequence)
else default_convert(t))
= array([1,2])
t0 = [t0,(t0,t0)]
t
test_eq(fa_convert(t), default_convert(t))map(type), [Tensor,tuple]) test_eq(L(fa_convert(t)).
t fa_convert(t)
[array([1, 2]), (array([1, 2]), array([1, 2]))]
[tensor([1, 2]), (tensor([1, 2]), tensor([1, 2]))]
class SkipItemException(Exception):
"Raised to notify `DataLoader` to skip an item"
pass
DataLoader -
@funcs_kwargs
class DataLoader(GetAttr):
= 'wif before_iter after_item before_batch after_batch after_iter'.split()
_noop_methods for o in _noop_methods: exec(f"def {o}(self, x=None, *args, **kwargs): return x")
= _noop_methods + 'create_batches create_item create_batch retain \
_methods get_idxs sample shuffle_fn do_batch create_batch'.split()
= 'dataset'
_default def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,
=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, **kwargs):
shuffleif batch_size is not None: bs = batch_size # PyTorch compatibility
assert not (bs is None and drop_last)
if indexed is None: indexed = dataset is not None and hasattr(dataset,'__getitem__')
if n is None:
try: n = len(dataset)
except TypeError: pass
'dataset,bs,shuffle,drop_last,indexed,n,pin_memory,timeout,device')
store_attr(self.rng,self.num_workers,self.offs = random.Random(random.randint(0,2**32-1)),1,0
self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers)
def __len__(self):
if self.n is None: raise TypeError
if self.bs is None: return self.n
return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
def get_idxs(self):
= Inf.count if self.indexed else Inf.nones
idxs if self.n is not None: idxs = list(itertools.islice(idxs, self.n))
if self.shuffle: idxs = self.shuffle_fn(idxs)
return idxs
def sample(self):
return (b for i,b in enumerate(self.__idxs) if i//(self.bs or 1)%self.num_workers==self.offs)
def __iter__(self):
self.randomize()
self.before_iter()
self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
if self.device is not None: b = to_device(b, self.device)
yield self.after_batch(b)
self.after_iter()
if hasattr(self, 'it'): del(self.it)
def create_batches(self, samps):
self.it = iter(self.dataset) if self.dataset is not None else None
= filter(lambda o:o is not None, map(self.do_item, samps))
res yield from map(self.do_batch, self.chunkify(res))
def new(self, dataset=None, cls=None, **kwargs):
if dataset is None: dataset = self.dataset
if cls is None: cls = type(self)
= dict(dataset=dataset, num_workers=self.fake_l.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
cur_kwargs =self.bs, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)
bsfor n in self._methods:
= getattr(self, n)
o if not isinstance(o, MethodType): cur_kwargs[n] = o
return cls(**merge(cur_kwargs, kwargs))
@property
def prebatched(self): return self.bs is None
def do_item(self, s):
try: return self.after_item(self.create_item(s))
except SkipItemException: return None
def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)
def shuffle_fn(self, idxs): return self.rng.sample(idxs, len(idxs))
def randomize(self): self.rng = random.Random(self.rng.randint(0,2**32-1))
def retain(self, res, b): return retain_types(res, b[0] if is_listy(b) else b)
def create_item(self, s): return next(self.it) if s is None else self.dataset[s]
def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)
def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
def to(self, device): self.device = device
def one_batch(self):
if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
with self.fake_l.no_multiproc(): res = first(self)
if hasattr(self, 'it'): delattr(self, 'it')
return res
Arguments to DataLoader
: * dataset
: dataset from which to load the data. Can be either map-style or iterable-style dataset. * bs
(int): how many samples per batch to load (if batch_size
is provided then batch_size
will override bs
). If bs=None
, then it is assumed that dataset.__getitem__
returns a batch. * num_workers
(int): how many subprocesses to use for data loading. 0
means that the data will be loaded in the main process. * pin_memory
(bool): If True
, the data loader will copy Tensors into CUDA pinned memory before returning them. * timeout
(float>0): the timeout value in seconds for collecting a batch from workers. * batch_size
(int): It is only provided for PyTorch compatibility. Use bs
. * shuffle
(bool): If True
, then data is shuffled every time dataloader is fully read/iterated. * drop_last
(bool): If True
, then the last incomplete batch is dropped. * indexed
(bool): Set to False
, if you are using iterable-style dataset. Otherwise it is set to True
by default. * n
(int): Defaults to len(dataset)
. If you are using iterable-style dataset, you can specify the size of batch using n
. * device
(torch.device): Defaults to default_device()
which is CUDA by default. You can specify device as `torch.device(‘cpu’).
Override item
and use the default infinite sampler to get a stream of unknown length (stop()
when you want to stop the stream).
class RandDL(DataLoader):
# just think that create item defines how many batches you want to create
def create_item(self, s):
= random.random()
r return r if r<0.95 else stop()
L(RandDL())
(#2) [0.7845764769109268,0.07663069024469027]
=4, drop_last=True)) L(RandDL(bs
(#5) [tensor([0.4496, 0.1020, 0.7749, 0.2346], dtype=torch.float64),tensor([0.3137, 0.0669, 0.2633, 0.6447], dtype=torch.float64),tensor([0.1578, 0.7143, 0.7018, 0.3614], dtype=torch.float64),tensor([0.0818, 0.8804, 0.0260, 0.1141], dtype=torch.float64),tensor([0.8457, 0.4684, 0.6813, 0.5376], dtype=torch.float64)]
= L(torch.randn(3,2,2),torch.randn(1,2))
aa
aa# map(len) 得到每一个个体的len信息
map(len) aa.
(#2) [tensor([[[-0.1817, 0.8239],
[-1.2745, 0.2690]],
[[-2.4169, -0.0737],
[-0.5183, -0.2426]],
[[-0.5382, -0.8570],
[-0.3183, -1.3729]]]),tensor([[ 0.3762, -0.1435]])]
(#2) [3,1]
# generate n samples, and each len of the sample is 4
=4, drop_last=True)).map(len) L(RandDL(bs
(#19) [4,4,4,4,4,4,4,4,4,4...]
= RandDL(bs=4, num_workers=4, drop_last=True)
dl
dl= L(dl)
aa
aamap(len) aa.
<__main__.RandDL at 0x7ff490c795e0>
(#6) [tensor([7.9808e-01, 3.3119e-04, 6.3444e-01, 4.4250e-01], dtype=torch.float64),tensor([0.3784, 0.7446, 0.4139, 0.4271], dtype=torch.float64),tensor([0.0310, 0.9253, 0.8902, 0.7117], dtype=torch.float64),tensor([0.6363, 0.0280, 0.4431, 0.4497], dtype=torch.float64),tensor([0.2198, 0.9301, 0.2775, 0.5392], dtype=torch.float64),tensor([0.9400, 0.6906, 0.3483, 0.1497], dtype=torch.float64)]
(#6) [4,4,4,4,4,4]
4)
test_eq(dl.fake_l.num_workers, with dl.fake_l.no_multiproc():
0)
test_eq(dl.fake_l.num_workers, map(len)
L(dl).4) test_eq(dl.fake_l.num_workers,
(#3) [4,4,4]
def _rand_item(s):
= random.random()
r return r if r<0.95 else stop()
=_rand_item)) L(DataLoader(create_item
(#19) [0.6349563676454735,0.7146332101602991,0.8141618453401647,0.4520649933251427,0.9361665561726571,0.6025762046797407,0.8542014056058742,0.1619398819056156,0.3453745719035911,0.21838379481215286...]
If you don’t set bs
, then dataset
is assumed to provide an iterator or a __getitem__
that returns a batch.
= DataLoader(letters)
ds1
test_eq(L(ds1), letters)len(ds1), 26)
test_eq(
=True)), letters)
test_shuffled(L(DataLoader(letters, shuffle
= DataLoader(letters, indexed=False)
ds1
test_eq(L(ds1), letters)len(ds1), 26)
test_eq(
= L(tensor([0,1,2]),tensor([3,4,5]))
t2 = DataLoader(t2)
ds2
test_eq_type(L(ds2), t2)
= L(array([0,1,2]),array([3,4,5]))
t3 = DataLoader(t3)
ds3 map(tensor))
test_eq_type(L(ds3), t3.
= DataLoader(t3, create_batch=noop, after_iter=lambda: setattr(t3, 'f', 1))
ds4
test_eq_type(L(ds4), t3)1) test_eq(t3.f,
If you do set bs
, then dataset
is assumed to provide an iterator or a __getitem__
that returns a single item of a batch.
def twoepochs(d): return ' '.join(''.join(list(o)) for _ in range(2) for o in d)
= DataLoader(letters, bs=4, drop_last=True, num_workers=0)
ds1 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')
test_eq(twoepochs(ds1),
= DataLoader(letters,4,num_workers=2)
ds1 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')
test_eq(twoepochs(ds1),
= DataLoader(range(12), bs=4, num_workers=3)
ds1 0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))
test_eq_type(L(ds1), L(tensor([
= DataLoader([str(i) for i in range(11)], bs=4, after_iter=lambda: setattr(t3, 'f', 2))
ds1 '0','1','2','3'],['4','5','6','7'],['8','9','10']))
test_eq_type(L(ds1), L([2)
test_eq(t3.f,
= iter(DataLoader(map(noop,range(20)), bs=4, num_workers=1))
it next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])]) test_eq_type([
def addone(s):
+=1
sreturn s
6) addone(
7
= DataLoader(range(12),bs = 4)
ds1
L(ds1)= DataLoader(range(12),bs = 4, create_item=addone)
ds1
L(ds1)= DataLoader(range(12),bs = 4, after_item=lambda o: o*2)
ds1
L(ds1)= DataLoader(range(12),bs = 4, after_item=lambda o: o*2,create_item=addone)
ds1
L(ds1)= DataLoader(range(12),bs = 4 ,create_item=addone, after_item=lambda i : i+2,after_batch=lambda o: o*3)
ds1
L(ds1)# ds1 = DataLoader(range(12),bs = 4, before_batch=lambda o: o-1)
# L(ds1)
(#3) [tensor([0, 1, 2, 3]),tensor([4, 5, 6, 7]),tensor([ 8, 9, 10, 11])]
(#3) [tensor([1, 2, 3, 4]),tensor([5, 6, 7, 8]),tensor([ 9, 10, 11, 12])]
(#3) [tensor([0, 2, 4, 6]),tensor([ 8, 10, 12, 14]),tensor([16, 18, 20, 22])]
(#3) [tensor([2, 4, 6, 8]),tensor([10, 12, 14, 16]),tensor([18, 20, 22, 24])]
(#3) [tensor([ 9, 12, 15, 18]),tensor([21, 24, 27, 30]),tensor([33, 36, 39, 42])]
class SleepyDL(list):
def __getitem__(self,i):
/50)
time.sleep(random.random()return super().__getitem__(i)
= SleepyDL(letters)
t
%time test_eq(DataLoader(t, num_workers=0), letters)
%time test_eq(DataLoader(t, num_workers=2), letters)
%time test_eq(DataLoader(t, num_workers=4), letters)
= DataLoader(t, shuffle=True, num_workers=1)
dl
test_shuffled(L(dl), letters) test_shuffled(L(dl), L(dl))
CPU times: user 913 µs, sys: 5.14 ms, total: 6.05 ms
Wall time: 302 ms
CPU times: user 4.15 ms, sys: 30.2 ms, total: 34.4 ms
Wall time: 177 ms
CPU times: user 8.4 ms, sys: 35.2 ms, total: 43.6 ms
Wall time: 142 ms
class SleepyQueue():
"Simulate a queue with varying latency"
def __init__(self, q): self.q=q
def __iter__(self):
while True:
/100)
time.sleep(random.random()try: yield self.q.get_nowait()
except queues.Empty: return
= Queue()
q for o in range(30): q.put(o)
= SleepyQueue(q)
it
%time test_shuffled(L(DataLoader(it, num_workers=4)), range(30))
CPU times: user 9.43 ms, sys: 36.3 ms, total: 45.7 ms
Wall time: 118 ms
class A(TensorBase): pass
for nw in (0,2):
= A(tensor([1,2]))
t = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)
dl = first(dl)
b len(b)
print(b)
0]
b[type(b), A)
test_eq(
= (A(tensor([1,2])),)
t = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)
dl = first(dl)
b type(b[0]), A) test_eq(
4
A([[1, 2],
[1, 2],
[1, 2],
[1, 2]])
A([1, 2])
4
A([[1, 2],
[1, 2],
[1, 2],
[1, 2]])
A([1, 2])
list(DataLoader(list(range(50)),bs=32,shuffle=True,num_workers=3))
[tensor([30, 18, 29, 38, 43, 25, 23, 1, 0, 22, 13, 9, 27, 47, 16, 3, 15, 7,
19, 32, 45, 42, 48, 41, 10, 11, 6, 14, 20, 31, 39, 26]),
tensor([34, 35, 33, 24, 5, 28, 36, 4, 40, 49, 8, 21, 37, 17, 44, 2, 12, 46])]
class A(TensorBase): pass
= A(tensor(1,2))
t
= DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=2, after_batch=to_device)
tdl = first(tdl)
b type(b), A)
test_eq(
# Unknown attributes are delegated to `dataset`
1,2)) test_eq(tdl.pop(), tensor(
Override get_idxs
to return the same index until consumption of the DL. This is intented to test consistent sampling behavior when num_workers
>1.
class AdamantDL(DataLoader):
def get_idxs(self):
=random.randint(0,self.n-1)
rreturn [r] * self.n
tuple(AdamantDL((list(range(50))),bs=16,num_workers=4))).unique().numel(),1) test_eq(torch.cat(