デコレータとか

Pythonのcontextlibについてまとめたいが、その前に通常のデコレータに関して整理をしておく。

例1 デコレータの実行タイミング

デコレータの簡単な例。

#!/usr/bin/env python


def hoge(*a, **k):
    print('hoge')


@hoge
def foo():
    ...

上のコードを実行すると(関数hogeも関数fooも呼び出していないにも関わらず)'hoge'が出力される。ここから、デコレータ構文@hogeは、関数fooの実行時ではなく、定義時に実行されることがわかる。

例2 デコレータの入力(引数)

上の例では関数hogeの引数をごまかしていた。実際に何が入ってくるかprintしてみると、関数fooが入っていることがわかる。

#!/usr/bin/env python


def hoge(*a, **k):
    print(f'arguments given: {a!r}, {k!r}')


@hoge
def foo():
    ...

出力:

arguments given: (<function foo at 0x7f2a45fbae50>,), {}

したがって、@hogeという構文は、次の行において定義を開始する関数fooを引数として関数hogeを呼び出している。

例3 デコレータからの出力(返り値)

少し変える。

#!/usr/bin/env python


def hoge(func):
    print(func.__name__)


@hoge
def foo():
    ...


print(foo)

出力:

foo
None

デコレータが呼び出されると、その後ろにある関数定義は上書きされる。上の例ではデコレータは何もreturnしていないため、fooにはNoneが代入される。

例4 関数定義の書き換え

#!/usr/bin/env python


def hoge(func):
    def new_func(*a, **k):
        result = func(*a, **k)
        return result
    return new_func


@hoge
def foo():
    print('Hello')


foo()

出力:

Hello

この形式ですぐに作れる形は、外側をwrapする処理を追加するデコレータが挙げられる。

例5 関数定義の書き換えにおける注意点

#!/usr/bin/env python


def hoge(func):
    def new_func(*a, **k):
        result = func(*a, **k)
        return result
    return new_func


@hoge
def foo():
    """print greeting message."""
    print('Hello')


print(foo.__name__)
print(foo.__doc__)

出力:

new_func
None

変数fooに保存された関数の定義はデコレータによって書き換わっている。書き換えの際に、関数の名前やドキュメントも置き換わってしまう。

例6 functoolsによる属性保持

関数の属性が変更されてしまう不都合を避けるために、functoolsを用いる。

#!/usr/bin/env python

import functools

def hoge(func):
    @functools.wraps(func)
    def new_func(*a, **k):
        result = func(*a, **k)
        return result
    return new_func


@hoge
def foo():
    """print greeting message."""
    print('Hello')


print(foo.__name__)
print(foo.__doc__)

出力:

foo
print greeting message.

例外が発生した場合のTracebackは、in <module> → in new_func → in fooというスタックになってくれる。なんとなくclassを継承する際にsuper()で継承元のメソッドを呼び出している時の有様を想起する。

例7 簡単な時間計測

もう少し実用的な例。まず時間計測。

#!/usr/bin/env python

import functools
import numpy as np
import time


def timer(func):
    @functools.wraps(func)
    def new_func(*a, **k):
        start_time = time.perf_counter()
        result = func(*a, **k)
        end_time = time.perf_counter()
        delta_time = end_time - start_time
        print(f'{func.__name__} in delta time {delta_time} sec')
        return result
    return new_func


@timer
def some_heavy_task(data):
    data_formatted = data.astype(np.float64)
    data_final = 2.3 * data_formatted
    return data_final


rng = np.random.default_rng(4693)
data = rng.standard_normal(10_000)
some_heavy_task(data)

出力:

some_heavy_task in delta time 0.0005468999734148383 sec

例8 引数を持つデコレータ

デコレータの汎用性を上げたい場合がある。単純な方法として、関数の階層を増やして引数を与える形式にする方法がある。 下の例ではtimeitのように同じ処理をn_iter回数だけ反復して処理時間を計測する。

#!/usr/bin/env python

import functools
import numpy as np
import time


def timer_iter(n_iter=10):
    def timer(func):
        @functools.wraps(func)
        def new_func(*a, **k):
            delta_time_list = []
            for _ in range(n_iter):
                start_time = time.perf_counter()
                result = func(*a, **k)
                end_time = time.perf_counter()
                delta_time = end_time - start_time
                delta_time_list.append(delta_time)
            average_delta_time = np.mean(delta_time_list)
            std_delta_time = np.std(delta_time_list)
            print(f'{func.__name__}: {average_delta_time:.7f}' +
                  f' +- {std_delta_time:.7f} sec')
            return result
        return new_func
    return timer


@timer_iter(n_iter=10)
def some_heavy_task(data):
    data_formatted = data.astype(np.float64)
    data_final = 2.3 * data_formatted
    return data_final

出力:

some_heavy_task: 0.0000126 +- 0.0000138 sec

例9 デコレータのクラス

デコレータに記述する関数部分は、関数でなく一般にcallableなオブジェクトを指定することができる。

#!/usr/bin/env python

import collections
import functools
import numpy as np
import time


class Timer:
    database = collections.defaultdict(list)

    def __call__(self, func):
        @functools.wraps(func)
        def wrapper(*a, **k):
            start_time = time.perf_counter()
            result = func(*a, **k)
            end_time = time.perf_counter()
            delta_time = end_time - start_time
            self.database[func.__name__].append(delta_time)
            return result
        return wrapper

    def __str__(self):
        return ', '.join([f'{k}\t{np.mean(v):.7f} +- {np.std(v):.7f}'
                          for k, v in self.database.items()])



@Timer()
def some_heavy_task(data):
    data_formatted = data.astype(np.float64)
    data_final = 2.3 * data_formatted
    return data_final


rng = np.random.default_rng(4693)
data = rng.standard_normal(10_000)
[some_heavy_task(data) for _ in range(10)]

print(Timer())

出力:

some_heavy_task 0.0001874 +- 0.0001034

デコレータを作っていると3階層4階層と深いものを当たり前のように作っていることがある。クラスのメソッドに対するデコレータだったり、さらにクラスに対するデコレータを定義することもある。しかし、ネストした構文はどうしてもややこしくなる。コードをきれいに保つ上で、階層を深くしすぎないということは重要であると考えている。

例10 コンテキストマネージャ

処理に対して前処理と後処理を実装したい場合、with文を使う方法がある。with文を使うためには、コンテキストマネージャを作成する。コンテキストマネージャとして機能するオブジェクトのクラスは、次のように構成する。

#!/usr/bin/env python

import time


class Timer:
    def __enter__(self):
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.perf_counter()
        self.delta_time = self.end_time - self.start_time



timer = Timer()


with timer:
    [f'{i}' for i in range(10_000)]


print(timer.delta_time)

この実装の場合、__enter____exit__に分けて実装するため、可読性が高くなる。 しかし関数を実行する箇所を直接制御する形にはなっていないため、引数や返り値をいじりにくいように思える。

例11 デコレータ兼コンテキストマネージャのタイマ

クラス定義の仕掛けをうまく組み合わせる。__call____enter____exit__を定義すれば、デコレータとしてもコンテキストマネージャとしても使用することができる。

#!/usr/bin/env python

import functools
import time


class Timer:
    def __enter__(self):
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        end_time = time.perf_counter()
        delta_time = end_time - self.start_time
        print(f'time: {delta_time}')

    def __call__(self, func):
        @functools.wraps(func)
        def wrapper(*a, **k):
            with self:
                result = func(*a, **k)
            return result
        return wrapper


timer = Timer()


@timer
def some_work():
    [f'{i}' for i in range(10_000)]


some_work()


with timer:
    [f'{i}' for i in range(10_000)]

出力:

time: 0.0009147999808192253
time: 0.0008865999989211559

例12 デコレータ兼コンテキストマネージャのロガー

#!/usr/bin/env python

import functools
import logging


class TraceLoggerError(Exception):
    ...


class TraceLoggerInvalidFunctionArgumentError(TraceLoggerError):
    ...


class TraceLogger:
    def __init__(self, logger=None):
        self.logger = logger

    def __enter__(self):
        self.logger.info('entering with statement')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.logger.info('exiting with statement')

    def __call__(self, func):
        if (
            (func.__kwdefaults__ is None) or
            ('trace_logger' not in func.__kwdefaults__.keys())
            ):
            # 関数funcがキーワード引数trace_loggerを持つことを仮定する
            raise TraceLoggerInvalidFunctionArgumentError(
                f'function {func.__name__} does not have ' +
                "keyword argument named 'trace_logger'"
                )
        @functools.wraps(func)
        def wrapper(*a, trace_logger=None, **k):
            if trace_logger is not None:
                trace_logger.logger.info(f'entering function {func.__name__}')
            result = func(*a, trace_logger=trace_logger, **k)
            if trace_logger is not None:
                trace_logger.logger.info(f'exiting function {func.__name__}')
            return result
        return wrapper


@TraceLogger()
def foo(*, trace_logger=None):
    print('Hello')


@TraceLogger()
def bar(*, trace_logger=None):
    print('World')


@TraceLogger()
def baz(*, trace_logger=None):
    # 例外を発生させる例
    x = 1 / 0


# キーワード引数trace_loggerのない関数をデコレートしようとするとエラー
# @TraceLogger()
def bad_1():
    print('good morning')


# キーワード引数trace_loggerのない関数をデコレートしようとするとエラー
# @TraceLogger()
def bad_2(trace_logger=None):
    print('good morning')


@TraceLogger()
def hoge(*, trace_logger=None):
    foo(trace_logger=trace_logger)
    bar(trace_logger=trace_logger)
    baz(trace_logger=trace_logger)


def test_trace_logger():
    logger = logging.getLogger('tracelogger')
    logger.setLevel(logging.DEBUG)
    handler1 = logging.StreamHandler()
    handler1.setLevel(logging.DEBUG)
    formatter1 = logging.Formatter(
        '{levelname:s} {asctime:s},{msecs:.3f} {message:s}',
        style='{',
        datefmt='%Y-%m-%d %H:%M:%S',
        )
    handler1.setFormatter(formatter1)
    logger.addHandler(handler1)

    trace_logger = TraceLogger(logger)
    with trace_logger:
        print('greeting')
    hoge(trace_logger=trace_logger)


if __name__ == '__main__':
    test_trace_logger()

出力:

INFO 2021-11-23 18:07:55,795.227 entering with statement
greeting
INFO 2021-11-23 18:07:55,795.372 exiting with statement
INFO 2021-11-23 18:07:55,795.437 entering function hoge
INFO 2021-11-23 18:07:55,795.495 entering function foo
Hello
INFO 2021-11-23 18:07:55,795.582 exiting function foo
INFO 2021-11-23 18:07:55,795.634 entering function bar
World
INFO 2021-11-23 18:07:55,795.722 exiting function bar
INFO 2021-11-23 18:07:55,795.791 entering function baz
Traceback (most recent call last):
  File "/root/work/tests/./test_deco.py", line 102, in <module>
    test_trace_logger()
  File "/root/work/tests/./test_deco.py", line 98, in test_trace_logger
    hoge(trace_logger=trace_logger)
  File "/root/work/tests/./test_deco.py", line 40, in wrapper
    result = func(*a, trace_logger=trace_logger, **k)
  File "/root/work/tests/./test_deco.py", line 79, in hoge
    baz(trace_logger=trace_logger)
  File "/root/work/tests/./test_deco.py", line 40, in wrapper
    result = func(*a, trace_logger=trace_logger, **k)
  File "/root/work/tests/./test_deco.py", line 60, in baz
    x = 1 / 0
ZeroDivisionError: division by zero