迭代过程基础

Beginner

This tutorial is from open-source community. Access the source code

简介

本节将探讨迭代的底层过程。

随处可见的迭代

许多不同的对象都支持迭代。

a = 'hello'
for c in a: ## 遍历字符串 a 中的字符
  ...

b = { 'name': 'Dave', 'password':'foo'}
for k in b: ## 遍历字典中的键
  ...

c = [1,2,3,4]
for i in c: ## 遍历列表/元组中的元素
  ...

f = open('foo.txt')
for x in f: ## 遍历文件中的行
  ...

迭代:协议

考虑 for 语句。

for x in obj:
    ## 语句

在底层发生了什么?

_iter = obj.__iter__()        ## 获取迭代器对象
while True:
    try:
        x = _iter.__next__()  ## 获取下一个元素
        ## 语句...
    except StopIteration:     ## 没有更多元素
        break

所有与 for 循环配合使用的对象都实现了这个底层迭代协议。

示例:手动迭代列表。

>>> x = [1,2,3]
>>> it = x.__iter__()
>>> it
<listiterator object at 0x590b0>
>>> it.__next__()
1
>>> it.__next__()
2
>>> it.__next__()
3
>>> it.__next__()
Traceback (most recent call last):
File "<stdin>", line 1, in? StopIteration
>>>

支持迭代

如果你想在自己的对象中添加迭代功能,了解迭代是很有用的。例如,创建一个自定义容器。

class Portfolio:
    def __init__(self):
        self.holdings = []

    def __iter__(self):
        return self.holdings.__iter__()
  ...

port = Portfolio()
for s in port:
  ...

练习 6.1:迭代示例说明

创建以下列表:

a = [1,9,4,25,16]

手动迭代这个列表。调用 __iter__() 获取一个迭代器,并调用 __next__() 方法获取连续的元素。

>>> i = a.__iter__()
>>> i
<listiterator object at 0x64c10>
>>> i.__next__()
1
>>> i.__next__()
9
>>> i.__next__()
4
>>> i.__next__()
25
>>> i.__next__()
16
>>> i.__next__()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration
>>>

内置函数 next() 是调用迭代器的 __next__() 方法的快捷方式。尝试在文件上使用它:

>>> f = open('portfolio.csv')
>>> f.__iter__()    ## 注意:这返回文件本身
<_io.TextIOWrapper name='portfolio.csv' mode='r' encoding='UTF-8'>
>>> next(f)
'name,shares,price\n'
>>> next(f)
'"AA",100,32.20\n'
>>> next(f)
'"IBM",50,91.10\n'
>>>

持续调用 next(f),直到到达文件末尾。观察会发生什么。

练习 6.2:支持迭代

有时,你可能希望自己的某个对象支持迭代——特别是当你的对象围绕现有列表或其他可迭代对象进行包装时。在一个新文件 portfolio.py 中,定义以下类:

## portfolio.py

class Portfolio:

    def __init__(self, holdings):
        self._holdings = holdings

    @property
    def total_cost(self):
        return sum([s.cost for s in self._holdings])

    def tabulate_shares(self):
        from collections import Counter
        total_shares = Counter()
        for s in self._holdings:
            total_shares[s.name] += s.shares
        return total_shares

这个类旨在作为围绕列表的一层包装,但带有一些额外的方法,比如 total_cost 属性。修改 report.py 中的 read_portfolio() 函数,使其像这样创建一个 Portfolio 实例:

## report.py

...

import fileparse
from stock import Stock
from portfolio import Portfolio

def read_portfolio(filename):
    '''
    将股票投资组合文件读取为一个字典列表,字典的键为
    name、shares 和 price。
    '''
    with open(filename) as file:
        portdicts = fileparse.parse_csv(file,
                                        select=['name','shares','price'],
                                        types=[str,int,float])

    portfolio = [ Stock(d['name'], d['shares'], d['price']) for d in portdicts ]
    return Portfolio(portfolio)

...

尝试运行 report.py 程序。你会发现它会因 Portfolio 实例不可迭代而严重失败。

>>> import report
>>> report.portfolio_report('portfolio.csv', 'prices.csv')
... 崩溃...

通过修改 Portfolio 类以支持迭代来修复此问题:

class Portfolio:

    def __init__(self, holdings):
        self._holdings = holdings

    def __iter__(self):
        return self._holdings.__iter__()

    @property
    def total_cost(self):
        return sum([s.shares*s.price for s in self._holdings])

    def tabulate_shares(self):
        from collections import Counter
        total_shares = Counter()
        for s in self._holdings:
            total_shares[s.name] += s.shares
        return total_shares

做出此更改后,你的 report.py 程序应该能再次正常工作。与此同时,修改你的 pcost.py 程序以使用新的 Portfolio 对象。如下所示:

## pcost.py

import report

def portfolio_cost(filename):
    '''
    计算投资组合文件的总成本(股数 * 价格)
    '''
    portfolio = report.read_portfolio(filename)
    return portfolio.total_cost
...

进行测试以确保其正常工作:

>>> import pcost
>>> pcost.portfolio_cost('portfolio.csv')
44671.15
>>>

练习 6.3:创建一个更合适的容器

如果要创建一个容器类,你通常需要做的不仅仅是支持迭代。修改 Portfolio 类,使其具有一些其他特殊方法,如下所示:

class Portfolio:
    def __init__(self, holdings):
        self._holdings = holdings

    def __iter__(self):
        return self._holdings.__iter__()

    def __len__(self):
        return len(self._holdings)

    def __getitem__(self, index):
        return self._holdings[index]

    def __contains__(self, name):
        return any([s.name == name for s in self._holdings])

    @property
    def total_cost(self):
        return sum([s.shares*s.price for s in self._holdings])

    def tabulate_shares(self):
        from collections import Counter
        total_shares = Counter()
        for s in self._holdings:
            total_shares[s.name] += s.shares
        return total_shares

现在,使用这个新类进行一些实验:

>>> import report
>>> portfolio = report.read_portfolio('portfolio.csv')
>>> len(portfolio)
7
>>> portfolio[0]
Stock('AA', 100, 32.2)
>>> portfolio[1]
Stock('IBM', 50, 91.1)
>>> portfolio[0:3]
[Stock('AA', 100, 32.2), Stock('IBM', 50, 91.1), Stock('CAT', 150, 83.44)]
>>> 'IBM' in portfolio
True
>>> 'AAPL' in portfolio
False
>>>

关于这一点有一个重要的观察结果——一般来说,如果代码使用了 Python 其他部分通常使用的通用词汇,那么它就被认为是“Pythonic”的。对于容器对象来说,支持迭代、索引、包含性检查以及其他类型的操作符是其中的一个重要部分。

总结

恭喜你!你已经完成了迭代协议实验。你可以在 LabEx 中练习更多实验来提升你的技能。