文档
Python 进阶示例:装饰器、生成器与上下文管理器
目标
掌握 Python 三大进阶特性:装饰器、生成器、上下文管理器,理解它们在真实项目中的应用。
1. 装饰器(Decorator)
装饰器在不修改原函数的情况下,为其增加额外功能。
基础装饰器
import time
import functools
from typing import Callable, Any
def timer(func: Callable) -> Callable:
"""测量函数执行时间的装饰器"""
@functools.wraps(func) # 保留原函数的元信息
def wrapper(*args: Any, **kwargs: Any) -> Any:
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"[timer] {func.__name__} 执行耗时: {elapsed:.4f} 秒")
return result
return wrapper
@timer
def compute_factorial(n: int) -> int:
"""计算阶乘(故意用慢方法演示)"""
if n <= 1:
return 1
time.sleep(0.1) # 模拟耗时操作
return n * compute_factorial(n - 1)
print(f"5! = {compute_factorial(5)}")
带参数的装饰器
def retry(max_attempts: int = 3, delay: float = 1.0):
"""失败自动重试的装饰器(装饰器工厂)"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt == max_attempts:
raise
print(f"[retry] {func.__name__} 第 {attempt} 次失败: {e},{delay}秒后重试...")
time.sleep(delay)
return None
return wrapper
return decorator
@retry(max_attempts=3, delay=0.5)
def unstable_network_call() -> str:
"""模拟不稳定的网络请求"""
import random
if random.random() < 0.7:
raise ConnectionError("网络超时")
return "数据获取成功!"
# 多次运行观察重试行为
for _ in range(3):
try:
print(unstable_network_call())
except ConnectionError:
print("最终失败")
print("---")
类装饰器:注册模式
# 实际用途:插件系统、命令注册
handlers: dict[str, Callable] = {}
def register_command(name: str):
"""将函数注册为命令处理器"""
def decorator(func: Callable) -> Callable:
handlers[name] = func
return func
return decorator
@register_command("greet")
def handle_greet(user: str) -> str:
return f"你好,{user}!"
@register_command("bye")
def handle_bye(user: str) -> str:
return f"再见,{user}!"
# 动态分发
command = "greet"
if command in handlers:
print(handlers[command]("张三"))
2. 生成器(Generator)
生成器用 yield 惰性产生值,节省内存。
def fibonacci(n: int):
"""生成前 n 个斐波那契数"""
a, b = 0, 1
for _ in range(n):
yield a
a, b = b, a + b
# 惰性求值:不会一次性计算所有值
fib = fibonacci(50) # 没问题,还没计算
print(list(fib)) # 现在才计算前50个
def read_large_file(filepath: str):
"""逐行读取大文件(不会撑爆内存)"""
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
yield line.strip()
# 生成器表达式(类似列表推导式但惰性)
squares = (x**2 for x in range(10**9)) # 立即返回,几乎不占内存
print(next(squares)) # 0
print(next(squares)) # 1
print(next(squares)) # 4
def pipeline(data):
"""生成器管道:链式处理数据流"""
# 步骤1:过滤偶数
evens = (x for x in data if x % 2 == 0)
# 步骤2:平方
squared = (x**2 for x in evens)
# 步骤3:转字符串
return (f"结果: {x}" for x in squared)
data = range(20)
for item in pipeline(data):
print(item)
yield from:委托子生成器
def flatten(nested):
"""递归展平嵌套列表"""
for item in nested:
if isinstance(item, (list, tuple)):
yield from flatten(item) # 委托给子生成器
else:
yield item
nested_list = [1, [2, [3, 4], 5], 6, [7, 8]]
print(list(flatten(nested_list))) # [1, 2, 3, 4, 5, 6, 7, 8]
3. 上下文管理器(Context Manager)
用 with 语句确保资源正确获取和释放。
基于类
class DatabaseConnection:
"""模拟数据库连接(上下文管理器)"""
def __init__(self, db_url: str):
self.db_url = db_url
self.connected = False
def __enter__(self):
print(f" 连接数据库: {self.db_url}")
self.connected = True
return self # 返回给 as 子句
def query(self, sql: str) -> str:
if not self.connected:
raise RuntimeError("未连接数据库")
return f"执行: {sql} → 结果集"
def __exit__(self, exc_type, exc_val, exc_tb):
print(f" 断开数据库连接")
self.connected = False
# 返回 True 会吞掉异常(一般返回 False)
return False
# 使用
with DatabaseConnection("postgresql://localhost/mydb") as db:
print(db.query("SELECT * FROM users"))
# 离开 with 块时自动断开连接
基于生成器(contextlib)
from contextlib import contextmanager
@contextmanager
def temp_dir(path: str = "/tmp/python_demo"):
"""临时目录上下文管理器"""
import shutil
import os
# __enter__ 部分
os.makedirs(path, exist_ok=True)
print(f"创建临时目录: {path}")
try:
yield path # 交给 with 块
finally:
# __exit__ 部分(无论如何都会执行)
shutil.rmtree(path, ignore_errors=True)
print(f"清理临时目录: {path}")
with temp_dir() as tmp:
# 在临时目录中做操作
print(f" 在 {tmp} 中工作...")
实用组合:计时 + 临时环境
@contextmanager
def timer_context(label: str = "操作"):
"""测量代码块耗时的上下文管理器"""
start = time.perf_counter()
yield
elapsed = time.perf_counter() - start
print(f"[{label}] 耗时: {elapsed:.4f}s")
with timer_context("数据处理"):
# 任何需要计时的代码块
result = sum(range(10_000_000))
print(f" 计算结果: {result}")
运行步骤
python3 advanced_features.py
预期输出(节选)
5! = 120
[timer] compute_factorial 执行耗时: 0.5032 秒
[retry] unstable_network_call 第 1 次失败: 网络超时,0.5秒后重试...
数据获取成功!
你好,张三!
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, ...]
结果: 0
结果: 4
结果: 16
...
[1, 2, 3, 4, 5, 6, 7, 8]
连接数据库: postgresql://localhost/mydb
执行: SELECT * FROM users → 结果集
断开数据库连接
创建临时目录: /tmp/python_demo
在 /tmp/python_demo 中工作...
清理临时目录: /tmp/python_demo
关键要点
| 特性 | 使用场景 |
|---|---|
| 装饰器 | 日志、计时、权限校验、缓存、重试 |
| 生成器 | 大文件处理、无限序列、数据管道 |
| 上下文管理器 | 文件/连接/锁管理、临时状态、计时 |