requestsで長時間Sessionを使う場合はidle_timeoutに注意

Pythonで一番人気のあるHTTPクライアントライブラリはrequestsですが、requestsやその低レイヤーであるurllib3はidle_timeoutの設定を持っていないので、長時間アイドルが続いた接続を再利用した時に Connection Reset by Peer エラーが発生することがあります。

このエラーを避けるためにurllib3はリクエストを送信する前に0バイトのreadを行って接続が生きているか確認しているのですが、サーバー側が接続を切断するのと同時にリクエストを送信してしまう場合にはその確認をすり抜けるので、ごく低頻度にエラーが起こってしまいます。

意図的にこのエラーを再現させてみます。Goを使ってidle_timeoutが1秒のサーバーを作ります。

package main

import (
    "net/http"
    "time"
    "fmt"
    "log"
)

func myHandler(w http.ResponseWriter, r *http.Request) {
    time.Sleep(100 * time.Millisecond)
    w.Header().Set("Content-Type", "text/plain; charset=utf-8")
    fmt.Fprintf(w, "Hello, world!")
}

func main() {
    s := &http.Server{
        Addr:           ":8080",
        Handler:        http.HandlerFunc(myHandler),
        IdleTimeout:    1 * time.Second,
    }
    log.Fatal(s.ListenAndServe())
}

このサーバーに対して1秒弱の間隔でリクエストを送信します。

import requests
import threading
import time
import random

url = "http://127.0.0.1:8080"  # app.go

def get():
    session = requests.Session()
    last_sleep = 0.0
    for i in range(100):
        try:
            response = session.get(url)
            response.raise_for_status()
        except requests.exceptions.ConnectionError as e:
            print(e)
            print(f"{last_sleep=}sec")
        # keep-alive timeout が 1s のサーバーに対してギリギリのタイミングでリクエストを投げる
        last_sleep = random.uniform(0.99, 1.0)
        time.sleep(last_sleep)

workers = []

for i in range(10):
    th = threading.Thread(target=get)
    th.start()
    workers.append(th)

for worker in workers:
    worker.join()

実行結果:

('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
last_sleep=0.9935746861945842sec
('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
last_sleep=0.9940835216122245sec
('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
last_sleep=0.9905876213085897sec
('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
last_sleep=0.9944683230422835sec
('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
last_sleep=0.9906204586986777sec
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
last_sleep=0.9910259433567449sec
('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
last_sleep=0.9903877980940736sec
...

たとえば外部APIを利用するWebアプリケーションで、外部APIの呼び出しが低頻度だとかマルチスレッドを使っていると問題が発生しやすいです。低頻度の場合はSessionを使わないのが一番簡単な解決策ですが、アクセスが高頻度でもマルチスレッドを利用している場合は稀な頻度で起こる同時接続のためにkeep_aliveされる接続が増えて、一部の接続が長時間アイドルになることがあるので、同時接続数を減らすのが良いでしょう。

urllib3 では PoolManager の maxsize で同時接続数を制限できてデフォルトで1なのですが、 requests ではこれを10に置き換えてしまっており、これはほとんどのアプリケーションにとっては過剰でしょう。最大接続数を超えてもデフォルトの設定ではブロックせずに新規接続してくれるので、基本的には maxsize=1 を使い、それで足りないような場合にだけ増やすのがいいと思います。 requestsでSessionを作成する時にmaxsizeを指定できないので、カスタマイズするためにはこのようにします。

import requests
from requests.adapters import HTTPAdapter

session = requests.Session()
session.mount("http://", HTTPAdapter(pool_maxsize=1))
session.mount("https://", HTTPAdapter(pool_maxsize=1))

urllib3にidle_timeoutを設定できるようにするPRがあるので、これがマージされればもっと良い解決ができるようになるでしょう。

または、 httpx への置き換えも検討してみてください。 httpx は高水準APIはrequestsとよく似ており、ほとんど Session を Client に置き換えるだけで使えます。 httpx.Limits.keepalive_expiry で idle_timeout を指定可能で、デフォルトでは5秒になっています。先ほどの再現コードで session 変数を作る部分を次のように書き換えるだけでエラーなしに動作するようになります。

    # import requests の代わりに import httpx
    session = httpx.Client(limits=httpx.Limits(keepalive_expiry=0.5))

functools.cacheをメソッドに使う

functools.cache は便利ですが、メソッドに対して使う時には注意が必要です。

from functools import cache

class A:
    @cache
    def f(self, x):
        return x * 2

for i in range(1000):
    a = A()
    a.f(42)

print(A.f.cache_info())
# CacheInfo(hits=0, misses=1000, maxsize=None, currsize=1000)

このコードでは A.f() メソッドの第一引数 self がキャッシュキーに含まれるためキャッシュが効いていません。単に効かないどころか Aインスタンスが無限にキャッシュに残り続けるのでメモリリークになります。

この問題を回避するには f() を staticmethod にするか、Aの外で通常の関数として定義する必要があります。

from functools import cache

class A:
    # デコレーターの順番に注意
    @staticmethod
    @cache
    def f(x):
        return x * 2

for i in range(1000):
    a = A()
    a.f(42)

print(A.f.cache_info())
# CacheInfo(hits=999, misses=1, maxsize=None, currsize=1)

これでちゃんとキャッシュヒットするようになり、メモリリークも防いでいます。 @staticmethod@cache の順番は逆になってはいけません。 cache が返すラッパー関数が staticmethod でないので、 A.f により self がバインドされてしまいます。 (staticmethod, cache, staticmethod の順番で f 本体とラッパー関数両方を staticmethod にしても良いですが、冗長です。)

一方でインスタンス間でキャッシュをシェアしたくない場合は一工夫必要になります。デコレーターとして使ってしまうとcacheがクラスに所属してしまうので、インスタンスと生存期間が同じになるようにcacheを __init__ で生成してあげます。

from functools import cache

class A:
    def __init__(self, x):
        self._x = x
        self.f = cache(self._f)

    def _f(self, y):
        return self._x * y

for i in range(10):
    a = A(i)
    for j in range(10):
        a.f(42)
    print(i, a.f.cache_info())
# 0 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 1 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 2 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 3 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 4 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 5 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 6 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 7 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 8 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
# 9 CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)

インスタンスごとにキャッシュが分かれて、それぞれが9回ヒットしていることが判ります。

コネクションプールなしでhttpxを使う場合の高速化

httpx をコネクションプールありで使う場合は client = httpx.Client() して client.get() などを使いますが、コネクションプールなしで使う場合は httpx.get() などを使います。

この httpx.get() のような単発でHTTPリクエストを実行するAPIは実際には内部で httpx.Client()インスタンスを生成して破棄しています。実はこの Client()インスタンス生成が遅いのです。

# x1.py
from httpx import Client
for _ in range(100):
    c = Client()
$ hyperfine '.venv/bin/python x1.py'
Benchmark 1: .venv/bin/python x1.py
  Time (mean ± σ):      1.484 s ±  0.018 s    [User: 1.421 s, System: 0.041 s]
  Range (min … max):    1.455 s …  1.522 s    10 runs

100回クライアントを作るのに1.5秒かかっています。1回あたり15msです。もうちょっとなんとかしたい。

実はこのクライアントの生成時間のほとんどは、SSLContext()を作るのに使われています。コネクションプールを使わない場合もSSLContextだけを使い回すことでhttpxを高速化できます。

from httpx import Client, create_ssl_context
ssl_context = create_ssl_context()
for _ in range(100):
    c = Client(verify=ssl_context)
$ hyperfine '.venv/bin/python x1.py' '.venv/bin/python x2.py'
Benchmark 1: .venv/bin/python x1.py
  Time (mean ± σ):      1.477 s ±  0.024 s    [User: 1.412 s, System: 0.040 s]
  Range (min … max):    1.453 s …  1.532 s    10 runs

Benchmark 2: .venv/bin/python x2.py
  Time (mean ± σ):     120.7 ms ±   5.3 ms    [User: 85.6 ms, System: 15.2 ms]
  Range (min … max):   104.2 ms … 129.6 ms    22 runs

Summary
  .venv/bin/python x2.py ran
   12.23 ± 0.58 times faster than .venv/bin/python x1.py

10倍以上高速化できました。

実際に使う場合には get() などのメソッドが verify キーワード引数を受け取るのでそこに SSLContext を渡します。

import httpx

ssl_context = httpx.create_ssl_context()
res = httpx.get('https://google.com', verify=ssl_context)
print("Status:", res.status_code)
print(res.text[:100])

ThreadPoolExecutorの終了処理

しかし、たとえば標準ライブラリの concurrent.futures.ThreadPoolExecutor はdamonスレッドを使っていません。 そのため、 executor.shutdown(wait=True) を atexit から呼び出すことができません。

atexitで終了させるスレッドはdaemonにしよう - methaneのブログ

と言っていましたが、この話題でThreadPoolExecutorを使うのはちょっとミスリードだった気がしたのと私もThreadPoolExecutorの終了処理をちゃんと把握していなかったので補足します。

まず、ThreadPoolExecutorはデフォルトで終了前にjoinされます。つまりThreadPoolExecutorにsubmitされたタスク全てが終了するのを待ってから終了します。なので自前で atexit を使って ThreadPoolExecutor.shutdown() を呼び出す必要はありませんし、やったとしても特に問題ありません。

この終了前の join はどう実装されているのか確認します。 [_Py_Finalize()] で _PyAtExit_Call() が呼ばれる前に、非daemonスレッドを待つ関数を読んでいます。

cpython/Python/pylifecycle.c at a2ba0a7552580f616f74091f8976410f8a310313 · python/cpython · GitHub

    // Wrap up existing "threading"-module-created, non-daemon threads.
    wait_for_thread_shutdown(tstate);

この関数は threading モジュールの _shutdown 関数を読んでいます。

cpython/Python/pylifecycle.c at a2ba0a7552580f616f74091f8976410f8a310313 · python/cpython · GitHub

wait_for_thread_shutdown(PyThreadState *tstate)
{
    PyObject *result;
    PyObject *threading = PyImport_GetModule(&_Py_ID(threading));
    if (threading == NULL) {
        ...
    }
    result = PyObject_CallMethodNoArgs(threading, &_Py_ID(_shutdown));

threading._shutdown() は非daemonスレッドの終了を待つ前に、 atexit モジュールではなく、 threadingモジュールの _threading_atexits に登録された関数を実行します。 _threading_atexits に登録するための関数は _register_atexit() です。

cpython/Lib/threading.py at a2ba0a7552580f616f74091f8976410f8a310313 · python/cpython · GitHub

_threading_atexits = []
_SHUTTING_DOWN = False


def _register_atexit(func, *arg, **kwargs):
    """CPython internal: register *func* to be called before joining threads.

    The registered *func* is called with its arguments just before all
    non-daemon threads are joined in `_shutdown()`. It provides a similar
    purpose to `atexit.register()`, but its functions are called prior to
    threading shutdown instead of interpreter shutdown.

    For similarity to atexit, the registered functions are called in reverse.
    """
    if _SHUTTING_DOWN:
        raise RuntimeError("can't register atexit after shutdown")

    _threading_atexits.append(lambda: func(*arg, **kwargs))

...

def _shutdown():
    """
    Wait until the Python thread state of all non-daemon threads get deleted.
    """
    # Obscure: other threads may be waiting to join _main_thread.  That's
    # dubious, but some code does it. We can't wait for it to be marked as done
    # normally - that won't happen until the interpreter is nearly dead. So
    # mark it done here.
    if _main_thread._os_thread_handle.is_done() and _is_main_interpreter():
        # _shutdown() was already called
        return

    global _SHUTTING_DOWN
    _SHUTTING_DOWN = True

    # Call registered threading atexit functions before threads are joined.
    # Order is reversed, similar to atexit.
    for atexit_call in reversed(_threading_atexits):
        atexit_call()

    if _is_main_interpreter():
        _main_thread._os_thread_handle._set_done()

    # Wait for all non-daemon threads to exit.
    _thread_shutdown()

ThreadPoolExecutorはこの threading._register_atexit() を使って全てのスレッドプールの終了を待ちます。

cpython/Lib/concurrent/futures/thread.py at a2ba0a7552580f616f74091f8976410f8a310313 · python/cpython · GitHub

def _python_exit():
    global _shutdown
    with _global_shutdown_lock:
        _shutdown = True
    items = list(_threads_queues.items())
    for t, q in items:
        q.put(None)
    for t, q in items:
        t.join()


# Register for `_python_exit()` to be called just before joining all
# non-daemon threads. This is used instead of `atexit.register()` for
# compatibility with subinterpreters, which no longer support daemon threads.
# See bpo-39812 for context.
threading._register_atexit(_python_exit)

ということで、uwsgiの中でグローバル変数executor = ThreadPoolExecutor() とかしていても特別なケアなしに graceful shutdown は実現できます。 uwsgi.atexit とかに頼る必要はありません。確かめてみましょう。

#wsgi.py
from concurrent.futures import ThreadPoolExecutor
import time

executor = ThreadPoolExecutor(max_workers=16)

def background(i):
    print(f"starting {i}")
    time.sleep(5)
    print(f"ending {i}")

counter = 0

def application(environ, start_response):
    global counter
    start_response("200 OK", [("Content-type", "text/plain; charset=utf-8")])
    count = counter
    counter += 1
    executor.submit(background, count)
    return [b"Hello, world"]
#!/bin/bash

# .venv ディレクトリがなかったら作る
if [ ! -d .venv ]; then
    uv venv
    uv pip install uwsgi
fi

# uwsgiをバックグラウンドで起動
.venv/bin/uwsgi --http-socket :4321 --enable-threads --module wsgi --callable application --lazy-app --threads=4 --die-on-term --master --pidfile uwsgi.pid -d uwsgi.log

for i in {0..50}
do
    curl http://127.0.0.1:4321/
    echo
done

kill -TERM $(cat uwsgi.pid)

uwsgi.log を tail すると、ちゃんと50番目のタスクを実行してから終了していることがわかります。

ending 43
ending 44
ending 45
ending 46
ending 47
ending 48
ending 49
ending 50
worker 1 buried after 20 seconds
goodbye to uWSGI.

めでたしめでたし。

タイプヒントには「実装の最小要件」ではなく「想定範囲」を表す型を書く

リストを受け取ってループで処理する関数を実装するとき、引数のタイプヒントに list ではなく最小の要求として Iterable を書くことを好む人がいる。コードの実装が引数に対して必要としている最小要件(必要十分条件)を表すためだ。

def func(arg: Iterable[int]) -> None:
    for a in arg:
        do_work(a)

しかし、その関数でログかトレースにその引数の中身を追加したくなった場合にどうしたらいいだろうか? OpenTelemetryのAttributeValue型Sequenceには対応しているがIterableには対応していない。 また、Iterableを一度巡回してしまうと再び巡回できる保証はないので、 arg の中身を複数回使うことができない。

引数のタイプヒントをlistSequenceに修正しようと思っても、他のコードも「最小要件原則」で書かれていると大量の呼び出し元のコードのtype hintも次々に修正しないといけなくなる。もしこの関数のユーザーがチーム外で、後方互換性を保つ必要があるのであれば、そもそもこの修正はできない。

そこで諦めてタイプヒントを修正せずに対応すると次のようになる。

def func(arg: Iterable[int]) -> None:
    # arg : list[int] = list(arg) # Mypyは再定義をエラーにする。
    arg = list(arg)
    with tracer.start_span("func") as span:
        span.set_attribute("arg", arg)
        for a in arg:
            do_work(a)

ここで3つのコストが発生した。

  • 一度修正を試みてから、影響範囲が広いからという理由で修正を断念するまでの作業コスト
  • 引数を毎回 list(arg) する実行コスト
  • argの型が途中で Iterable から list に変わることによる認知負荷。(list型に別の名前をつけても2つの変数の認知負荷になるだけである。)

もし、この関数が最初からlistを受け取る用途しか想定していないなら型ヒントには list を使うべきだったし、listかtupleのどちらかを受け取ることを想定していたならlist|tupleSequenceを使うべきだった。このように、タイプヒントには「(今の)実装が求める最小要件」ではなく「想定している引数の型の範囲」を表すべきである。

しかし、想定する引数の型を最初から完全に決めるのは難しい。試しにこの関数のtype hintがlistだったのにtupleを渡したくなった場合を考えてみよう。上の例と逆に「具体的すぎた」ケースだ。

  • タイプヒントを list | tupleSequence に修正する場合、既存の呼び出し元は list 型の値を渡しているので芋づる式に大量の呼び出し元の修正は必要ない。
  • タイプヒントを変えずに呼び出し元で list(arg) に変換する場合も、変換コストがかかるのはその1箇所だけで済む。

このように、「実装の最小要件」を使うポリシーよりも「必要になるまで具体型を使う」ポリシーの方が対応コストが低くなることが多い。どこまでの抽象度の型を受け取るべきか判断を後回しにしたい場合は、とりあえず具体型を使うことにしよう。変更の必要が生じた時は、その関数のあるべき仕様をより正しく理解してタイプヒントを書けるはずだ。

戻り値についても考えてみよう。戻り値を list から Sequence に変更するのは破壊的変更になるので、特にライブラリの公開APIのように利用側コードが別チームで開発されている場合は簡単に変更できない。ただしこれはタイプヒントだけの問題ではない。タイプヒントのないコードでも、 list を返していた関数が tuple 型を返すようになったら破壊的変更になる。だから破壊的変更を恐れてなるべく抽象度の高い型を選ぶ必要性は薄い。そもそも戻り値の型をlistから別のシーケンス型に変えるケースなんてどれくらいあるだろうか?稀に内部処理を変更して処理結果が tuple になることがあったとして、後方互換性を保つ必要があるならlistに変換して返せばいいだけだ。なので、引数よりは少し気を遣うとはいえ、最初から list と書いてしまって良い場合が多い。

結論:

  • タイプヒントには「実装の最小要件」ではなく「想定範囲」を表す型を書く。
  • 迷ったらとりあえず具体型を書いて、必要になってから抽象型に変える。
このブログに乗せているコードは引用を除き CC0 1.0 で提供します。