1 itertools函数
以下模块级别函数都构造并返回iterator。有些iterator是无限长度的流,因此它们只能被截断流的函数或循环访问。
1.1 itertools.accumelate(iterable[, func])
版本3.2中新增。
版本3.3中修改:添加可以参数*func*。
创建一个iterator,返回累积和,或者其它二元函数的累积结果(通过可选的func参数)。参数func应该能接收两个参数。iterable中的元素可以是func函数接受的任何类型。(例如,使用默认的加法操作,元素可以是任何可做加法的类型,包括Decimal和Fraction。)如果输入的iterable为空,则输出的iterable也为空。
大致等价于:
def accumulate(iterable, func=operator.add):
'Return running totals'
# accumulate([1, 2, 3, 4, 5]) --> 1 3 6 10 15
# accumulate([1, 2, 3, 4, 5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = func(total, element)
yield total
func参数有很多种用法。可以设置为min()求最小值,max()求最大值,或者operator.mul()求积。分期还款表可以通过积累利息和申请付款构建。在iterable中提供初始值,并只使用func参数中计算的总数,可以建模一阶递归关系:
>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, operator.mul))
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
>>> list(accumulate(data, max))
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
# 1000元的利率是5%,分期还4年,每年还90
>>> cashflows = [1000, -90, -90, -90, -90]
>>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05+pmt))
[1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]
# Chaotic recurrence relation https://en.wikipedia.org/wiki/Logistic_map
>>> logistic_map = lambda x, _: r * x * (1 - x)
>>> r = 3.8
>>> x0 = 0.4
>>> inputs = repeat(x0, 36) # only the initial value is used
>>> [format(x, '.2f') for x in accumulate(inputs, logistic_map)]
['0.40', '0.91', '0.30', '0.81', '0.60', '0.92', '0.29', '0.79', '0.63',
'0.88', '0.39', '0.90', '0.33', '0.84', '0.52', '0.95', '0.18', '0.57',
'0.93', '0.25', '0.71', '0.79', '0.63', '0.88', '0.39', '0.91', '0.32',
'0.83', '0.54', '0.95', '0.20', '0.60', '0.91', '0.30', '0.80', '0.60']
只返回最终累积值的相似函数,请参考functools.reduce()。
1.2 itertools.chain(*iterables)
创建一个iterator,从第一个iterable中返回元素,直到最后一个,然后处理下一个iterable,直到最后一个iterable。用于将连续序列作为单个序列处理。大致等价于:
def chain(*iterables):
# chain('ABC', 'DEF') --> A B C D E F
for it in iterables:
for element in it:
yield element
1.3 类方法:chain.from_iterable(iterable)
chain()的替代构造函数。通过延迟计算,从单个可迭代参数中获得链接输入。大致等价于:
def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables:
for element in it:
yield element
1.4 itertools.combinations(iterable, r)
从输入的iterable中返回长度为r的元素子序列。
组合以字母顺序排列。因此,如果输入的iterable是有序的,那么产生的组合元组也是有序的。
元素根据它们的位置,而不是它们的值来确定唯一性。因此,如果输入的元素是唯一的,那么每个组合中不会有重复的元素。
大致等价于:
def combinations(iterable, r):
# combinations('ABCD', 2) --> AB AC AD BC BD CD
# combinations(range(4), 3)) --> 012 013 023 123
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = list(range(r))
yield tuple(pool[i] for i indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
过滤掉无序元素的条目后(根据它们在pool中的位置),combinations()的代码可以使用permutations()的子序列表示:
def combinations(iterable, r):
pool = tuple(iterable)
n = len(pool)
for indices in permutations(range(n), r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
当0 <= r <= n时,返回的项目数量为n! / r! / (n-r)!,当r > n时返回0。
1.5 itertools.combinations_with_replacement(iterable, r)
版本3.1中新增。
从输入的iterable中返回长度为r的元素子序列,允许单个元素重复多次。
组合以字母顺序排列。因此,如果输入的iterable是有序的,那么产生的组合元组也是有序的。
元素根据它们的位置,而不是它们的值来确定唯一性。因此,如果输入的元素是唯一的,那么每个组合中不会有重复的元素。
大致等价于:
def combinations_with_replacement(iterable, r):
# combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
pool = tuple(iterable)
n = len(pool)
if not n and r:
return
indices = [0] * r
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != n -1:
break
else:
return
indices[i:] = [indices[i] + 1] * (r - i)
yield tuple(pool[i] for i in indices)
过滤掉无序元素的条目后(根据它们在pool中的位置),combinations_with_replacement()的代码可以使用product()的子序列表示:
def combinations_with_replacement(iterable, r):
pool = tuple(iterable)
n = len(pool)
for indices in product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
当n > 0时,返回的项目数量为(n+r-1)! / r! / (n-1)!,当r > n时返回0。
1.6 itertools.compress(data, selectors)
版本3.1中新增。
创建一个从data中过滤元素的iterator,只返回在selector中计算结果为True的对应元素。当data或者selector迭代完最后一个元素后停止。大致等价于:
def comress(data, selectors):
# compress('ABCDEF', [1, 0, 1, 0, 1, 1]) --> A C E F
return (d for d, s in zip(data, selectors) if s)
1.7 itertools.count(start=0, step=1)
版本3.1中修改:添加step参数,并允许非整数参数。
创建一个iterator,从数字start开始,返回均匀间隔的值。通常用于map()的参数,生成连续的数据点。同时,也可与zip()一起使用,用来添加数列。大致等价与:
def count(start=0, step=1):
# count(10) --> 10 11 12 13 14 ...
# count(2.5, 0.5) --> 2.5 3.0 3.5 ...
n = start
while True:
yiled n
n += step
使用浮点数计数时,使用乘法,例如:start + step * i for i in count()
,可以获得更高的精度。
1.8 itertools.cycle(iterable)
创建一个iterator,从iterable返回元素,并保存每个元素的副本。当iterable耗尽时,从保存的副本中返回元素。无限重复。大致等价与:
def cycle(iterable):
# cycle('ABCD') --> A B C D A B C D ...
saved = []
for element in iterable:
yield element
saved.append(element)
while saved:
for element in saved:
yield element
注意,这个工具可能需要大量的存储空间(取决于iterable的长度)。**
1.9 itertools.dropwhile(predicate, iterable)
创建一个iterator, 只要predicate为真,就从iterable中丢弃元素;然后返回每个元素。注意,该iterator在predicate首次变为False之前,不会生成任何输出,因此可能有较长的启动时间。大致等价于:
def dropwhile(predicate, iterable):
# dropwhile(lambda x: x < 5, [1,4,6,4,1]) --> 6 4 1
iterable = iter(iterable)
for x in iterable:
if not predicate(x):
yiled x
break
for x in iterable:
yield x
1.10 itertools.filterfalse(predicate, iterable)
创建一个从iterable中过滤元素的iterator,只返回predicate为False的元素。如果predicate是None,返回为False的元素。大致等价于:
def filterfalse(predicate, iterable):
# filterfalse(lambda x: x%2, range(10) --> 0 2 4 6 8
if predicate is None:
predicate = bool
for x in iterable:
if not predicate(x):
yield x
1.11 itertools.groupby(iterable, key=None)
创建一个iterator,从iterable中返回连续的键和组。参数key是计算每个元素键值的函数。如果没有指定,或者为None,默认恒等函数(identity function),并返回未改变的元素。通常,iterable已经按同一个key函数排序了。
返回的组本身是一个iterator,与groupby()共享底层的iterable。因为源是共享的,当groupby()对象被提前时,上一个group不再可见。因此,如果之后需要数据,应该保存在一个列表中:
groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
groups.append(list(g))
uniquekeys.append(k)
groupby()大致等价于:
class groupby:
# [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
# [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
def __init__(self, iterable, key=None):
if key is None:
key = lambda x: x
self.keyfunc = key
self.it = iter(iterable)
self.tgtkey = self.currkey = self.currvalue = object()
def __iter__(self):
return self
def __next__(self):
while self.currkey == self.tgtkey:
self.currvalue = next(self.it)
self.currkey = self,keyfunc(self.currvalue)
self.tgtkey = self.currkey
return (self.currkey, self.__grouper(self.tgtkey))
def __grouper(self, tgtkey):
while self.currkey == tgtkey:
yield self.currvalue
try:
self.currvalue = next(self.it)
except StopIteration:
return
self.currkey = self.keyfunc(self.currvalue)
1.12 itertools.islice(iterable, stop) & itertools.islice(iterable, start, stop[, step])
创建一个iterator,从iterable中返回选中的元素。如果start不为0,那么start之前的元素会被跳过。接着,会返回连续的元素,除非step设置为大于0,导致元素被跳过。如果stop为None,则会迭代到最后一个元素;否则会在指定位置停止。不同于常规的切片,islice()的start,stop和step不支持负数。可用于从内部结构扁平化的数据中提取相关的字段(例如,一个多行的报告可能每三行列出一个名称字段)。大致等价于:
def islice(iterable, *args):
# islice('ABCDEFG', 2) --> A B
# islice('ABCDEFG', 2, 4) --> C D
# islice('ABCDEFG', 2, None) --> C D E F G
# islice('ABCDEFG', 0, None, 2) --> A C E G
s = slice(*args)
it = iter(range(s.start or 0, s.stop or sys.maxsize, s.step or 1))
try:
nexti = next(it)
except StopIteraction:
return
for i, element in enumerate(iterable):
if i == nexti:
yield element
nexti = next(it)
如果start为None,则从0开始迭代。如果step为None,默认值为1。
1.13 itertools.permutations(iterable, r=None)
返回iterable中长度为r的连续排列。
如果指定r,或者为None,默认值为iterable的长度,同时生成所有可能的等长排列。
组合以字母顺序排列。因此,如果输入的iterable是有序的,那么产生的组合元组也是有序的。
元素根据它们的位置,而不是它们的值来确定唯一性。因此,如果输入的元素是唯一的,那么每个组合中不会有重复的元素。
大致等价于:
def permutations(iterable, r=None):
# permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
# permutations(range(3)) --> 012 021 102 120 201 210
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
if r > n:
return
indices = list(range(n))
cycles = list(range(n ,n-r, -1))
while n:
for i in reversed(range(r)):
cycles[i] = -1
if cycles[i] == 0:
indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i
else:
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield tuple(pool[i] for i in indices[:r])
break
else:
return
过滤掉重复元素的条目后(在输入的pool中有相同的位置),permutations()的代码可以使用product()的子序列表示:
def permutations(iterable, r=None):
pool = tuple(iterable)
n = len(pool)
r = n if r is None ele r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
当0 <= r <= n时,返回的项目数量为n! / (n-r)!,当r > n时返回0。
1.14 itertools.product(*iterables, repeat=1)
输入iterables的笛卡尔积。
大致等价于在生成器表达式中内嵌for循环。例如,product(A, B)
返回值与((x, y) for x in A for y in B)
相同。
循环嵌套想里程表,最右边的元素在每次迭代中前进。这种模式以字母顺序排列。因此,如果输入的iterable是有序的,那么产生的组合元组也是有序的。
指定可选关键字参数repeat,可以计算iterable与自身的乘积。例如,product(A, repeat=4)
表示product(A, A, A, A)
。
该函数大致等价于以下代码,除了实际的实现中不会在内存中创建中间结果:
def product(*args, repeat=1):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = [tuple(pool) for pool in args] * repeat
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y pool]
for prod in result:
yield tuple(prod)
1.15 itertools.repeat(object[, times])
创建一个重复返回object的iterator。无限返回,除非指定了times参数。作为map()的不变参数,来调用函数。也可用于zip(),创建元组的不变部分。
大致等价于:
def repeat(object, times=None):
# repeat(10, 3) --> 10 10 10
if times is None:
while True:
yield object
else:
for i in range(times):
yield object
常用用法是为map或zip提供常量值的流:
>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
1.16 itertools.starmap(function, iterable)
创建一个iterator,使用从iterable中获得的参数计算function。当参数已经从单个iterable中分组为tuple(数据已经被pre-zipped)时,用于代替map()。map()和starmap()的区别与function(a, b)和function(*c)的区别一致。大致相当于:
def starmap(function, iterable):
# starmap(pow, [(2, 5), (3, 2), (10, 3)]) --> 32 9 1000
for args in iterable:
yield function(*args)
1.17 itertools.takewhile(predicate, iterable)
创建一个iterator,只要predicate为真,就从iterable中返回元素。大致等价于:
def takewhile(predicate, iterable):
# takewhile(lambda x: x<5, [1, 4, 6, 4, 1]) --> 1 4
for x in iterable:
if predicate(x):
yield x
else:
break
1.18 itertools.tee(iterable, n=2)
从单个iterable中返回n个独立的iterator。
以下代码帮助解释tee做了什么(尽管实际的实现更复杂,并且只有一个底层的FIFO队列)。
大致等价于:
def tee(iterable, n=2):
it = iter(iterable)
deques = [collections.deque() for i in range(n)]
def gen(mydeque):
while True:
if not mydeque:
try:
newval = next(it)
except StopIteration:
return
for d in deques:
d.append(newval)
yield mydeque.popleft()
return tuple(gen(d) for d in deques)
一旦tee()做了拆分,原始的iterable不应该在其它任何地方使用;否则,iterable前进时不会通知tee。
注意,这个工具可能需要大量的存储空间(取决于需要存储多少临时数据的长度)。通常,如果一个iterator在另一个iterator开始之前使用大多数或所有数据,使用list()会比tee()更快。
1.19 itertools.zip_longest(*iterables, fillvalue=None)
创建一个iterator,汇集iterable中的每一个元素。如果每个iterable的长度不相等,缺失的值使用fillvalue填充。迭代知道最长的iterable耗尽为止。大致等价于:
class ZipExhausted(Exception):
pass
def zip_longest(*args, **kwds):
# zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-
fillvalue = kwds.get('fillvalue')
counter = len(args) - 1
def sentinel():
nonlocal counter
if not counter:
raise ZipExhausted
counter -= 1
yield fillvalue
fillers = repeat(fillvalue)
iterators = [chain(it, sentinel(), fillers) for it in args]
try:
while iterators:
yield tuple(map(next, iterators))
except ZipExhausted:
pass
如果其中一个iterable可能无穷大,则zip_longest()函数应该使用限制数量的调用(例如islice()或者takewhile())。如果没有指定,fillvalue默认值为None。
2 itertools小窍门
本节使用itertools作为创建扩展工具集的构件。
扩展工具提供与底层工具集相同的高性能。通过一次处理一个元素,而不是一次把整个iterable放入内存,保持了内存的高性能。在有助于消除临时变量的函数样式中,把工具链接在一起,使代码量保持较小。使用向量化构件,而不是导致解释器开销的for循环和生成器,来保持快速执行代码。
def take(n, iterable):
"Return first n items of the iterable as a list"
return list(islice(iterable, n))
def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def tail(n, iterable):
"Return an iterator over the last n items"
# tail(3, 'ABCDEFG') --> E F G
return iter(collections.deque(iterable, maxlen=n))
def consume(iterator, n):
"Advance the iterator n-steps ahead. If n is none, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
# feed the entire iterator into a zero-length deque
collections.deque(iterator, maxlen=0)
else:
# advance to the empty slice starting at position n
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"Returns the nth item or a default value"
return next(islice(iterable, n, None), default)
def all_equal(iterable):
"Returns True if all the elements are equal to each other"
g = groupby(iterable)
return next(g, True) and not next(g, False)
def quantify(iterable, pred=bool):
"Count how many times the predicate is true"
return sum(map(pred, iterable))
def padnone(iterable):
"""Returns the sequence elements and then returns None indefinitely.
Useful for emulating the behavior of the built-in map() function.
"""
return chain(iterable, repeat(None))
def ncycles(iterable, n):
"Returns the sequence elements n times"
return chain.from_iterable(repeat(tuple(iterable), n))
def dotproduct(vec1, vec2):
return sum(map(operator.mul, vec1, vec2))
def flatten(listOfLists):
"Flatten one level of nesting"
return chain.from_iterable(listOfLists)
def repeatfunc(func, times=None, *args):
"""Repeat calls to func with specified arguments.
Example: repeatfunc(random.random)
"""
if times is None:
return starmap(func, repeat(args))
return starmap(func, repeat(args, times))
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
def roundrobin(*iterables):
"roundrobin('ABC', 'D', 'EF') --> A D E B F C"
# Recipe credited to George Sakkis
pending = len(iterables)
nexts = cycle(iter(it).__next__ for it in iterables)
while pending:
try:
for next in nexts:
yield next()
except StopIteration:
pending -= 1
nexts = cycle(islice(nexts, pending))
def partition(pred, iterable):
'Use a predicate to partition entries into false entries and true entries'
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
t1, t2 = tee(iterable)
return filterfalse(pred, t1), filter(pred, t2)
def powerset(iterable):
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def unique_everseen(iterable, key=None):
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
# unique_everseen('ABBCcAD', str.lower) --> A B C D
seen = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
def unique_justseen(iterable, key=None):
"List unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
# unique_justseen('ABBCcAD', str.lower) --> A B C A D
return map(next, map(itemgetter(1), groupby(iterable, key)))
def iter_except(func, exception, first=None):
""" Call a function repeatedly until an exception is raised.
Converts a call-until-exception interface to an iterator interface.
Like builtins.iter(func, sentinel) but uses an exception instead
of a sentinel to end the loop.
Examples:
iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator
iter_except(d.popitem, KeyError) # non-blocking dict iterator
iter_except(d.popleft, IndexError) # non-blocking deque iterator
iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue
iter_except(s.pop, KeyError) # non-blocking set iterator
"""
try:
if first is not None:
yield first() # For database APIs needing an initial cast to db.first()
while True:
yield func()
except exception:
pass
def first_true(iterable, default=False, pred=None):
"""Returns the first true value in the iterable.
If no true value is found, returns *default*
If *pred* is not None, returns the first item
for which pred(item) is true.
"""
# first_true([a,b,c], x) --> a or b or c or x
# first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
return next(filter(pred, iterable), default)
def random_product(*args, repeat=1):
"Random selection from itertools.product(*args, **kwds)"
pools = [tuple(pool) for pool in args] * repeat
return tuple(random.choice(pool) for pool in pools)
def random_permutation(iterable, r=None):
"Random selection from itertools.permutations(iterable, r)"
pool = tuple(iterable)
r = len(pool) if r is None else r
return tuple(random.sample(pool, r))
def random_combination(iterable, r):
"Random selection from itertools.combinations(iterable, r)"
pool = tuple(iterable)
n = len(pool)
indices = sorted(random.sample(range(n), r))
return tuple(pool[i] for i in indices)
def random_combination_with_replacement(iterable, r):
"Random selection from itertools.combinations_with_replacement(iterable, r)"
pool = tuple(iterable)
n = len(pool)
indices = sorted(random.randrange(n) for i in range(r))
return tuple(pool[i] for i in indices)
注意,通过局部变量定义为默认值,来代替全局查找,以上大部分方法都可以优化。例如,dotproduct方法可以重写为:
def dotproduct(vec1, vec2, sum=sum, map=map, mul=operator.mul):
return sum(map(mul, vec1, vec2))