通过继承实现可扩展程序

Beginner

This tutorial is from open-source community. Access the source code

简介

继承是编写可扩展程序时常用的工具。本节将探讨这一概念。

继承

继承用于特化现有对象:

class Parent:
  ...

class Child(Parent):
  ...

新类 Child 被称为派生类或子类。Parent 类被称为基类或超类。在类名 class Child(Parent): 后的 () 中指定 Parent

扩展

通过继承,你可以基于一个现有的类并:

  • 添加新方法
  • 重新定义一些现有方法
  • 为实例添加新属性

最终,你是在 扩展现有代码

示例

假设这是你的起始类:

class Stock:
    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

    def cost(self):
        return self.shares * self.price

    def sell(self, nshares):
        self.shares -= nshares

你可以通过继承来更改其中的任何部分。

添加一个新方法

class MyStock(Stock):
    def panic(self):
        self.sell(self.shares)

使用示例。

>>> s = MyStock('GOOG', 100, 490.1)
>>> s.sell(25)
>>> s.shares
75
>>> s.panic()
>>> s.shares
0
>>>

重新定义现有方法

class MyStock(Stock):
    def cost(self):
        return 1.25 * self.shares * self.price

使用示例。

>>> s = MyStock('GOOG', 100, 490.1)
>>> s.cost()
61262.5
>>>

新方法取代了旧方法。其他方法不受影响。这太棒了。

重写

有时一个类扩展了一个现有方法,但它希望在重新定义时使用原始实现。为此,请使用 super()

class Stock:
  ...
    def cost(self):
        return self.shares * self.price
  ...

class MyStock(Stock):
    def cost(self):
        ## 检查对 `super` 的调用
        actual_cost = super().cost()
        return 1.25 * actual_cost

使用 super() 来调用上一个版本。

注意:在 Python 2 中,语法更冗长。

actual_cost = super(MyStock, self).cost()

__init__ 与继承

如果重新定义了 __init__,那么初始化父类是至关重要的。

class Stock:
    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

class MyStock(Stock):
    def __init__(self, name, shares, price, factor):
        ## 检查对 `super` 和 `__init__` 的调用
        super().__init__(name, shares, price)
        self.factor = factor

    def cost(self):
        return self.factor * super().cost()

你应该在 super 上调用 __init__() 方法,这是如前所示调用上一个版本的方式。

使用继承

继承有时用于组织相关的对象。

class Shape:
 ...

class Circle(Shape):
 ...

class Rectangle(Shape):
 ...

思考一个逻辑层次结构或分类法。然而,更常见(且实用)的用途与创建可复用或可扩展的代码有关。例如,一个框架可能定义一个基类,并指示你对其进行定制。

class CustomHandler(TCPHandler):
    def handle_request(self):
     ...
        ## 自定义处理

基类包含一些通用代码。你的类继承并定制特定部分。

“是一个”关系

继承建立了一种类型关系。

class Shape:
  ...

class Circle(Shape):
  ...

检查对象实例。

>>> c = Circle(4.0)
>>> isinstance(c, Shape)
True
>>>

重要提示:理想情况下,任何适用于父类实例的代码也将适用于子类实例。

object 基类

如果一个类没有父类,你有时会看到使用 object 作为基类。

class Shape(object):
...

object 是 Python 中所有对象的父类。

*注意:从技术上讲这不是必需的,但你经常会看到它被指定,这是从 Python 2 中它的必需使用方式延续下来的。如果省略,该类仍然会隐式地从 object 继承。

多重继承

你可以通过在类的定义中指定多个类来实现多重继承。

class Mother:
 ...

class Father:
 ...

class Child(Mother, Father):
 ...

Child 继承了父母双方的特性。这里有一些相当棘手的细节。除非你清楚自己在做什么,否则不要这样做。下一节会给出一些更多信息,但在本课程中我们不会进一步使用多重继承。

继承的一个主要用途是编写旨在以各种方式进行扩展或定制的代码 —— 特别是在库或框架中。为了说明这一点,看看你 report.py 程序中的 print_report() 函数。它应该如下所示:

def print_report(reportdata):
    '''
    从 (name, shares, price, change) 元组列表中打印一个格式良好的表格。
    '''
    headers = ('Name','Shares','Price','Change')
    print('%10s %10s %10s %10s' % headers)
    print(('-'*10 + ' ')*len(headers))
    for row in reportdata:
        print('%10s %10d %10.2f %10.2f' % row)

当你运行你的报告程序时,你应该会得到如下输出:

>>> import report
>>> report.portfolio_report('portfolio.csv', 'prices.csv')
      Name     Shares      Price     Change
---------- ---------- ---------- ----------
        AA        100       9.22     -22.98
       IBM         50     106.28      15.18
       CAT        150      35.46     -47.98
      MSFT        200      20.89     -30.34
        GE         95      13.48     -26.89
      MSFT         50      20.89     -44.21
       IBM        100     106.28      35.84

练习 4.5:一个可扩展性问题

假设你想要修改 print_report() 函数,以支持多种不同的输出格式,如纯文本、HTML、CSV 或 XML。要做到这一点,你可能会尝试编写一个庞大的函数来处理所有事情。然而,这样做很可能会导致代码难以维护,变得一团糟。相反,这是一个使用继承的绝佳机会。

首先,关注创建表格所涉及的步骤。表格顶部是一组表头。之后是表格数据行。让我们把这些步骤提取出来,放到它们自己的类中。创建一个名为 tableformat.py 的文件,并定义以下类:

## tableformat.py

class TableFormatter:
    def headings(self, headers):
        '''
        输出表格表头。
        '''
        raise NotImplementedError()

    def row(self, rowdata):
        '''
        输出一行表格数据。
        '''
        raise NotImplementedError()

这个类什么也不做,但它为即将定义的其他类提供了一种设计规范。这样的类有时被称为“抽象基类”。

修改 print_report() 函数,使其接受一个 TableFormatter 对象作为输入,并调用它的方法来生成输出。例如,如下所示:

## report.py
...

def print_report(reportdata, formatter):
    '''
    从 (name, shares, price, change) 元组列表中打印一个格式良好的表格。
    '''
    formatter.headings(['Name','Shares','Price','Change'])
    for name, shares, price, change in reportdata:
        rowdata = [ name, str(shares), f'{price:0.2f}', f'{change:0.2f}' ]
        formatter.row(rowdata)

由于你在 print_report() 中添加了一个参数,你还需要修改 portfolio_report() 函数。将其修改为如下所示,以便它创建一个 TableFormatter

## report.py

import tableformat

...
def portfolio_report(portfoliofile, pricefile):
    '''
    根据投资组合和价格数据文件生成股票报告。
    '''
    ## 读取数据文件
    portfolio = read_portfolio(portfoliofile)
    prices = read_prices(pricefile)

    ## 创建报告数据
    report = make_report_data(portfolio, prices)

    ## 打印输出
    formatter = tableformat.TableFormatter()
    print_report(report, formatter)

运行这段新代码:

>>> ================================ RESTART ================================
>>> import report
>>> report.portfolio_report('portfolio.csv', 'prices.csv')
... 程序崩溃...

它应该会立即因 NotImplementedError 异常而崩溃。这并不太令人兴奋,但这正是我们所期望的。继续下一部分。

练习 4.6:使用继承生成不同输出

你在(a)部分定义的 TableFormatter 类旨在通过继承进行扩展。实际上,这就是整个思路。为了说明这一点,定义一个如下的 TextTableFormatter 类:

## tableformat.py
...
class TextTableFormatter(TableFormatter):
    '''
    以纯文本格式输出表格
    '''
    def headings(self, headers):
        for h in headers:
            print(f'{h:>10s}', end=' ')
        print()
        print(('-'*10 + ' ')*len(headers))

    def row(self, rowdata):
        for d in rowdata:
            print(f'{d:>10s}', end=' ')
        print()

portfolio_report() 函数修改如下并进行尝试:

## report.py
...
def portfolio_report(portfoliofile, pricefile):
    '''
    根据投资组合和价格数据文件生成股票报告。
    '''
    ## 读取数据文件
    portfolio = read_portfolio(portfoliofile)
    prices = read_prices(pricefile)

    ## 创建报告数据
    report = make_report_data(portfolio, prices)

    ## 打印输出
    formatter = tableformat.TextTableFormatter()
    print_report(report, formatter)

这应该会产生与之前相同的输出:

>>> ================================ RESTART ================================
>>> import report
>>> report.portfolio_report('portfolio.csv', 'prices.csv')
      Name     Shares      Price     Change
---------- ---------- ---------- ----------
        AA        100       9.22     -22.98
       IBM         50     106.28      15.18
       CAT        150      35.46     -47.98
      MSFT        200      20.89     -30.34
        GE         95      13.48     -26.89
      MSFT         50      20.89     -44.21
       IBM        100     106.28      35.84
>>>

然而,让我们将输出改为其他格式。定义一个新的 CSVTableFormatter 类,以 CSV 格式输出:

## tableformat.py
...
class CSVTableFormatter(TableFormatter):
    '''
    以 CSV 格式输出投资组合数据。
    '''
    def headings(self, headers):
        print(','.join(headers))

    def row(self, rowdata):
        print(','.join(rowdata))

将你的主程序修改如下:

def portfolio_report(portfoliofile, pricefile):
    '''
    根据投资组合和价格数据文件生成股票报告。
    '''
    ## 读取数据文件
    portfolio = read_portfolio(portfoliofile)
    prices = read_prices(pricefile)

    ## 创建报告数据
    report = make_report_data(portfolio, prices)

    ## 打印输出
    formatter = tableformat.CSVTableFormatter()
    print_report(report, formatter)

现在你应该会看到如下的 CSV 输出:

>>> ================================ RESTART ================================
>>> import report
>>> report.portfolio_report('portfolio.csv', 'prices.csv')
Name,Shares,Price,Change
AA,100,9.22,-22.98
IBM,50,106.28,15.18
CAT,150,35.46,-47.98
MSFT,200,20.89,-30.34
GE,95,13.48,-26.89
MSFT,50,20.89,-44.21
IBM,100,106.28,35.84

使用类似的思路,定义一个 HTMLTableFormatter 类,生成一个具有以下输出的表格:

<tr><th>Name</th><th>Shares</th><th>Price</th><th>Change</th></tr>
<tr><td>AA</td><td>100</td><td>9.22</td><td>-22.98</td></tr>
<tr><td>IBM</td><td>50</td><td>106.28</td><td>15.18</td></tr>
<tr><td>CAT</td><td>150</td><td>35.46</td><td>-47.98</td></tr>
<tr><td>MSFT</td><td>200</td><td>20.89</td><td>-30.34</td></tr>
<tr><td>GE</td><td>95</td><td>13.48</td><td>-26.89</td></tr>
<tr><td>MSFT</td><td>50</td><td>20.89</td><td>-44.21</td></tr>
<tr><td>IBM</td><td>100</td><td>106.28</td><td>35.84</td></tr>

通过修改主程序以创建一个 HTMLTableFormatter 对象而不是 CSVTableFormatter 对象来测试你的代码。

练习 4.7:多态性的实际应用

面向对象编程的一个主要特性是,你可以将一个对象插入到程序中,它就能正常工作,而无需更改任何现有代码。例如,如果你编写了一个期望使用 TableFormatter 对象的程序,无论你实际提供的是哪种 TableFormatter,它都能正常运行。这种行为有时被称为“多态性”。

一个潜在的问题是弄清楚如何允许用户选择他们想要的格式化器。直接使用类名,如 TextTableFormatter,通常很麻烦。因此,你可能会考虑一些简化的方法。也许你可以在代码中嵌入一个 if 语句,如下所示:

def portfolio_report(portfoliofile, pricefile, fmt='txt'):
    '''
    根据投资组合和价格数据文件生成股票报告。
    '''
    ## 读取数据文件
    portfolio = read_portfolio(portfoliofile)
    prices = read_prices(pricefile)

    ## 创建报告数据
    report = make_report_data(portfolio, prices)

    ## 打印输出
    if fmt == 'txt':
        formatter = tableformat.TextTableFormatter()
    elif fmt == 'csv':
        formatter = tableformat.CSVTableFormatter()
    elif fmt == 'html':
        formatter = tableformat.HTMLTableFormatter()
    else:
        raise RuntimeError(f'未知格式 {fmt}')
    print_report(report, formatter)

在这段代码中,用户指定一个简化的名称,如 'txt''csv' 来选择一种格式。然而,像那样在 portfolio_report() 函数中放置一个大型 if 语句是最好的主意吗?也许将那段代码移到其他地方的一个通用函数中会更好。

tableformat.py 文件中,添加一个函数 create_formatter(name),它允许用户根据输出名称,如 'txt''csv''html' 创建一个格式化器。修改 portfolio_report(),使其如下所示:

def portfolio_report(portfoliofile, pricefile, fmt='txt'):
    '''
    根据投资组合和价格数据文件生成股票报告。
    '''
    ## 读取数据文件
    portfolio = read_portfolio(portfoliofile)
    prices = read_prices(pricefile)

    ## 创建报告数据
    report = make_report_data(portfolio, prices)

    ## 打印输出
    formatter = tableformat.create_formatter(fmt)
    print_report(report, formatter)

尝试使用不同的格式调用该函数,以确保它能正常工作。

练习 4.8:整合所有内容

修改 report.py 程序,使 portfolio_report() 函数接受一个可选参数来指定输出格式。例如:

>>> report.portfolio_report('portfolio.csv', 'prices.csv', 'txt')
      Name     Shares      Price     Change
---------- ---------- ---------- ----------
        AA        100       9.22     -22.98
       IBM         50     106.28      15.18
       CAT        150      35.46     -47.98
      MSFT        200      20.89     -30.34
        GE         95      13.48     -26.89
      MSFT         50      20.89     -44.21
       IBM        100     106.28      35.84
>>>

修改主程序,以便可以在命令行中给出格式:

$ python3 report.py portfolio.csv prices.csv csv
Name,Shares,Price,Change
AA,100,9.22,-22.98
IBM,50,106.28,15.18
CAT,150,35.46,-47.98
MSFT,200,20.89,-30.34
GE,95,13.48,-26.89
MSFT,50,20.89,-44.21
IBM,100,106.28,35.84
$

讨论

编写可扩展代码是库和框架中继承的最常见用途之一。例如,一个框架可能会指示你定义自己的对象,该对象继承自提供的基类。然后,你需要填充各种方法来实现各种功能。

另一个稍微深入一点的概念是“拥有自己的抽象”。在练习中,我们为格式化表格定义了我们自己的类。你可能看着自己的代码并告诉自己:“我应该只使用一个格式化库,或者其他人已经编写好的东西!”不,你应该同时使用你自己的类和一个库。使用你自己的类可以促进松耦合,并且更灵活。只要你的应用程序使用你类的编程接口,你就可以以任何你想要的方式更改内部实现。你可以编写全自定义代码。你可以使用别人的第三方包。当你找到更好的包时,你可以将一个第三方包换成另一个。这都没关系——只要你保留接口,你的应用程序代码就不会中断。这是一个强大的概念,也是你可能考虑为此类事情使用继承的原因之一。

话虽如此,设计面向对象程序可能极其困难。如需更多信息,你可能应该查找有关设计模式主题的书籍(尽管理解本练习中发生的事情将使你在以实际有用的方式使用对象方面走得很远)。

总结

恭喜你!你已经完成了实验继承。你可以在 LabEx 中练习更多实验来提高你的技能。