在处理一个大型 TensorFlow 模型时,我们遇到了一个棘手的资源竞争问题。该模型需要加载到 GPU 显存中执行推理,单次推理会占用接近 90% 的显存。在我们的架构中,API 请求通过 Tyk 网关路由到后端的推理服务。最初的设计是无状态、水平扩展的 Python 服务,但当两个并发请求几乎同时到达同一个 GPU 节点时,第二个请求在尝试分配显存时会立即导致 CUDA_ERROR_OUT_OF_MEMORY
,进而引发整个服务实例崩溃。
最直接的解决方案是让推理服务变成单线程,但这会牺牲掉高可用性,单点故障风险极高。我们需要的是一种机制,允许多个推理服务实例(Workers)为了高可用而存在,但确保在任何时刻,只有一个 Worker 能够使用共享的 GPU 资源。这本质上是一个跨进程、跨节点的互斥锁问题——一个典型的分布式锁场景。
初步构想是引入一个外部协调服务,比如 Redis 或 etcd。但团队希望尽可能减少技术栈的复杂性。我们已经在使用 NATS JetStream 作为内部事件和消息总线,一个问题随之而来:能否仅用 NATS 来实现这个分布式锁,同时用它来分发推理任务?
答案是肯定的。这不仅能避免引入新的依赖,还能让我们的任务队列和锁机制都构建在同一个高可用的消息系统之上。
架构设计与技术权衡
最终的架构流程如下:
- 入口: 客户端请求通过 Tyk API 网关。Tyk 负责认证、速率限制,并将请求转发给一个轻量级的“任务分发器” (Dispatcher) 服务。
- 任务分发: Dispatcher 是一个无状态的 Go 服务,它接收 HTTP 请求,将其转换为一个结构化的 JSON 消息,然后发布到 NATS JetStream 的一个名为
TENSORFLOW_JOBS
的 Stream 中。 - 任务队列:
TENSORFLOW_JOBS
Stream 是一个持久化的任务队列。多个 TensorFlow Worker 实例会作为消费者组,从中拉取(Pull)任务。 - 资源锁定与执行:
- 一个 Worker 成功拉取到任务后,它首先尝试获取一个代表“GPU-0”资源的分布式锁。
- 这个锁我们将使用 NATS 的 Key-Value (KV) Store 功能实现。KV Store 提供了基于版本号的原子性更新,这正是实现锁的关键。
- 获取锁成功后,Worker 执行 TensorFlow 推理。
- 推理完成后,无论成功或失败,Worker 必须释放锁。
- 结果返回: 推理结果被发布到另一个名为
INFERENCE_RESULTS
的 Stream 中,供其他下游服务消费。
sequenceDiagram participant Client participant Tyk as Tyk API Gateway participant Dispatcher participant NATS as NATS JetStream participant TF_Worker_1 as TensorFlow Worker 1 participant TF_Worker_2 as TensorFlow Worker 2 Client->>+Tyk: POST /v1/predict Tyk->>+Dispatcher: Forward Request Dispatcher->>+NATS: Publish(subject: JOBS.new, data: {...}) NATS-->>-Dispatcher: Ack Dispatcher-->>-Tyk: 202 Accepted Tyk-->>-Client: 202 Accepted loop Job Processing TF_Worker_1->>+NATS: Fetch(subject: JOBS.new) NATS-->>-TF_Worker_1: Job #1 TF_Worker_1->>+NATS: AcquireLock(key: 'gpu-lock') NATS-->>-TF_Worker_1: Lock Acquired TF_Worker_1->>TF_Worker_1: Run TensorFlow Inference... TF_Worker_1->>+NATS: ReleaseLock(key: 'gpu-lock') NATS-->>-TF_Worker_1: Lock Released TF_Worker_1->>+NATS: Publish(subject: RESULTS.new, data: {...}) NATS-->>-TF_Worker_1: Ack Note right of TF_Worker_2: Tries to get lock but fails,
waits or retries. TF_Worker_2->>+NATS: AcquireLock(key: 'gpu-lock') NATS-->>-TF_Worker_2: Lock Held by Another Worker end
这个设计的核心在于利用 NATS KV Store 实现一个带 TTL 的咨询锁(Advisory Lock)。Worker 之间遵循“君子协定”:谁拿到锁,谁用 GPU。
基础设施与 NATS 配置
我们使用 Docker Compose 来模拟这个环境。
docker-compose.yml
:
version: '3.8'
services:
nats:
image: nats:2.9-alpine
ports:
- "4222:4222"
- "8222:8222" # for monitoring
command: "-js -sd /data" # Enable JetStream and specify storage directory
volumes:
- nats_data:/data
volumes:
nats_data:
启动后,我们需要手动创建 JetStream Streams 和 KV Bucket。在真实项目中,这应该通过 Terraform 或启动脚本来完成。
# Install nats-cli
# go install github.com/nats-io/natscli/nats@latest
# Create the job stream
nats stream add TENSORFLOW_JOBS --subjects "JOBS.*" --storage file --retention workq --max-msgs-per-subject 1 --ack
# Create the results stream
nats stream add INFERENCE_RESULTS --subjects "RESULTS.*" --storage file --retention limits --max-age 1h
# Create a KV bucket for our distributed lock
# We set a TTL of 60 seconds. If a worker crashes while holding the lock,
# it will automatically expire.
nats kv add gpu_locks --ttl 60s
这里的 workq
保留策略对任务队列至关重要,它确保一个消息只被一个消费者处理。KV Bucket 的 ttl 60s
是我们的保险丝,防止 Worker 崩溃导致死锁。
分布式锁的 Python 实现
这是整个方案的技术核心。我们基于 nats-py
库,封装一个 NatsDistLock
类。
distributed_lock.py
:
import asyncio
import logging
import uuid
from typing import Optional
from nats.aio.client import Client as NATS
from nats.js.client import JetStreamContext
from nats.js.kv import KeyValue
from nats.errors import TimeoutError as NatsTimeoutError
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class NatsDistLock:
"""
An advisory distributed lock implementation using NATS Key-Value Store.
This lock is atomic by leveraging the `create` operation of the KV store,
which fails if the key already exists.
"""
def __init__(self, nc: NATS, bucket_name: str, lock_key: str, ttl_seconds: int = 60):
"""
Initializes the distributed lock handler.
:param nc: An active NATS client connection.
:param bucket_name: The name of the KV bucket used for locking.
:param lock_key: The specific key representing the resource to be locked.
:param ttl_seconds: Time-to-live for the lock key in seconds.
"""
self.js: JetStreamContext = nc.jetstream()
self.bucket_name = bucket_name
self.lock_key = lock_key
self.ttl_seconds = ttl_seconds
self.kv_store: Optional[KeyValue] = None
self.lock_owner_id = f"owner-{uuid.uuid4()}"
self._is_initialized = False
async def _initialize(self):
"""Lazy initialization of the KV store."""
if not self._is_initialized:
try:
self.kv_store = await self.js.key_value(self.bucket_name)
self._is_initialized = True
logging.info(f"Successfully connected to KV bucket '{self.bucket_name}'")
except Exception as e:
logging.error(f"Failed to connect to KV bucket '{self.bucket_name}': {e}")
raise
async def acquire(self, timeout_seconds: int = 30) -> bool:
"""
Tries to acquire the lock.
Retries every second until the timeout is reached.
:param timeout_seconds: How long to wait for the lock.
:return: True if the lock was acquired, False otherwise.
"""
await self._initialize()
if not self.kv_store:
return False
end_time = asyncio.get_event_loop().time() + timeout_seconds
while asyncio.get_event_loop().time() < end_time:
try:
# The core of the lock: `create` is an atomic operation.
# It succeeds only if the key does not exist.
# We store our unique owner ID to verify ownership on release.
await self.kv_store.create(
key=self.lock_key,
value=self.lock_owner_id.encode()
)
logging.info(f"Lock '{self.lock_key}' acquired by {self.lock_owner_id}")
return True
except Exception as e:
# This exception is expected when the key already exists (lock is held).
logging.debug(f"Failed to acquire lock '{self.lock_key}', likely held by another worker. Retrying...")
await asyncio.sleep(1)
logging.warning(f"Timeout while trying to acquire lock '{self.lock_key}'")
return False
async def release(self) -> bool:
"""
Releases the lock.
It first checks if this instance is the current owner of the lock
before deleting the key. This prevents a worker from accidentally
releasing a lock acquired by another worker after a delay.
:return: True if the lock was released, False otherwise.
"""
await self._initialize()
if not self.kv_store:
return False
try:
entry = await self.kv_store.get(self.lock_key)
if entry.value.decode() == self.lock_owner_id:
# We are the owner, safe to delete.
# We use `purge` instead of `delete` to completely remove the key and its history.
await self.kv_store.purge(self.lock_key)
logging.info(f"Lock '{self.lock_key}' released by owner {self.lock_owner_id}")
return True
else:
logging.warning(f"Attempted to release lock '{self.lock_key}' but not the owner. "
f"Current owner: {entry.value.decode()}, this worker: {self.lock_owner_id}")
return False
except Exception:
# Key might not exist if it expired or was already released.
logging.info(f"Lock '{self.lock_key}' not found on release, assuming it's already gone.")
return True
一个常见的错误是在 release
时不检查所有权。如果一个 Worker 因为 GC 暂停或其他原因延迟了,它持有的锁可能已经因 TTL 过期,并被另一个 Worker 获取。如果此时第一个 Worker 恢复并执行 release
,它可能会错误地释放掉第二个 Worker 的锁。通过在值中存储一个唯一的 owner_id
并在释放前进行检查,我们避免了这种竞争条件。
TensorFlow Worker 的实现
Worker 是一个长时间运行的 Python 进程,它在一个循环中不断拉取、锁定、处理和释放。
tf_worker.py
:
import asyncio
import json
import logging
import time
import os
import nats
from nats.aio.client import Client as NATS
from distributed_lock import NatsDistLock
# --- Mock TensorFlow Inference ---
# In a real scenario, this would import TensorFlow and load a model.
def run_inference(data: dict) -> dict:
"""A mock function to simulate a GPU-intensive TensorFlow task."""
logging.info(f"Starting simulated inference for request_id: {data.get('request_id')}")
# Simulate heavy GPU work
time.sleep(5)
result = {"prediction": [0.98, 0.01, 0.01], "request_id": data.get("request_id")}
logging.info(f"Finished inference for request_id: {data.get('request_id')}")
return result
# --- End Mock ---
class TensorFlowWorker:
def __init__(self, nats_url: str, job_stream: str, result_stream: str, lock_bucket: str, lock_key: str):
self.nats_url = nats_url
self.job_stream = job_stream
self.result_stream = result_stream
self.lock_bucket = lock_bucket
self.lock_key = lock_key
self.nc: NATS = None
self.lock: NatsDistLock = None
async def connect(self):
"""Connects to NATS and initializes the lock."""
try:
self.nc = await nats.connect(self.nats_url)
self.lock = NatsDistLock(self.nc, self.lock_bucket, self.lock_key)
logging.info(f"Worker connected to NATS at {self.nats_url}")
except Exception as e:
logging.critical(f"Failed to connect to NATS: {e}")
raise
async def start_processing(self):
"""Main loop to fetch and process jobs."""
if not self.nc:
raise RuntimeError("NATS connection not established. Call connect() first.")
js = self.nc.jetstream()
# Pull-based subscription is better for controlling workload.
# We process one message at a time.
sub = await js.pull_subscribe(
subject="JOBS.*",
durable="tf-worker-group", # All workers share this durable name
)
logging.info("Worker is ready and waiting for jobs...")
while True:
try:
msgs = await sub.fetch(1, timeout=60)
msg = msgs[0]
logging.info(f"Fetched job: {msg.subject} with data: {msg.data.decode()}")
job_data = json.loads(msg.data.decode())
# The critical section starts here
logging.info("Attempting to acquire GPU lock...")
if await self.lock.acquire(timeout_seconds=120):
try:
# Once lock is acquired, run the inference
result = run_inference(job_data)
# Publish result
await js.publish(f"RESULTS.{job_data.get('request_id')}", json.dumps(result).encode())
# Acknowledge the job message only after successful processing and result publishing
await msg.ack()
logging.info(f"Job for request_id {job_data.get('request_id')} processed and acknowledged.")
except Exception as e:
logging.error(f"Error during inference for {job_data.get('request_id')}: {e}")
# Do not ack the message, let it be redelivered after ack_wait time.
# NACKing would be an option too.
finally:
# CRITICAL: Always release the lock
logging.info("Releasing GPU lock...")
await self.lock.release()
else:
logging.warning(f"Could not acquire GPU lock for job {job_data.get('request_id')}. Message will be redelivered.")
# We don't ACK, so the message will be available again for another worker.
except nats.errors.TimeoutError:
# No new messages, just continue waiting
continue
except Exception as e:
logging.error(f"An unexpected error occurred in the main loop: {e}")
await asyncio.sleep(5) # Avoid fast-spinning on persistent errors
async def close(self):
if self.nc:
await self.nc.close()
logging.info("NATS connection closed.")
async def main():
worker = TensorFlowWorker(
nats_url=os.getenv("NATS_URL", "nats://localhost:4222"),
job_stream="TENSORFLOW_JOBS",
result_stream="INFERENCE_RESULTS",
lock_bucket="gpu_locks",
lock_key="gpu-0-lock"
)
try:
await worker.connect()
await worker.start_processing()
finally:
await worker.close()
if __name__ == '__main__':
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Worker shutting down.")
try...finally
块的使用至关重要。无论推理成功与否,甚至发生异常,锁的释放都必须被执行,否则整个系统将被永久阻塞,直到锁的TTL过期。
任务分发器 (Dispatcher)
这个服务可以非常简单,用 Go 语言实现能获得极佳的性能和非常小的资源占用。
dispatcher/main.go
:
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"time"
"github.com/google/uuid"
"github.com/nats-io/nats.go"
)
type PredictRequest struct {
ImageData string `json:"image_data"` // base64 encoded image
}
type JobPayload struct {
RequestID string `json:"request_id"`
Data string `json:"image_data"`
Timestamp string `json:"timestamp"`
}
func main() {
natsURL := os.Getenv("NATS_URL")
if natsURL == "" {
natsURL = nats.DefaultURL
}
nc, err := nats.Connect(natsURL)
if err != nil {
log.Fatalf("Error connecting to NATS: %v", err)
}
defer nc.Close()
js, err := nc.JetStream()
if err != nil {
log.Fatalf("Error getting JetStream context: %v", err)
}
http.HandleFunc("/predict", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed)
return
}
var req PredictRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
payload := JobPayload{
RequestID: uuid.New().String(),
Data: req.ImageData,
Timestamp: time.Now().UTC().Format(time.RFC3339),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
http.Error(w, "Failed to serialize job payload", http.StatusInternalServerError)
return
}
// Publish to the jobs stream. The subject can be used for routing if needed.
_, err = js.Publish("JOBS.new", payloadBytes)
if err != nil {
log.Printf("Error publishing to JetStream: %v", err)
http.Error(w, "Failed to enqueue job", http.StatusInternalServerError)
return
}
log.Printf("Enqueued job with RequestID: %s", payload.RequestID)
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{"request_id": payload.RequestID})
})
log.Println("Dispatcher service starting on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatalf("Could not start server: %s\n", err)
}
}
Tyk 的配置只需要设置一个上游 API 指向这个 Dispatcher 服务即可。这种异步架构的好处是,客户端可以立即得到一个 202 Accepted
响应和一个 request_id
,而无需等待耗时的推理完成。
局限性与未来展望
这套基于 NATS 的方案优雅地解决了我们的 GPU 资源竞争问题,且没有增加新的基础设施依赖。但它并非没有局限性。
首先,这是一个咨询锁。它依赖于所有 Worker 都遵守锁定协议。一个行为不端的或有 bug 的 Worker 可以无视锁直接访问 GPU,从而导致系统崩溃。在受控的环境中,这是可以接受的权衡。
其次,对于需要管理多个不同 GPU 资源的场景(例如,一个 GPU 池),当前的单一 gpu-0-lock
键需要扩展为一个更复杂的锁管理器。可能需要一个专门的 NATS 主题来请求可用资源,由一个协调器来分配锁。
最后,锁的 TTL 机制是一个重要的安全网,但它也意味着如果一个任务的实际执行时间超过了 TTL,它的锁可能会被中途释放,导致另一个 Worker 开始执行,这同样会引发资源冲突。因此,TTL 的设置必须远大于预期的最大任务执行时间,并且需要有监控告警来发现即将超时的任务。
未来的一个优化方向是,将锁的管理与 Worker 的健康检查结合起来。例如,Worker 可以定期“心跳”更新锁的 TTL,只有当 Worker 进程确实死亡时,锁才会最终过期。NATS 的 Watch 功能可以用来监控锁的状态变化,实现更复杂的协调逻辑。