Pythonのcontextlibについてまとめたいが、その前に通常のデコレータに関して整理をしておく。
- 例1 デコレータの実行タイミング
- 例2 デコレータの入力(引数)
- 例3 デコレータからの出力(返り値)
- 例4 関数定義の書き換え
- 例5 関数定義の書き換えにおける注意点
- 例6 functoolsによる属性保持
- 例7 簡単な時間計測
- 例8 引数を持つデコレータ
- 例9 デコレータのクラス
- 例10 コンテキストマネージャ
- 例11 デコレータ兼コンテキストマネージャのタイマ
- 例12 デコレータ兼コンテキストマネージャのロガー
例1 デコレータの実行タイミング
デコレータの簡単な例。
#!/usr/bin/env python def hoge(*a, **k): print('hoge') @hoge def foo(): ...
上のコードを実行すると(関数hogeも関数fooも呼び出していないにも関わらず)'hoge'が出力される。ここから、デコレータ構文@hoge
は、関数foo
の実行時ではなく、定義時に実行されることがわかる。
例2 デコレータの入力(引数)
上の例では関数hogeの引数をごまかしていた。実際に何が入ってくるかprintしてみると、関数fooが入っていることがわかる。
#!/usr/bin/env python def hoge(*a, **k): print(f'arguments given: {a!r}, {k!r}') @hoge def foo(): ...
出力:
arguments given: (<function foo at 0x7f2a45fbae50>,), {}
したがって、@hoge
という構文は、次の行において定義を開始する関数fooを引数として関数hogeを呼び出している。
例3 デコレータからの出力(返り値)
少し変える。
#!/usr/bin/env python def hoge(func): print(func.__name__) @hoge def foo(): ... print(foo)
出力:
foo None
デコレータが呼び出されると、その後ろにある関数定義は上書きされる。上の例ではデコレータは何もreturnしていないため、fooにはNoneが代入される。
例4 関数定義の書き換え
#!/usr/bin/env python def hoge(func): def new_func(*a, **k): result = func(*a, **k) return result return new_func @hoge def foo(): print('Hello') foo()
出力:
Hello
この形式ですぐに作れる形は、外側をwrapする処理を追加するデコレータが挙げられる。
例5 関数定義の書き換えにおける注意点
#!/usr/bin/env python def hoge(func): def new_func(*a, **k): result = func(*a, **k) return result return new_func @hoge def foo(): """print greeting message.""" print('Hello') print(foo.__name__) print(foo.__doc__)
出力:
new_func None
変数fooに保存された関数の定義はデコレータによって書き換わっている。書き換えの際に、関数の名前やドキュメントも置き換わってしまう。
例6 functoolsによる属性保持
関数の属性が変更されてしまう不都合を避けるために、functoolsを用いる。
#!/usr/bin/env python import functools def hoge(func): @functools.wraps(func) def new_func(*a, **k): result = func(*a, **k) return result return new_func @hoge def foo(): """print greeting message.""" print('Hello') print(foo.__name__) print(foo.__doc__)
出力:
foo print greeting message.
例外が発生した場合のTracebackは、in <module> → in new_func → in fooというスタックになってくれる。なんとなくclassを継承する際にsuper()で継承元のメソッドを呼び出している時の有様を想起する。
例7 簡単な時間計測
もう少し実用的な例。まず時間計測。
#!/usr/bin/env python import functools import numpy as np import time def timer(func): @functools.wraps(func) def new_func(*a, **k): start_time = time.perf_counter() result = func(*a, **k) end_time = time.perf_counter() delta_time = end_time - start_time print(f'{func.__name__} in delta time {delta_time} sec') return result return new_func @timer def some_heavy_task(data): data_formatted = data.astype(np.float64) data_final = 2.3 * data_formatted return data_final rng = np.random.default_rng(4693) data = rng.standard_normal(10_000) some_heavy_task(data)
出力:
some_heavy_task in delta time 0.0005468999734148383 sec
例8 引数を持つデコレータ
デコレータの汎用性を上げたい場合がある。単純な方法として、関数の階層を増やして引数を与える形式にする方法がある。 下の例ではtimeitのように同じ処理をn_iter回数だけ反復して処理時間を計測する。
#!/usr/bin/env python import functools import numpy as np import time def timer_iter(n_iter=10): def timer(func): @functools.wraps(func) def new_func(*a, **k): delta_time_list = [] for _ in range(n_iter): start_time = time.perf_counter() result = func(*a, **k) end_time = time.perf_counter() delta_time = end_time - start_time delta_time_list.append(delta_time) average_delta_time = np.mean(delta_time_list) std_delta_time = np.std(delta_time_list) print(f'{func.__name__}: {average_delta_time:.7f}' + f' +- {std_delta_time:.7f} sec') return result return new_func return timer @timer_iter(n_iter=10) def some_heavy_task(data): data_formatted = data.astype(np.float64) data_final = 2.3 * data_formatted return data_final
出力:
some_heavy_task: 0.0000126 +- 0.0000138 sec
例9 デコレータのクラス
デコレータに記述する関数部分は、関数でなく一般にcallableなオブジェクトを指定することができる。
#!/usr/bin/env python import collections import functools import numpy as np import time class Timer: database = collections.defaultdict(list) def __call__(self, func): @functools.wraps(func) def wrapper(*a, **k): start_time = time.perf_counter() result = func(*a, **k) end_time = time.perf_counter() delta_time = end_time - start_time self.database[func.__name__].append(delta_time) return result return wrapper def __str__(self): return ', '.join([f'{k}\t{np.mean(v):.7f} +- {np.std(v):.7f}' for k, v in self.database.items()]) @Timer() def some_heavy_task(data): data_formatted = data.astype(np.float64) data_final = 2.3 * data_formatted return data_final rng = np.random.default_rng(4693) data = rng.standard_normal(10_000) [some_heavy_task(data) for _ in range(10)] print(Timer())
出力:
some_heavy_task 0.0001874 +- 0.0001034
デコレータを作っていると3階層4階層と深いものを当たり前のように作っていることがある。クラスのメソッドに対するデコレータだったり、さらにクラスに対するデコレータを定義することもある。しかし、ネストした構文はどうしてもややこしくなる。コードをきれいに保つ上で、階層を深くしすぎないということは重要であると考えている。
例10 コンテキストマネージャ
処理に対して前処理と後処理を実装したい場合、with文を使う方法がある。with文を使うためには、コンテキストマネージャを作成する。コンテキストマネージャとして機能するオブジェクトのクラスは、次のように構成する。
#!/usr/bin/env python import time class Timer: def __enter__(self): self.start_time = time.perf_counter() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.perf_counter() self.delta_time = self.end_time - self.start_time timer = Timer() with timer: [f'{i}' for i in range(10_000)] print(timer.delta_time)
この実装の場合、__enter__
と__exit__
に分けて実装するため、可読性が高くなる。
しかし関数を実行する箇所を直接制御する形にはなっていないため、引数や返り値をいじりにくいように思える。
例11 デコレータ兼コンテキストマネージャのタイマ
クラス定義の仕掛けをうまく組み合わせる。__call__
と__enter__
と__exit__
を定義すれば、デコレータとしてもコンテキストマネージャとしても使用することができる。
#!/usr/bin/env python import functools import time class Timer: def __enter__(self): self.start_time = time.perf_counter() return self def __exit__(self, exc_type, exc_value, traceback): end_time = time.perf_counter() delta_time = end_time - self.start_time print(f'time: {delta_time}') def __call__(self, func): @functools.wraps(func) def wrapper(*a, **k): with self: result = func(*a, **k) return result return wrapper timer = Timer() @timer def some_work(): [f'{i}' for i in range(10_000)] some_work() with timer: [f'{i}' for i in range(10_000)]
出力:
time: 0.0009147999808192253 time: 0.0008865999989211559
例12 デコレータ兼コンテキストマネージャのロガー
#!/usr/bin/env python import functools import logging class TraceLoggerError(Exception): ... class TraceLoggerInvalidFunctionArgumentError(TraceLoggerError): ... class TraceLogger: def __init__(self, logger=None): self.logger = logger def __enter__(self): self.logger.info('entering with statement') return self def __exit__(self, exc_type, exc_value, traceback): self.logger.info('exiting with statement') def __call__(self, func): if ( (func.__kwdefaults__ is None) or ('trace_logger' not in func.__kwdefaults__.keys()) ): # 関数funcがキーワード引数trace_loggerを持つことを仮定する raise TraceLoggerInvalidFunctionArgumentError( f'function {func.__name__} does not have ' + "keyword argument named 'trace_logger'" ) @functools.wraps(func) def wrapper(*a, trace_logger=None, **k): if trace_logger is not None: trace_logger.logger.info(f'entering function {func.__name__}') result = func(*a, trace_logger=trace_logger, **k) if trace_logger is not None: trace_logger.logger.info(f'exiting function {func.__name__}') return result return wrapper @TraceLogger() def foo(*, trace_logger=None): print('Hello') @TraceLogger() def bar(*, trace_logger=None): print('World') @TraceLogger() def baz(*, trace_logger=None): # 例外を発生させる例 x = 1 / 0 # キーワード引数trace_loggerのない関数をデコレートしようとするとエラー # @TraceLogger() def bad_1(): print('good morning') # キーワード引数trace_loggerのない関数をデコレートしようとするとエラー # @TraceLogger() def bad_2(trace_logger=None): print('good morning') @TraceLogger() def hoge(*, trace_logger=None): foo(trace_logger=trace_logger) bar(trace_logger=trace_logger) baz(trace_logger=trace_logger) def test_trace_logger(): logger = logging.getLogger('tracelogger') logger.setLevel(logging.DEBUG) handler1 = logging.StreamHandler() handler1.setLevel(logging.DEBUG) formatter1 = logging.Formatter( '{levelname:s} {asctime:s},{msecs:.3f} {message:s}', style='{', datefmt='%Y-%m-%d %H:%M:%S', ) handler1.setFormatter(formatter1) logger.addHandler(handler1) trace_logger = TraceLogger(logger) with trace_logger: print('greeting') hoge(trace_logger=trace_logger) if __name__ == '__main__': test_trace_logger()
出力:
INFO 2021-11-23 18:07:55,795.227 entering with statement greeting INFO 2021-11-23 18:07:55,795.372 exiting with statement INFO 2021-11-23 18:07:55,795.437 entering function hoge INFO 2021-11-23 18:07:55,795.495 entering function foo Hello INFO 2021-11-23 18:07:55,795.582 exiting function foo INFO 2021-11-23 18:07:55,795.634 entering function bar World INFO 2021-11-23 18:07:55,795.722 exiting function bar INFO 2021-11-23 18:07:55,795.791 entering function baz Traceback (most recent call last): File "/root/work/tests/./test_deco.py", line 102, in <module> test_trace_logger() File "/root/work/tests/./test_deco.py", line 98, in test_trace_logger hoge(trace_logger=trace_logger) File "/root/work/tests/./test_deco.py", line 40, in wrapper result = func(*a, trace_logger=trace_logger, **k) File "/root/work/tests/./test_deco.py", line 79, in hoge baz(trace_logger=trace_logger) File "/root/work/tests/./test_deco.py", line 40, in wrapper result = func(*a, trace_logger=trace_logger, **k) File "/root/work/tests/./test_deco.py", line 60, in baz x = 1 / 0 ZeroDivisionError: division by zero