every Tech Blog

株式会社エブリーのTech Blogです。

MLflow Tracing時に並行処理するとTracingが分かれる問題への対処法

はじめに

こんにちは。デリッシュキッチンでデータサイエンティストをしている古濵です。 今回はニッチな内容ですが、タイトル通りの問題が発生したため、その対処法について備忘録的にまとめます。

動作環境は以下になります。

  • Databricks Runtime: 15.4LTS for ML
  • Python: 3.11.11
  • ライブラリはDatabricks Runtimeのバージョンから以下にアップグレードしています
    openai==1.65.2
    mlflow==2.20.3
    pydantic==2.10.6
    databricks-agents==0.16.0
    databricks-sdk==0.50.0

MLflow Tracingに関するドキュメントは以下になります。

mlflow.org

問題

準備

まず、具体的にどんな問題が発生するかを説明するために、以下のようなコードを用意します。

やりたいこととしては、ユーザーのクエリからフィルタリング条件を抽出するタスクをLLMにさせます。 これは以前書いたテックブログでのフィルタリング処理をベースとしています。

今回は、調理時間と調理費用のみをフィルタリング条件として抽出するタスクとします。

import mlflow
from enum import Enum
from pydantic import BaseModel
import os

os.environ["OPENAI_API_KEY"] = dbutils.secrets.get(...) # your scope and key

# mlflow.traceを自前で対応するためopenaiのautologを無効にする
mlflow.openai.autolog(disable=True)

# 調理時間のフィルタリング条件
class CookingTimeColumn(str, Enum):
    cooking_time = "cooking_time_min"

class CookingTimeOperator(str, Enum):
    greater_than = ">"
    less_than = "<"
    greater_than_or_equal_to = ">="
    less_than_or_equal_to = "<="

class CookingTimeFilter(BaseModel):
    column: CookingTimeColumn
    operator: CookingTimeOperator
    value: float

class CookingTimeFilters(BaseModel):
    cooking_time_filters: list[CookingTimeFilter]


# 調理費用のフィルタリング条件
class CookingCostColumn(str, Enum):
    cooking_cost = "cooking_cost_yen"

class CookingCostOperator(str, Enum):
    greater_than = ">"
    less_than = "<"
    greater_than_or_equal_to = ">="
    less_than_or_equal_to = "<="

class CookingCostFilter(BaseModel):
    column: CookingCostColumn
    operator: CookingCostOperator
    value: float

class CookingCostFilters(BaseModel):
    cooking_cost_filters: list[CookingCostFilter]

ユーザーの入力クエリは以下を例として使用します。

user_query = "10分以内に500円未満で作れる副菜教えて"

MLflow Tracingを使用した関数を定義

MLflow Tracingは、関数に対して@mlflow.traceデコレータを付与することで、関数内の処理を簡単にTracingすることができます。

そのような関数を以下に3つ定義しました。 create_metadata_filter_from_user_query()から、create_cooking_time_filter()とcreate_cooking_cost_filter()を呼び出します。

from openai import OpenAI
from mlflow.entities import SpanType

@mlflow.trace(span_type=SpanType.LLM)
def create_cooking_time_filter(user_query: str) -> list[CookingTimeFilters | CookingCostFilters]:
    client = OpenAI()

    system_prompt = f"""
    あなたは料理の知識が豊富なレシピ検索AIです。
    ユーザーがレシピ検索のために入力したクエリを解読し、ユーザが**調理時間**でフィルタリングして検索したい場合は、フィルタリング条件を返してください。

    ## 出力形式
    * json形式で出力してください
    * columnにカラム名、operatorに不等号、valueにフィルタリング対象を入れてください
    """

    completion = client.beta.chat.completions.parse(
        model = "gpt-4o-mini",
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_query},
        ],
        response_format = CookingTimeFilters,
    )
    structured_outputs = completion.choices[0].message.content
    filters = CookingTimeFilters.model_validate_json(structured_outputs).cooking_time_filters
    return filters

@mlflow.trace(span_type=SpanType.LLM)
def create_cooking_cost_filter(user_query: str) -> list[CookingTimeFilters | CookingCostFilters]:
    client = OpenAI()

    system_prompt = f"""
    あなたは料理の知識が豊富なレシピ検索AIです。
    ユーザーがレシピ検索のために入力したクエリを解読し、ユーザが**調理費用**でフィルタリングして検索したい場合は、フィルタリング条件を返してください。

    ## 出力形式
    * json形式で出力してください
    * columnにカラム名、operatorに不等号、valueにフィルタリング対象を入れてください
    """

    completion = client.beta.chat.completions.parse(
        model = "gpt-4o-mini",
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_query},
        ],
        response_format = CookingCostFilters,
    )
    structured_outputs = completion.choices[0].message.content
    filters = CookingCostFilters.model_validate_json(structured_outputs).cooking_cost_filters
    return filters

@mlflow.trace(span_type=SpanType.CHAIN)
def create_metadata_filter_from_user_query(user_query: str) -> list[CookingTimeFilters | CookingCostFilters]:
    filter_functions = [
        create_cooking_time_filter,
        create_cooking_cost_filter,
    ]
    metadata_filters = []

    for func in filter_functions:
        filters = func(user_query)
        metadata_filters.extend(filters)

    return metadata_filters

実行結果

- create_metadata_filter_from_user_query
    - create_cooking_time_filter
    - create_cooking_cost_filter

のような構造になっており、cooking_timeとcooking_costのフィルタリング条件をそれぞれ1秒ほど(合計約2秒)で抽出できていることがわかります。

ただ、これでは直列に処理しているため、フィルタリング条件が増えれば増えるほど処理時間が長くなってしまいます。

LLMの処理時間が長いのはAPIの待機時間が原因のため、ここを並行処理にすることで処理時間を短縮することができます。

並行処理内でMLflow Tracingを使用

並行処理を行うために、concurrent.futures.ThreadPoolExecutorを使用しました。 create_metadata_filter_from_user_query()内の処理を以下のように変更します。

from concurrent.futures import ThreadPoolExecutor

@mlflow.trace(span_type=SpanType.CHAIN)
def create_metadata_filter_from_user_query(user_query: str) -> list[CookingTimeFilters | CookingCostFilters]:
    filter_functions = [
        create_cooking_time_filter,
        create_cooking_cost_filter,
    ]
    metadata_filters = []

    # 各関数を並行処理で実行するよう修正
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(func, user_query) for func in filter_functions]
        for future in futures:
            filters = future.result()
            metadata_filters.extend(filters)

    return metadata_filters

実行結果

図中の赤枠が示す通り、1、2、3とタブができており、それぞれのTracingが関数単位になっています。

前置きが長くなりましたが、この問題に対処します。

対処

この問題は、OpenTelemetryを使って、親のTracingのContextを子に渡すことで解決できました。

OpenTelemetryはオブザーバビリティ用途で使用されるOSSです。 MLflow Tracingは内部的にはOpenTelemetryを使用しており、OpenTelemetryのContextを使うことで、親のTracingを子に渡すことができます。

以下のように、呼び出し元のcreate_metadata_filter_from_user_query()内のContextを親のTracingとして、子の関数に渡します。 次に、子の関数の処理ではContextのattachをすることで親と子を関連付けることができました(クリーンアップするためにdetachもしています)。

from opentelemetry import context
from opentelemetry.context import Context

def create_filter_with_trace_parent_context(user_query: str, func: callable, trace_parent_context: Context) -> list[CookingTimeFilters | CookingCostFilters]:
    context_token = context.attach(trace_parent_context)
    filters = func(user_query)
    context.detach(context_token)

    return filters

@mlflow.trace(span_type=SpanType.CHAIN)
def create_metadata_filter_from_user_query(user_query: str) -> list[CookingTimeFilters | CookingCostFilters]:
    filter_functions = [
        create_cooking_time_filter,
        create_cooking_cost_filter,
    ]
    metadata_filters = []
    parent_context = context.get_current()

    # 各関数を並行処理で実行するよう修正
    with ThreadPoolExecutor() as executor:
        # create_filter_with_trace_parent_context()で各関数を呼び出すよう修正
        futures = [executor.submit(create_filter_with_trace_parent_context, user_query, func, parent_context) for func in filter_functions]
        for future in futures:
            filters = future.result()
            metadata_filters.extend(filters)

    return metadata_filters

実行結果

MLflow Tracingが分かれずに1つのタブの中にまとまっていることがわかります。 また、全体の処理時間として約1.3秒で終えており、直列にLLMを呼び出すときに比べて高速化できていることもわかるかと思います。

おわりに

MLflow Tracingを使用した際に並行処理でTracingが分かれてしまう問題と、その対処法について紹介しました。 OpenTelemetryのContextを使用することで、親子関係を保ったままのTracingが可能になり、並行処理による高速化とトレーサビリティの両方を実現できました。 同様の問題に遭遇した方の参考になれば幸いです。

なお、LangGraph(とLangChain)を使えば並行処理をしたとしても、MLflow Tracingが分かれる問題は発生しませんでした。 今回の対処法はあくまで応急処置的な側面もあることは補足しておきます。

また、並行処理内でさらに並行処理をするなど処理が複雑化した場合、ワークフローのどの関数が並行に処理されているのかわかりにくくなるかと思います。 そういう意味でも、要件に合わせて処理が複雑化していくにつれ、LangGraphなどのフレームワークを使用する方が可読性や保守性の観点で有効だと考えられます。