基于 NATS JetStream 与分布式锁实现 TensorFlow GPU 资源序列化推理


在处理一个大型 TensorFlow 模型时,我们遇到了一个棘手的资源竞争问题。该模型需要加载到 GPU 显存中执行推理,单次推理会占用接近 90% 的显存。在我们的架构中,API 请求通过 Tyk 网关路由到后端的推理服务。最初的设计是无状态、水平扩展的 Python 服务,但当两个并发请求几乎同时到达同一个 GPU 节点时,第二个请求在尝试分配显存时会立即导致 CUDA_ERROR_OUT_OF_MEMORY,进而引发整个服务实例崩溃。

最直接的解决方案是让推理服务变成单线程,但这会牺牲掉高可用性,单点故障风险极高。我们需要的是一种机制,允许多个推理服务实例(Workers)为了高可用而存在,但确保在任何时刻,只有一个 Worker 能够使用共享的 GPU 资源。这本质上是一个跨进程、跨节点的互斥锁问题——一个典型的分布式锁场景。

初步构想是引入一个外部协调服务,比如 Redis 或 etcd。但团队希望尽可能减少技术栈的复杂性。我们已经在使用 NATS JetStream 作为内部事件和消息总线,一个问题随之而来:能否仅用 NATS 来实现这个分布式锁,同时用它来分发推理任务?

答案是肯定的。这不仅能避免引入新的依赖,还能让我们的任务队列和锁机制都构建在同一个高可用的消息系统之上。

架构设计与技术权衡

最终的架构流程如下:

  1. 入口: 客户端请求通过 Tyk API 网关。Tyk 负责认证、速率限制,并将请求转发给一个轻量级的“任务分发器” (Dispatcher) 服务。
  2. 任务分发: Dispatcher 是一个无状态的 Go 服务,它接收 HTTP 请求,将其转换为一个结构化的 JSON 消息,然后发布到 NATS JetStream 的一个名为 TENSORFLOW_JOBS 的 Stream 中。
  3. 任务队列: TENSORFLOW_JOBS Stream 是一个持久化的任务队列。多个 TensorFlow Worker 实例会作为消费者组,从中拉取(Pull)任务。
  4. 资源锁定与执行:
    • 一个 Worker 成功拉取到任务后,它首先尝试获取一个代表“GPU-0”资源的分布式锁。
    • 这个锁我们将使用 NATS 的 Key-Value (KV) Store 功能实现。KV Store 提供了基于版本号的原子性更新,这正是实现锁的关键。
    • 获取锁成功后,Worker 执行 TensorFlow 推理。
    • 推理完成后,无论成功或失败,Worker 必须释放锁。
  5. 结果返回: 推理结果被发布到另一个名为 INFERENCE_RESULTS 的 Stream 中,供其他下游服务消费。
sequenceDiagram
    participant Client
    participant Tyk as Tyk API Gateway
    participant Dispatcher
    participant NATS as NATS JetStream
    participant TF_Worker_1 as TensorFlow Worker 1
    participant TF_Worker_2 as TensorFlow Worker 2

    Client->>+Tyk: POST /v1/predict
    Tyk->>+Dispatcher: Forward Request
    Dispatcher->>+NATS: Publish(subject: JOBS.new, data: {...})
    NATS-->>-Dispatcher: Ack
    Dispatcher-->>-Tyk: 202 Accepted
    Tyk-->>-Client: 202 Accepted

    loop Job Processing
        TF_Worker_1->>+NATS: Fetch(subject: JOBS.new)
        NATS-->>-TF_Worker_1: Job #1
        TF_Worker_1->>+NATS: AcquireLock(key: 'gpu-lock')
        NATS-->>-TF_Worker_1: Lock Acquired
        TF_Worker_1->>TF_Worker_1: Run TensorFlow Inference...
        TF_Worker_1->>+NATS: ReleaseLock(key: 'gpu-lock')
        NATS-->>-TF_Worker_1: Lock Released
        TF_Worker_1->>+NATS: Publish(subject: RESULTS.new, data: {...})
        NATS-->>-TF_Worker_1: Ack

        Note right of TF_Worker_2: Tries to get lock but fails, 
waits or retries. TF_Worker_2->>+NATS: AcquireLock(key: 'gpu-lock') NATS-->>-TF_Worker_2: Lock Held by Another Worker end

这个设计的核心在于利用 NATS KV Store 实现一个带 TTL 的咨询锁(Advisory Lock)。Worker 之间遵循“君子协定”:谁拿到锁,谁用 GPU。

基础设施与 NATS 配置

我们使用 Docker Compose 来模拟这个环境。

docker-compose.yml:

version: '3.8'

services:
  nats:
    image: nats:2.9-alpine
    ports:
      - "4222:4222"
      - "8222:8222" # for monitoring
    command: "-js -sd /data" # Enable JetStream and specify storage directory
    volumes:
      - nats_data:/data

volumes:
  nats_data:

启动后,我们需要手动创建 JetStream Streams 和 KV Bucket。在真实项目中,这应该通过 Terraform 或启动脚本来完成。

# Install nats-cli
# go install github.com/nats-io/natscli/nats@latest

# Create the job stream
nats stream add TENSORFLOW_JOBS --subjects "JOBS.*" --storage file --retention workq --max-msgs-per-subject 1 --ack

# Create the results stream
nats stream add INFERENCE_RESULTS --subjects "RESULTS.*" --storage file --retention limits --max-age 1h

# Create a KV bucket for our distributed lock
# We set a TTL of 60 seconds. If a worker crashes while holding the lock,
# it will automatically expire.
nats kv add gpu_locks --ttl 60s

这里的 workq 保留策略对任务队列至关重要,它确保一个消息只被一个消费者处理。KV Bucket 的 ttl 60s 是我们的保险丝,防止 Worker 崩溃导致死锁。

分布式锁的 Python 实现

这是整个方案的技术核心。我们基于 nats-py 库,封装一个 NatsDistLock 类。

distributed_lock.py:

import asyncio
import logging
import uuid
from typing import Optional

from nats.aio.client import Client as NATS
from nats.js.client import JetStreamContext
from nats.js.kv import KeyValue
from nats.errors import TimeoutError as NatsTimeoutError

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class NatsDistLock:
    """
    An advisory distributed lock implementation using NATS Key-Value Store.

    This lock is atomic by leveraging the `create` operation of the KV store,
    which fails if the key already exists.
    """

    def __init__(self, nc: NATS, bucket_name: str, lock_key: str, ttl_seconds: int = 60):
        """
        Initializes the distributed lock handler.

        :param nc: An active NATS client connection.
        :param bucket_name: The name of the KV bucket used for locking.
        :param lock_key: The specific key representing the resource to be locked.
        :param ttl_seconds: Time-to-live for the lock key in seconds.
        """
        self.js: JetStreamContext = nc.jetstream()
        self.bucket_name = bucket_name
        self.lock_key = lock_key
        self.ttl_seconds = ttl_seconds
        self.kv_store: Optional[KeyValue] = None
        self.lock_owner_id = f"owner-{uuid.uuid4()}"
        self._is_initialized = False

    async def _initialize(self):
        """Lazy initialization of the KV store."""
        if not self._is_initialized:
            try:
                self.kv_store = await self.js.key_value(self.bucket_name)
                self._is_initialized = True
                logging.info(f"Successfully connected to KV bucket '{self.bucket_name}'")
            except Exception as e:
                logging.error(f"Failed to connect to KV bucket '{self.bucket_name}': {e}")
                raise

    async def acquire(self, timeout_seconds: int = 30) -> bool:
        """
        Tries to acquire the lock.

        Retries every second until the timeout is reached.

        :param timeout_seconds: How long to wait for the lock.
        :return: True if the lock was acquired, False otherwise.
        """
        await self._initialize()
        if not self.kv_store:
            return False

        end_time = asyncio.get_event_loop().time() + timeout_seconds
        while asyncio.get_event_loop().time() < end_time:
            try:
                # The core of the lock: `create` is an atomic operation.
                # It succeeds only if the key does not exist.
                # We store our unique owner ID to verify ownership on release.
                await self.kv_store.create(
                    key=self.lock_key,
                    value=self.lock_owner_id.encode()
                )
                logging.info(f"Lock '{self.lock_key}' acquired by {self.lock_owner_id}")
                return True
            except Exception as e:
                # This exception is expected when the key already exists (lock is held).
                logging.debug(f"Failed to acquire lock '{self.lock_key}', likely held by another worker. Retrying...")
                await asyncio.sleep(1)
        
        logging.warning(f"Timeout while trying to acquire lock '{self.lock_key}'")
        return False

    async def release(self) -> bool:
        """
        Releases the lock.

        It first checks if this instance is the current owner of the lock
        before deleting the key. This prevents a worker from accidentally
        releasing a lock acquired by another worker after a delay.

        :return: True if the lock was released, False otherwise.
        """
        await self._initialize()
        if not self.kv_store:
            return False
            
        try:
            entry = await self.kv_store.get(self.lock_key)
            if entry.value.decode() == self.lock_owner_id:
                # We are the owner, safe to delete.
                # We use `purge` instead of `delete` to completely remove the key and its history.
                await self.kv_store.purge(self.lock_key)
                logging.info(f"Lock '{self.lock_key}' released by owner {self.lock_owner_id}")
                return True
            else:
                logging.warning(f"Attempted to release lock '{self.lock_key}' but not the owner. "
                                f"Current owner: {entry.value.decode()}, this worker: {self.lock_owner_id}")
                return False
        except Exception:
            # Key might not exist if it expired or was already released.
            logging.info(f"Lock '{self.lock_key}' not found on release, assuming it's already gone.")
            return True

一个常见的错误是在 release 时不检查所有权。如果一个 Worker 因为 GC 暂停或其他原因延迟了,它持有的锁可能已经因 TTL 过期,并被另一个 Worker 获取。如果此时第一个 Worker 恢复并执行 release,它可能会错误地释放掉第二个 Worker 的锁。通过在值中存储一个唯一的 owner_id 并在释放前进行检查,我们避免了这种竞争条件。

TensorFlow Worker 的实现

Worker 是一个长时间运行的 Python 进程,它在一个循环中不断拉取、锁定、处理和释放。

tf_worker.py:

import asyncio
import json
import logging
import time
import os
import nats
from nats.aio.client import Client as NATS

from distributed_lock import NatsDistLock

# --- Mock TensorFlow Inference ---
# In a real scenario, this would import TensorFlow and load a model.
def run_inference(data: dict) -> dict:
    """A mock function to simulate a GPU-intensive TensorFlow task."""
    logging.info(f"Starting simulated inference for request_id: {data.get('request_id')}")
    # Simulate heavy GPU work
    time.sleep(5) 
    result = {"prediction": [0.98, 0.01, 0.01], "request_id": data.get("request_id")}
    logging.info(f"Finished inference for request_id: {data.get('request_id')}")
    return result
# --- End Mock ---

class TensorFlowWorker:
    def __init__(self, nats_url: str, job_stream: str, result_stream: str, lock_bucket: str, lock_key: str):
        self.nats_url = nats_url
        self.job_stream = job_stream
        self.result_stream = result_stream
        self.lock_bucket = lock_bucket
        self.lock_key = lock_key
        self.nc: NATS = None
        self.lock: NatsDistLock = None

    async def connect(self):
        """Connects to NATS and initializes the lock."""
        try:
            self.nc = await nats.connect(self.nats_url)
            self.lock = NatsDistLock(self.nc, self.lock_bucket, self.lock_key)
            logging.info(f"Worker connected to NATS at {self.nats_url}")
        except Exception as e:
            logging.critical(f"Failed to connect to NATS: {e}")
            raise

    async def start_processing(self):
        """Main loop to fetch and process jobs."""
        if not self.nc:
            raise RuntimeError("NATS connection not established. Call connect() first.")

        js = self.nc.jetstream()
        # Pull-based subscription is better for controlling workload.
        # We process one message at a time.
        sub = await js.pull_subscribe(
            subject="JOBS.*",
            durable="tf-worker-group", # All workers share this durable name
        )

        logging.info("Worker is ready and waiting for jobs...")
        while True:
            try:
                msgs = await sub.fetch(1, timeout=60)
                msg = msgs[0]
                logging.info(f"Fetched job: {msg.subject} with data: {msg.data.decode()}")

                job_data = json.loads(msg.data.decode())

                # The critical section starts here
                logging.info("Attempting to acquire GPU lock...")
                if await self.lock.acquire(timeout_seconds=120):
                    try:
                        # Once lock is acquired, run the inference
                        result = run_inference(job_data)
                        
                        # Publish result
                        await js.publish(f"RESULTS.{job_data.get('request_id')}", json.dumps(result).encode())
                        
                        # Acknowledge the job message only after successful processing and result publishing
                        await msg.ack()
                        logging.info(f"Job for request_id {job_data.get('request_id')} processed and acknowledged.")

                    except Exception as e:
                        logging.error(f"Error during inference for {job_data.get('request_id')}: {e}")
                        # Do not ack the message, let it be redelivered after ack_wait time.
                        # NACKing would be an option too.
                    finally:
                        # CRITICAL: Always release the lock
                        logging.info("Releasing GPU lock...")
                        await self.lock.release()
                else:
                    logging.warning(f"Could not acquire GPU lock for job {job_data.get('request_id')}. Message will be redelivered.")
                    # We don't ACK, so the message will be available again for another worker.
            
            except nats.errors.TimeoutError:
                # No new messages, just continue waiting
                continue
            except Exception as e:
                logging.error(f"An unexpected error occurred in the main loop: {e}")
                await asyncio.sleep(5) # Avoid fast-spinning on persistent errors

    async def close(self):
        if self.nc:
            await self.nc.close()
            logging.info("NATS connection closed.")

async def main():
    worker = TensorFlowWorker(
        nats_url=os.getenv("NATS_URL", "nats://localhost:4222"),
        job_stream="TENSORFLOW_JOBS",
        result_stream="INFERENCE_RESULTS",
        lock_bucket="gpu_locks",
        lock_key="gpu-0-lock"
    )
    try:
        await worker.connect()
        await worker.start_processing()
    finally:
        await worker.close()

if __name__ == '__main__':
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        print("Worker shutting down.")

try...finally 块的使用至关重要。无论推理成功与否,甚至发生异常,锁的释放都必须被执行,否则整个系统将被永久阻塞,直到锁的TTL过期。

任务分发器 (Dispatcher)

这个服务可以非常简单,用 Go 语言实现能获得极佳的性能和非常小的资源占用。

dispatcher/main.go:

package main

import (
	"encoding/json"
	"log"
	"net/http"
	"os"
	"time"

	"github.com/google/uuid"
	"github.com/nats-io/nats.go"
)

type PredictRequest struct {
	ImageData string `json:"image_data"` // base64 encoded image
}

type JobPayload struct {
	RequestID string `json:"request_id"`
	Data      string `json:"image_data"`
	Timestamp string `json:"timestamp"`
}

func main() {
	natsURL := os.Getenv("NATS_URL")
	if natsURL == "" {
		natsURL = nats.DefaultURL
	}

	nc, err := nats.Connect(natsURL)
	if err != nil {
		log.Fatalf("Error connecting to NATS: %v", err)
	}
	defer nc.Close()

	js, err := nc.JetStream()
	if err != nil {
		log.Fatalf("Error getting JetStream context: %v", err)
	}

	http.HandleFunc("/predict", func(w http.ResponseWriter, r *http.Request) {
		if r.Method != http.MethodPost {
			http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed)
			return
		}

		var req PredictRequest
		if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
			http.Error(w, "Invalid request body", http.StatusBadRequest)
			return
		}

		payload := JobPayload{
			RequestID: uuid.New().String(),
			Data:      req.ImageData,
			Timestamp: time.Now().UTC().Format(time.RFC3339),
		}

		payloadBytes, err := json.Marshal(payload)
		if err != nil {
			http.Error(w, "Failed to serialize job payload", http.StatusInternalServerError)
			return
		}

		// Publish to the jobs stream. The subject can be used for routing if needed.
		_, err = js.Publish("JOBS.new", payloadBytes)
		if err != nil {
			log.Printf("Error publishing to JetStream: %v", err)
			http.Error(w, "Failed to enqueue job", http.StatusInternalServerError)
			return
		}

		log.Printf("Enqueued job with RequestID: %s", payload.RequestID)
		w.WriteHeader(http.StatusAccepted)
		json.NewEncoder(w).Encode(map[string]string{"request_id": payload.RequestID})
	})

	log.Println("Dispatcher service starting on :8080")
	if err := http.ListenAndServe(":8080", nil); err != nil {
		log.Fatalf("Could not start server: %s\n", err)
	}
}

Tyk 的配置只需要设置一个上游 API 指向这个 Dispatcher 服务即可。这种异步架构的好处是,客户端可以立即得到一个 202 Accepted 响应和一个 request_id,而无需等待耗时的推理完成。

局限性与未来展望

这套基于 NATS 的方案优雅地解决了我们的 GPU 资源竞争问题,且没有增加新的基础设施依赖。但它并非没有局限性。

首先,这是一个咨询锁。它依赖于所有 Worker 都遵守锁定协议。一个行为不端的或有 bug 的 Worker 可以无视锁直接访问 GPU,从而导致系统崩溃。在受控的环境中,这是可以接受的权衡。

其次,对于需要管理多个不同 GPU 资源的场景(例如,一个 GPU 池),当前的单一 gpu-0-lock 键需要扩展为一个更复杂的锁管理器。可能需要一个专门的 NATS 主题来请求可用资源,由一个协调器来分配锁。

最后,锁的 TTL 机制是一个重要的安全网,但它也意味着如果一个任务的实际执行时间超过了 TTL,它的锁可能会被中途释放,导致另一个 Worker 开始执行,这同样会引发资源冲突。因此,TTL 的设置必须远大于预期的最大任务执行时间,并且需要有监控告警来发现即将超时的任务。

未来的一个优化方向是,将锁的管理与 Worker 的健康检查结合起来。例如,Worker 可以定期“心跳”更新锁的 TTL,只有当 Worker 进程确实死亡时,锁才会最终过期。NATS 的 Watch 功能可以用来监控锁的状态变化,实现更复杂的协调逻辑。


  目录