functools.cacheをメソッドに使う

functools.cache は便利ですが、メソッドに対して使う時には注意が必要です。

from functools import cache

class A:
    @cache
    def f(self, x):
        return x * 2

for i in range(1000):
    a = A()
    a.f(42)

print(A.f.cache_info())
# CacheInfo(hits=0, misses=1000, maxsize=None, currsize=1000)

このコードでは A.f() メソッドの第一引数 self がキャッシュキーに含まれるためキャッシュが効いていません。単に効かないどころか Aインスタンスが無限にキャッシュに残り続けるのでメモリリークになります。

この問題を回避するには f() を staticmethod にするか、Aの外で通常の関数として定義する必要があります。

from functools import cache

class A:
    # デコレーターの順番に注意
    @staticmethod
    @cache
    def f(x):
        return x * 2

for i in range(1000):
    a = A()
    a.f(42)

print(A.f.cache_info())
# CacheInfo(hits=999, misses=1, maxsize=None, currsize=1)

これでちゃんとキャッシュヒットするようになり、メモリリークも防いでいます。 @staticmethod@cache の順番は逆になってはいけません。 cache が返すラッパー関数が staticmethod でないので、 A.f により self がバインドされてしまいます。 (staticmethod, cache, staticmethod の順番で f 本体とラッパー関数両方を staticmethod にしても良いですが、冗長です。)

一方でインスタンス間でキャッシュをシェアしたくない場合は一工夫必要になります。デコレーターとして使ってしまうとcacheがクラスに所属してしまうので、インスタンスと生存期間が同じになるようにcacheを __init__ で生成してあげます。

from functools import cache

class A:
    def __init__(self, x):
        self._x = x
        self.f = cache(self._f)

    def _f(self, y):
        return self._x * y

for i in range(10):
    a = A(i)
    for j in range(10):
        a.f(42)
    print(i, a.f.cache_info())
# 0 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 1 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 2 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 3 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 4 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 5 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 6 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 7 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 8 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 9 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)

インスタンスごとにキャッシュが分かれて、それぞれが9回ヒットしていることが判ります。

このブログに乗せているコードは引用を除き CC0 1.0 で提供します。