前言
最近在做500 lines or less的翻译,进度实在缓慢,趁着病中请假先做个简单的。由于工作中重构的地方还是蛮多,就先把《重构:改善既有代码设计》的第一章示例使用python实现一遍。当作温习和加深印象,至于要不要做后面的,看心情 :)
另外不得不说的一点,这篇经典著作的作者从一开始就强调单元测试对于重构的重要性。另外我在实际编码中的一点体会是,编写可测试的代码本身并不会比重构代码容易。主要的难点体现在单元测试需要隔绝外部的影响,而我写的业务代码很多都是读取一个数据的状态,更新其他很多数据库的值,测试起来还是比较麻烦(当然总有测试的办法,还是懒)。所以最好实在一开始设计的时候就做好单元测试,可以确保安全的修改。
需求
一个简单的程序,计算顾客的消费金额并打印详单。操作者告诉程序:顾客租了哪些影片,程序根据租赁时间的长短和影片类型计算费用。影片分为三类:普通片/儿童片/新片。除计算费用,还要对于常客计算积分,积分会根据租赁的影片是否是新片而不同。
起始代码
在我们的最初的代码中,所有的计算工作都在Customer.statement
中完成。
其时序图如下:
详细代码如下
class Movie:
""" 影片类,一个单纯的数据类
:param title: string, 影片标题
:param price_code: int, 影片的计价类型
"""
CHILDRENS = 2 # 儿童片
REGULAR = 0 # 普通片
NEW_REALEASE = 1 # 新片
def __init__(self, title, price_code):
self.title = title
self.price_code = price_code
def get_price_code(self):
return self.price_code
def set_price_code(self, arg):
self.price_code = arg
def get_title(self):
return self.title
class Rental:
""" 租赁类,表示某个顾客租了一部影片
:param movie: object, Movie类
:days_rented: int, 租期
"""
def __init__(self, movie, days_rented):
self.movie = movie
self.days_rented = days_rented
def get_days_rented(self):
return self.days_rented
def get_movie(self):
return self.movie
class Customer:
""" 顾客类,存放每个顾客租赁信息的列表
:param name: 顾客姓名
"""
def __init__(self, name):
self.name = name
self.rentals = []
def add_rental(self, rental):
""" 增加一条租赁信息
:param rental: object, Rental类
"""
self.rentals.append(rental)
def get_name(self):
return self.name
def statement(self):
""" 生成详单的函数 """
total_amount = 0.0
frequent_renter_points = 0
result = "Rental Record for " + self.get_name() + "\n"
for retal in self.rentals:
this_amount = 0
# 计算每部影片的价格
if retal.get_movie().get_price_code() == Movie.REGULAR:
this_amount += 2
if retal.get_days_rented() > 2:
this_amount += (retal.get_days_rented() - 2) * 1.5
elif retal.get_movie().get_price_code() == Movie.NEW_REALEASE:
this_amount += retal.get_days_rented() * 3
elif retal.get_movie().get_price_code() == Movie.CHILDRENS:
this_amount += 1.5
if retal.get_days_rented() > 3:
this_amount += (retal.get_days_rented() - 3) * 1.5
# 计算常客积分
frequent_renter_points += 1
# 新片租赁两天以上会有额外的积分奖励
if retal.get_movie().get_price_code() == Movie.NEW_REALEASE and \
retal.get_days_rented() > 1:
frequent_renter_points += 1
# 展示一条影片租赁的详情
result += "\t" + retal.get_movie().get_title() + "\t" \
+ str(this_amount) + "\n"
total_amount += this_amount
# 汇总信息
result += "Amount owned is " + str(total_amount) + "\n"
result += "You earned " + str(frequent_renter_points) \
+ " frequent renter points"
return result
以下是人工测试代码,确保修改不迷路。
def test():
c_movie = Movie("CHILDRENS", Movie.CHILDRENS)
r_movie = Movie("REGULAR", Movie.REGULAR)
n_movie = Movie("NEW_REALEASE", Movie.NEW_REALEASE)
c_rental = Rental(c_movie, 20)
r_rental = Rental(r_movie, 20)
n_rental = Rental(n_movie, 20)
customer = Customer("CUSTOMER")
customer.add_rental(c_rental)
customer.add_rental(r_rental)
customer.add_rental(n_rental)
result = customer.statement()
print(result)
if __name__ == "__main__":
test()
这个statement
函数的问题总结如下:
- 一眼看去,就很难找到修改点,从而带来维护的困难。
- 第一点变化,需要用html的方式来打印结果,就不得不重新编写一个
html_statement
函数。 - 第二点变化,计费标准发生了变化如何处理,就不得不同时修改
statement
和html_statement
两个函数。 - 第三点变化,影片的分类规则发生了变化,暂时并未确定会如何修改,但是肯定会改变消费的方案和场客积的计算方式。
分解并重组statement()
找到代码的逻辑泥团并分离方法(Extract Method)。例如本例中的if...elif...语句。
使用一个独立的方法来计算值每部影片的租金。
重构前的方法如下:
for retal in self.rentals:
this_amount = 0
# 计算每部影片的价格
if retal.get_movie().get_price_code() == Movie.REGULAR:
this_amount += 2
if retal.get_days_rented() > 2:
this_amount += (retal.get_days_rented() - 2) * 1.5
elif retal.get_movie().get_price_code() == Movie.NEW_REALEASE:
this_amount += retal.get_days_rented() * 3
elif retal.get_movie().get_price_code() == Movie.CHILDRENS:
this_amount += 1.5
if retal.get_days_rented() > 3:
this_amount += (retal.get_days_rented() - 3) * 1.5
# 计算常客积分
frequent_renter_points += 1
方法分离后如下:
for retal in self.rentals:
this_amount = self._amount_for(retal)
# 计算常客积分
frequent_renter_points += 1
def _amount_for(self, retal):
""" 计算一次租赁的金额
:param retal: object, 租赁对象Rental
"""
this_amount = 0
# 计算每部影片的价格
if retal.get_movie().get_price_code() == Movie.REGULAR:
this_amount += 2
if retal.get_days_rented() > 2:
this_amount += (retal.get_days_rented() - 2) * 1.5
elif retal.get_movie().get_price_code() == Movie.NEW_REALEASE:
this_amount += retal.get_days_rented() * 3
elif retal.get_movie().get_price_code() == Movie.CHILDRENS:
this_amount += 1.5
if retal.get_days_rented() > 3:
this_amount += (retal.get_days_rented() - 3) * 1.5
return this_amount
另外,作者觉得_amount_for
的变量名称不能有效表达意图,另外由于python中没有switch
,我又加了一个中间变量。这么做的主要原因是,好的代码应该清晰表达自己的功能。重构后如下:
def _amount_for(self, retal):
""" 计算一次租赁的金额
:param retal: object, 租赁对象Rental
"""
result = 0
price_code = retal.get_movie().get_price_code()
# 计算每部影片的价格
if price_code == Movie.REGULAR:
result += 2
if retal.get_days_rented() > 2:
result += (retal.get_days_rented() - 2) * 1.5
elif price_code == Movie.NEW_REALEASE:
result += retal.get_days_rented() * 3
elif price_code == Movie.CHILDRENS:
result += 1.5
if retal.get_days_rented() > 3:
result += (retal.get_days_rented() - 3) * 1.5
return result
搬移“金额计算”代码
仔细观察amount_for
函数,发现函数使用了来自的Rental
的信息,却没有使用来自Customer
的信息。
一般来说,函数应该放在使用的数据的对象中,所以_amount_for
方法应该搬移到Rental
类中去:
class Rental...
def get_charge(self):
""" 计算一次租赁的金额 """
result = 0
price_code = self.get_movie().get_price_code()
# 计算每部影片的价格
if price_code == Movie.REGULAR:
result += 2
if self.get_days_rented() > 2:
result += (self.get_days_rented() - 2) * 1.5
elif price_code == Movie.NEW_REALEASE:
result += self.get_days_rented() * 3
elif price_code == Movie.CHILDRENS:
result += 1.5
if self.get_days_rented() > 3:
result += (self.get_days_rented() - 3) * 1.5
return result
并将_amount_for
更改成一个简单的传值函数:
def _amount_for(self, retal):
""" 计算一次租赁的金额
:param retal: object, 租赁对象Rental
"""
return retal.get_charge()
通过测试后,可以移除这个简单传值函数,在调用端做如下修改:
for retal in self.rentals:
this_amount = retal.get_charge()
# 计算常客积分
frequent_renter_points += 1
下一件事情,目前this_amount这个变量变得有点多余了,所以我们可以使用查询替代临时变量(Replace Temp with Query)的方法来改变取值的方式。
替代临时变量的原因如下:临时变量往往引发问题,它们会导致大量的参数被传来传去,而其实完全没有这种必要。你很容易跟丢它们,尤其是在长长的函数中更是如此。但是这导致了一个问题,就是费用计算了两次,但是代码如果有合理的组织和管理就会有很好的效果(主要的思想是不要过早进行优化,对于要求实时响应的,应该小心重构,对于一般的程序直到真的遇到了性能问题再说,重构有助于你快速找到性能问题所在,而且一般来说,重构并非性能问题的瓶颈)。
提炼“常客积分计算”代码
下一步是对常客积分做类似的处理。积分视影片种类而定,有理由将其放在Rental
类中。
重构前代码如下:
for retal in self.rentals:
# 计算常客积分
frequent_renter_points += 1
# 新片租赁两天以上会有额外的积分奖励
if retal.get_movie().get_price_code() == Movie.NEW_REALEASE and \
retal.get_days_rented() > 1:
frequent_renter_points += 1
先对Rental
类进行重构:
class Rental...
def get_frequent_render_point(self):
""" 计算常客积分 """
# 新片租赁两天及以上会有额外的积分奖励
if self.get_movie().get_price_code() == Movie.NEW_REALEASE and \
self.get_days_rented() > 1:
return 2
return 1
然后更改引用点并测试:
for retal in self.rentals:
# 计算常客积分
frequent_renter_points += retal.get_frequent_render_point()
这时我们来看时序图,发现Customer
类不再需要和Movie
类进行交互了。
去除临时变量
如前所述,临时变量可能会是个问题。它们只在自己所属的函数内有效,会助长冗长复杂的函数。所以我们可以换个方式来计算总量。
重构的时候我们在Customer
类中加了两个方法,用来获取总量相关的信息。
class Customer...
def get_total_charge(self):
""" 计算顾客消费总金额
:rtype: float
"""
result = 0.0
for rental in self.rentals:
result += rental.get_charge()
return result
def get_total_frequent_renter_point(self):
""" 计算常客积分的总量
:rtype: int
"""
result = 0
for rental in self.rentals:
result += rental.get_frequent_renter_point()
return result
调用端改成如下的形式:
result += "Amount owned is " + str(self.get_total_charge()) + "\n"
result += "You earned " + str(self.get_total_frequent_renter_point()) \
+ " frequent renter points"
经过这次重构,类的时序图变成了如下形式:
现在,如果我们需要重新来定义一个html版本的输出,那么重新定义一个html_statement
会比刚才容易很多,而且经过有效的分离,如果租赁的规则发生改变,只需要在Rental
类中更新相关的操作。
更加详细的重构可以将表头,表尾和详单的详情的代码都提炼出来。
于是,我们还剩下最后一个问题,如果影片的分类规则,发生改变,我们应该如何适应。所以重构继续。
运用多态取代与价格相关的条件逻辑
(注:python并不存在多态,而是鸭子类型)
最好不要再另一个对象的属性上运用switch语句。如果不得不使用,也应该再对象自己的数据上使用,而不是在别人的数据上使用。
这暗示着我们应该将这段代码搬运到Movie
类中去。
class Movie...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
result = 0.0
price_code = self.get_price_code()
# 计算每部影片的价格
if price_code == Movie.REGULAR:
result += 2
if days_rented > 2:
result += (days_rented - 2) * 1.5
elif price_code == Movie.NEW_REALEASE:
result += days_rented * 3
elif price_code == Movie.CHILDRENS:
result += 1.5
if days_rented > 3:
result += (days_rented - 3) * 1.5
return result
def get_frequent_renter_point(self, days_rented):
""" 计算租赁一部影片多少天的常客积分
:param days_rented: int, 租赁天数
:rtype: int
"""
# 新片租赁两天及以上会有额外的积分奖励
if self.get_price_code() == Movie.NEW_REALEASE and \
days_rented > 1:
return 2
return 1
Rental
类中保留对于Movie的引用
class Rental:
def get_charge(self):
return self.get_movie().get_charge(self.get_days_rented())
def get_frequent_renter_point(self):
""" 计算常客积分 """
return self.get_movie().get_frequent_renter_point(self.get_days_rented())
终于...我们来到了继承
我们需要回顾一下我们的最后一项变化点:即在影片的生命周期内,如果影片的分类规则发生了变化,我们的代码如何适应这样的变化。
为此引入了间接层,我们加入了一个Price
对象进行子类化的处理。这可能是一个状态模式和策略模式都可以适应这样的需求。
第一个我们会用到的重构方式是将状态码用状态/策略模式替代(Replace Type Code with State/Strategy)。
需要在构造函数中设置函数来访问价格代码。
class Movie...
def __init__(self, title, price_code):
self.title = title
self.set_price_code(price_code)
接着加入一个新的类Price
,这是我们的抽象类,并在子类中加入对应的函数。
class Price:
def get_price_code(self):
""" 抽象函数,获取价格码
:rtype: int
"""
raise NotImplementedError
class ChildrensPrice(Price):
def get_price_code(self):
return Movie.CHILDRENS
class NewReleasePrice(Price):
def get_price_code(self):
return Movie.NEW_REALEASE
class RegularPrice(Price):
def get_price_code(self):
return Movie.REGULAR
接着我们采用搬运方法(Move Method)的方法将get_charge
从Movie
搬运到Price
中去(还记得前面的尽量不要switch别人加的数据吗?)
class Movie...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
return self.price.get_charge(days_rented)
class Price...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
result = 0.0
price_code = self.get_price_code()
# 计算每部影片的价格
if price_code == Movie.REGULAR:
result += 2
if days_rented > 2:
result += (days_rented - 2) * 1.5
elif price_code == Movie.NEW_REALEASE:
result += days_rented * 3
elif price_code == Movie.CHILDRENS:
result += 1.5
if days_rented > 3:
result += (days_rented - 3) * 1.5
return result
搬运完了以后,可以使用用多态替代条件变量(Replace Condition with Polymorphism)来进行重构。
我们的做法是依此取出case分支,并在各个价格状态(State模式)中建立相应的取值函数。
class ChildrensPrice...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
result = 1.5
if days_rented > 3:
result += (days_rented - 3) * 1.5
return result
class NewReleasePrice...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
return days_rented * 3
class RegularPrice...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
result = 2
if days_rented > 2:
result += (days_rented - 2) * 1.5
return result
最后我们需要将Price
抽象基类(仅提供接口的类中的代码去掉)。
class Price...
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
raise NotImplementedError
当然,由于python是鸭子类型,我们甚至可以让三个策略类不继承Price并去掉这个类。
也能正常工作。
最后对于get_frequent_renter_point
采取同样的方式。
class Movie...
def get_frequent_renter_point(self, days_rented):
""" 计算租赁一部影片多少天的常客积分
:param days_rented: int, 租赁天数
:rtype: int
"""
return self.price.get_frequent_renter_point(days_rented)
class Price...
""" 基类提供默认行为 """
def get_frequent_renter_point(self, days_rented):
""" 默认租一次加一分 """
return 1
class NewReleasePrice...
def get_frequent_renter_point(self, days_rented):
""" 新片超过一天加两分,覆盖了父类的方法 """
if days_rented > 1:
return 2
return 1
经过最后这次重构,改变影片的分类规则,改变费用计算的规则,改变常客积分的规则,都只需要在相应的类进行更新即可。
以上。
附上最终代码(改日勘误):
class Movie:
""" 影片类,一个单纯的数据类
:param title: string, 影片标题
:param price_code: int, 影片的计价类型
"""
CHILDRENS = 2 # 儿童片
REGULAR = 0 # 普通片
NEW_REALEASE = 1 # 新片
def __init__(self, title, price_code):
self.title = title
self.set_price_code(price_code)
def get_price_code(self):
return self.price.get_price_code()
def set_price_code(self, arg):
if arg == self.REGULAR:
self.price = RegularPrice()
elif arg == self.NEW_REALEASE:
self.price = NewReleasePrice()
elif arg == self.CHILDRENS:
self.price = ChildrensPrice()
else:
raise AttributeError("Incorrect Price Code")
def get_title(self):
return self.title
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
return self.price.get_charge(days_rented)
def get_frequent_renter_point(self, days_rented):
""" 计算租赁一部影片多少天的常客积分
:param days_rented: int, 租赁天数
:rtype: int
"""
return self.price.get_frequent_renter_point(days_rented)
class Price:
""" 抽象类,由于鸭子类型的关系,完全可以让子类不继承这个类 """
def get_price_code(self):
""" 抽象函数,获取价格码
:rtype: int
"""
raise NotImplementedError
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
raise NotImplementedError
def get_frequent_renter_point(self, days_rented):
""" 默认租一次加一分 """
return 1
class ChildrensPrice(Price):
def get_price_code(self):
return Movie.CHILDRENS
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
result = 1.5
if days_rented > 3:
result += (days_rented - 3) * 1.5
return result
class NewReleasePrice(Price):
def get_price_code(self):
return Movie.NEW_REALEASE
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
return days_rented * 3
def get_frequent_renter_point(self, days_rented):
""" 新片超过一天加两分,覆盖了父类的方法 """
if days_rented > 1:
return 2
return 1
class RegularPrice(Price):
def get_price_code(self):
return Movie.REGULAR
def get_charge(self, days_rented):
""" 计算租赁一部影片多少天的价格
:param days_rented: int, 租赁天数
:rtype: float
"""
result = 2
if days_rented > 2:
result += (days_rented - 2) * 1.5
return result
class Rental:
""" 租赁类,表示某个顾客租了一部影片
:param movie: object, Movie类
:days_rented: int, 租期
"""
def __init__(self, movie, days_rented):
self.movie = movie
self.days_rented = days_rented
def get_days_rented(self):
return self.days_rented
def get_movie(self):
return self.movie
def get_charge(self):
return self.get_movie().get_charge(self.get_days_rented())
def get_frequent_renter_point(self):
""" 计算常客积分 """
return self.get_movie().get_frequent_renter_point(self.get_days_rented())
class Customer:
""" 顾客类,存放每个顾客租赁信息的列表
:param name: 顾客姓名
"""
def __init__(self, name):
self.name = name
self.rentals = []
def add_rental(self, rental):
""" 增加一条租赁信息
:param rental: object, Rental类
"""
self.rentals.append(rental)
def get_name(self):
return self.name
def statement(self):
""" 生成详单的函数
rtype: string
"""
frequent_renter_points = 0
result = "Rental Record for " + self.get_name() + "\n"
for retal in self.rentals:
# 展示一条影片租赁的详情
result += "\t" + retal.get_movie().get_title() + "\t" \
+ str(retal.get_charge()) + "\n"
# 汇总信息
result += "Amount owned is " + str(self.get_total_charge()) + "\n"
result += "You earned " + str(self.get_total_frequent_renter_point()) \
+ " frequent renter points"
return result
def get_total_charge(self):
""" 计算顾客消费总金额
:rtype: float
"""
result = 0.0
for rental in self.rentals:
result += rental.get_charge()
return result
def get_total_frequent_renter_point(self):
""" 计算常客积分的总量
:rtype: int
"""
result = 0
for rental in self.rentals:
result += rental.get_frequent_renter_point()
return result
def test():
c_movie = Movie("CHILDRENS", Movie.CHILDRENS)
r_movie = Movie("REGULAR", Movie.REGULAR)
n_movie = Movie("NEW_REALEASE", Movie.NEW_REALEASE)
c_rental = Rental(c_movie, 20)
r_rental = Rental(r_movie, 20)
n_rental = Rental(n_movie, 20)
customer = Customer("CUSTOMER")
customer.add_rental(c_rental)
customer.add_rental(r_rental)
customer.add_rental(n_rental)
result = customer.statement()
print(result)
if __name__ == "__main__":
test()