背景
今日在工作中阅读业务代码时,发现用到了OrderedDict类,遂思考如何实现OrderedDict,此文记录思考与实现过程。
OrderedDict意思上为有序字典,何为有序字典?python中普通的dict在插入数据时,最终数据存储的顺序非插入顺序。例如:
a = dict()
a['a'] = 1
a['b'] = 2
a['c'] = 3
当我们在使用for
循环遍历时,会发现打印出来的结果和我们插入顺序不同。
当然,这和字典底层使用hash表进行存储关系密切,此文暂不深入展开字典的底层实现。
有序字典通常在需要记录数据插入顺序,且能够快速定位到元素的场景使用。例如:记账。
记账这个功能中,我们需要按照时间顺序存储每一笔消费,且能够通过日期快速定位到当天的消费记录。因此,选择有序字典进行数据存储是比较理想的实现方法。
实现过程
实现OrderedDict过程中,首先想到的即是数组,因为数组的append
操作可以保证插入顺序性,而需要使用高效查找,则需要结合字典共同实现,在字典中存储插入项在数组中的下标,这样就可以通过key
快速定位到数组中的元素。
第一版实现如下,完成通过[]
set和get值的功能:
class OrderedDict():
def __init__(self):
self.elements = []
self.index = {}
def __setitem__(self, key, value):
if key not in self.index:
self.elements.append((key, value))
self.index[key] = len(self.elements) - 1
else:
pos = self.index[key]
self.elements[pos] = (key, value)
def __getitem__(self, key):
if key not in self.index:
raise KeyError(key)
pos = self.index[key]
return self.elements[pos][1]
第二版希望实现keys()
和items()
方法,并且item()
方法返回生成器,代码如下:
class OrderedDict():
def __init__(self):
self.elements = []
self.index = {}
def __setitem__(self, key, value):
if key not in self.index:
self.elements.append((key, value))
self.index[key] = len(self.elements) - 1
else:
pos = self.index[key]
self.elements[pos] = (key, value)
def __getitem__(self, key):
if key not in self.index:
raise KeyError(key)
pos = self.index[key]
return self.elements[pos][1]
def items(self):
for element in self.elements:
yield element[0], element[1]
def keys(self):
all_key = []
for element in self.elements:
all_key.append(element[0])
return all_key
第三版希望实现通过for
循环遍历该有序字典,遂初次写入如下代码:
class OrderedDict():
def __init__(self):
self.elements = []
self.index = {}
def __setitem__(self, key, value):
if key not in self.index:
self.elements.append((key, value))
self.index[key] = len(self.elements) - 1
else:
pos = self.index[key]
self.elements[pos] = (key, value)
def __getitem__(self, key):
if key not in self.index:
raise KeyError(key)
pos = self.index[key]
return self.elements[pos][1]
def __iter__(self):
return self
def next(self):
for element in self.elements:
yield element[0]
def items(self):
for element in self.elements:
yield element[0], element[1]
def keys(self):
all_key = []
for element in self.elements:
all_key.append(element[0])
return all_key
执行后发现出现无限循环:
再次分析for
循环的执行原理:通过iter()
拿到可迭代对象,再通过next()
方法不断获取元素,直到捕获StopIteration
异常。
于是发现自己写的next()
方法每次执行都会返回生成器,而非元素,故想:需要在next()
方法外部保存一个计数器,指向当前遍历的元素位置,每次执行next()
方法,都获取到该计数器指向的元素。
思考一番,将可迭代对象的实现移到另一个OrderedDictIterator
类中,在iter()
方法返回该类对象。
完整代码:
import unittest
class OrderedDictIterator():
def __init__(self, iter_cols):
self.iter_cols = iter_cols
self.iter_start = 0
def next(self):
if self.iter_start < len(self.iter_cols):
value = self.iter_cols[self.iter_start]
self.iter_start += 1
return value
else:
raise StopIteration()
__next__ = next
class OrderedDict():
def __init__(self):
self.elements = []
self.index = {}
def __setitem__(self, key, value):
if key not in self.index:
self.elements.append((key, value))
self.index[key] = len(self.elements) - 1
else:
pos = self.index[key]
self.elements[pos] = (key, value)
def __getitem__(self, key):
if key not in self.index:
raise KeyError(key)
pos = self.index[key]
return self.elements[pos][1]
def __iter__(self):
return OrderedDictIterator(self.elements)
def items(self):
for element in self.elements:
yield element[0], element[1]
def keys(self):
all_key = []
for element in self.elements:
all_key.append(element[0])
return all_key
class TestOrderedDict(unittest.TestCase):
def test_add_element(self):
d = OrderedDict()
d['a'] = 1
d['b'] = 2
def test_get_element(self):
d = OrderedDict()
d['a'] = 1
self.assertEqual(d['a'], 1)
def test_set_element(self):
d = OrderedDict()
d['a'] = 1
d['a'] = 2
self.assertEqual(d['a'], 2)
def test_for(self):
d = OrderedDict()
d['a'] = 1
d['b'] = 2
for key, value in d.items():
print('key: {}, value: {}'.format(key, value))
def test_keys(self):
d = OrderedDict()
d['a'] = 1
self.assertEqual(d.keys(), ['a'])
def test_for2(self):
d = OrderedDict()
d['a'] = 1
for i in d:
print(i)
if __name__ == '__main__':
unittest.main()
后续
上述方法采用外部类的next()
方法实现迭代,从stackoverflow
评论中看到他人评论,可直接在__iter__()
方法中yield
出元素,实现迭代。
此处给出简单实现:
def __iter__(self):
for element in self.elements:
yield element[0]
总结
1、实现某个对象可以被for循环输出时,在python2中使用next()
,而在python3中使用__next__()
方法,如果代码需要兼容2和3,则可以在类中实现next()
方法,并使用__next__ = next
的方式完成2和3的兼容
2、当for循环的目的是为了遍历对象中的存储的集合数据类型时,最好再实现一个类相关的迭代器,而不是在当前类中实现next()
方法。
3、next()
方法在每次for循环时都会被调用,需要有计数器来控制迭代次数,否则会出现无限循环。
4、next()
方法不一定需要实现,在__iter__()
方法中通过yield返回元素也可以实现for循环。
参考
2、https://stackoverflow.com/questions/5982817/problems-using-next-method-in-python