FastAPI + WebSocket 实时通信实战指南:从基础到流式对话

FastAPI + WebSocket 实时通信实战指南:从基础到流式对话

简介

WebSocket 是构建实时应用的核心协议,而 FastAPI 对 WebSocket 的原生支持让它成为 AI 对话、实时通知、协作编辑等场景的首选后端框架。本文从零开始,覆盖 FastAPI WebSocket 的基础连接、消息协议设计、流式响应、断线重连、多客户端管理、生产部署等完整环节,最后以构建一个 AI 对话引擎为例串联所有知识点。

前置要求

  • Python 3.10+
  • 已安装 FastAPI:pip install fastapi uvicorn websockets
  • 了解基本的异步编程(async/await)
  • 可选:OpenAI / DeepSeek API Key(用于流式对话示例)

一、FastAPI WebSocket 基础

1.1 最简单的 WebSocket 端点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from fastapi import FastAPI, WebSocket
import uvicorn

app = FastAPI()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"你说了: {data}")

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

启动后,用浏览器或 wscat 测试:

1
2
3
4
5
6
7
# 安装 wscat
npm install -g wscat

# 连接测试
wscat -c ws://localhost:8000/ws
> 你好
< 你说了: 你好

1.2 连接生命周期管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from fastapi import WebSocket, WebSocketDisconnect

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"收到: {data}")
except WebSocketDisconnect:
print("客户端断开连接")
except Exception as e:
print(f"连接异常: {e}")
finally:
print("清理连接资源")

1.3 接收多种数据类型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
try:
# 自动识别消息类型
data = await websocket.receive_json()
msg_type = data.get("type")

if msg_type == "ping":
await websocket.send_json({"type": "pong"})
elif msg_type == "message":
await websocket.send_json({
"type": "echo",
"content": data["content"]
})
except WebSocketDisconnect:
break

二、消息协议设计

2.1 标准消息格式

为 WebSocket 通信定义统一的消息协议,这是构建复杂应用的基础:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from pydantic import BaseModel
from typing import Optional
from enum import Enum
import json

class MessageType(str, Enum):
# 系统消息
PING = "ping"
PONG = "pong"
ERROR = "error"

# 对话消息
USER_MESSAGE = "user_message"
AI_MESSAGE = "ai_message"
AI_STREAM_CHUNK = "ai_stream_chunk"
AI_STREAM_END = "ai_stream_end"

# 控制消息
CONVERSATION_START = "conversation_start"
CONVERSATION_END = "conversation_end"
TYPING = "typing"

class WSMessage(BaseModel):
type: MessageType
content: Optional[str] = None
conversation_id: Optional[str] = None
metadata: Optional[dict] = None

def to_json(self) -> str:
return self.model_dump_json(exclude_none=True)

@classmethod
def from_json(cls, data: str) -> "WSMessage":
return cls(**json.loads(data))

2.2 心跳保活机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import asyncio

class ConnectionManager:
def __init__(self):
self.active_connections: dict[str, WebSocket] = {}
self.heartbeat_tasks: dict[str, asyncio.Task] = {}

async def connect(self, client_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[client_id] = websocket
# 启动心跳检测
self.heartbeat_tasks[client_id] = asyncio.create_task(
self._heartbeat_check(client_id, websocket)
)

async def disconnect(self, client_id: str):
if client_id in self.heartbeat_tasks:
self.heartbeat_tasks[client_id].cancel()
del self.heartbeat_tasks[client_id]
self.active_connections.pop(client_id, None)

async def _heartbeat_check(self, client_id: str, websocket: WebSocket, interval: int = 30):
"""每 30 秒发送一次 ping,超时 10 秒未收到 pong 则断开"""
try:
while True:
await asyncio.sleep(interval)
try:
await asyncio.wait_for(
websocket.send_json({"type": "ping"}),
timeout=5
)
# 等待 pong
response = await asyncio.wait_for(
websocket.receive_json(),
timeout=10
)
if response.get("type") != "pong":
raise Exception("无效心跳响应")
except asyncio.TimeoutError:
print(f"客户端 {client_id} 心跳超时,断开连接")
await self.disconnect(client_id)
break
except asyncio.CancelledError:
pass

三、流式响应:AI 对话的核心

3.1 模拟流式输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import asyncio
import random

async def simulate_stream_response(websocket: WebSocket, text: str):
"""模拟 AI 逐字输出"""
words = text.split(" ")
for word in words:
await websocket.send_json({
"type": "ai_stream_chunk",
"content": word + " ",
"metadata": {"index": words.index(word)}
})
await asyncio.sleep(random.uniform(0.05, 0.2)) # 模拟生成延迟

await websocket.send_json({
"type": "ai_stream_end",
"content": "",
"metadata": {"total_words": len(words)}
})

@app.websocket("/chat")
async def chat_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
if data["type"] == "user_message":
# 模拟 AI 思考
await websocket.send_json({"type": "typing"})
await asyncio.sleep(0.5)

# 流式返回
response = f"你说的是: {data['content']}。让我想想..."
await simulate_stream_response(websocket, response)
except WebSocketDisconnect:
pass

3.2 集成真实 LLM 流式 API

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import httpx
import json

async def stream_llm_response(
websocket: WebSocket,
messages: list[dict],
api_key: str,
model: str = "deepseek-chat"
):
"""对接 DeepSeek/OpenAI 流式 API"""
async with httpx.AsyncClient(timeout=60) as client:
async with client.stream(
"POST",
"https://api.deepseek.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": model,
"messages": messages,
"stream": True
}
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
await websocket.send_json({
"type": "ai_stream_end"
})
return

try:
chunk = json.loads(data_str)
delta = chunk["choices"][0]["delta"]
content = delta.get("content", "")

if content:
await websocket.send_json({
"type": "ai_stream_chunk",
"content": content
})
except json.JSONDecodeError:
continue

@app.websocket("/ai-chat")
async def ai_chat_endpoint(websocket: WebSocket):
await websocket.accept()
messages = [{"role": "system", "content": "你是一个友好的 AI 助手。"}]

try:
while True:
data = await websocket.receive_json()
if data["type"] == "user_message":
messages.append({
"role": "user",
"content": data["content"]
})

# 发送 typing 指示
await websocket.send_json({"type": "typing"})

# 流式获取 AI 回复
await stream_llm_response(
websocket, messages,
api_key="your-api-key"
)

# 完整回复追加到上下文
# 注意:实际需要收集所有 chunk 拼接成完整消息
except WebSocketDisconnect:
pass

3.3 收集流式 chunk 拼接完整消息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
async def stream_and_collect(
websocket: WebSocket,
messages: list[dict],
api_key: str
) -> str:
"""流式发送给客户端,同时收集完整回复"""
full_content = ""

async with httpx.AsyncClient(timeout=60) as client:
async with client.stream(
"POST",
"https://api.deepseek.com/v1/chat/completions",
headers={"Authorization": f"Bearer {api_key}"},
json={
"model": "deepseek-chat",
"messages": messages,
"stream": True
}
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
await websocket.send_json({"type": "ai_stream_end"})
return full_content

try:
chunk = json.loads(data_str)
delta = chunk["choices"][0]["delta"]
content = delta.get("content", "")

if content:
full_content += content
await websocket.send_json({
"type": "ai_stream_chunk",
"content": content
})
except json.JSONDecodeError:
continue

return full_content

# 在对话循环中使用
# full_reply = await stream_and_collect(websocket, messages, api_key)
# messages.append({"role": "assistant", "content": full_reply})

四、多客户端管理与房间系统

4.1 连接管理器(完整版)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Optional
import uuid

class Room:
def __init__(self, room_id: str):
self.room_id = room_id
self.clients: dict[str, WebSocket] = {}

async def broadcast(self, message: dict, exclude: Optional[str] = None):
"""向房间内所有客户端广播消息"""
for client_id, ws in self.clients.items():
if client_id != exclude:
try:
await ws.send_json(message)
except Exception:
pass

class ChatManager:
def __init__(self):
self.rooms: dict[str, Room] = {}
self.client_rooms: dict[str, str] = {} # client_id -> room_id

async def join_room(self, room_id: str, client_id: str, websocket: WebSocket):
if room_id not in self.rooms:
self.rooms[room_id] = Room(room_id)

await websocket.accept()
self.rooms[room_id].clients[client_id] = websocket
self.client_rooms[client_id] = room_id

# 通知房间其他成员
await self.rooms[room_id].broadcast(
{"type": "user_joined", "content": f"用户 {client_id} 加入了房间"},
exclude=client_id
)

async def leave_room(self, client_id: str):
room_id = self.client_rooms.get(client_id)
if room_id and room_id in self.rooms:
room = self.rooms[room_id]
room.clients.pop(client_id, None)

await room.broadcast(
{"type": "user_left", "content": f"用户 {client_id} 离开了房间"}
)

# 房间为空时清理
if not room.clients:
del self.rooms[room_id]

self.client_rooms.pop(client_id, None)

chat_manager = ChatManager()

@app.websocket("/room/{room_id}")
async def room_endpoint(websocket: WebSocket, room_id: str):
client_id = str(uuid.uuid4())[:8]
await chat_manager.join_room(room_id, client_id, websocket)

try:
while True:
data = await websocket.receive_json()
if data["type"] == "message":
room = chat_manager.rooms.get(room_id)
if room:
await room.broadcast({
"type": "message",
"client_id": client_id,
"content": data["content"]
})
except WebSocketDisconnect:
await chat_manager.leave_room(client_id)

4.2 连接数限制与限流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
import time

class RateLimiter:
def __init__(self, max_connections: int = 100, rate_per_second: int = 10):
self.max_connections = max_connections
self.rate_per_second = rate_per_second
self.connections: dict[str, list[float]] = {}

def can_connect(self, client_id: str) -> bool:
if len(self.connections) >= self.max_connections:
return False
return True

def check_rate_limit(self, client_id: str) -> bool:
now = time.time()
timestamps = self.connections.get(client_id, [])

# 清理 1 秒前的记录
timestamps = [t for t in timestamps if now - t < 1]

if len(timestamps) >= self.rate_per_second:
return False

timestamps.append(now)
self.connections[client_id] = timestamps
return True

rate_limiter = RateLimiter()

@app.websocket("/chat")
async def chat_with_limit(websocket: WebSocket, client_id: str = "anonymous"):
if not rate_limiter.can_connect(client_id):
await websocket.close(code=1008, reason="连接数已达上限")
return

await websocket.accept()
try:
while True:
data = await websocket.receive_json()

if not rate_limiter.check_rate_limit(client_id):
await websocket.send_json({
"type": "error",
"content": "请求过于频繁,请稍后再试"
})
continue

# 处理消息...
await websocket.send_json({
"type": "echo",
"content": data.get("content", "")
})
except WebSocketDisconnect:
pass

五、断线重连与会话恢复

5.1 服务端会话存储

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import redis.asyncio as aioredis
import json

class SessionStore:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis = None
self.redis_url = redis_url

async def init(self):
self.redis = await aioredis.from_url(self.redis_url)

async def save_session(self, session_id: str, messages: list[dict], ttl: int = 3600):
"""保存会话,1 小时后自动过期"""
await self.redis.setex(
f"session:{session_id}",
ttl,
json.dumps(messages)
)

async def load_session(self, session_id: str) -> list[dict]:
data = await self.redis.get(f"session:{session_id}")
return json.loads(data) if data else []

async def append_message(self, session_id: str, message: dict):
"""追加消息到会话"""
messages = await self.load_session(session_id)
messages.append(message)
await self.save_session(session_id, messages)

session_store = SessionStore()

@app.on_event("startup")
async def startup():
await session_store.init()

@app.websocket("/chat/{session_id}")
async def chat_with_reconnect(websocket: WebSocket, session_id: str):
await websocket.accept()

# 恢复历史会话
history = await session_store.load_session(session_id)
if history:
await websocket.send_json({
"type": "history",
"content": history
})

try:
while True:
data = await websocket.receive_json()

if data["type"] == "user_message":
# 保存用户消息
await session_store.append_message(session_id, {
"role": "user",
"content": data["content"]
})

# 处理并回复...
await websocket.send_json({
"type": "ai_message",
"content": f"收到: {data['content']}"
})
except WebSocketDisconnect:
print(f"会话 {session_id} 断开,上下文已保存")

5.2 客户端重连逻辑(JavaScript 示例)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class ChatClient {
constructor(sessionId) {
this.sessionId = sessionId;
this.ws = null;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 10;
this.reconnectDelay = 1000; // 初始 1 秒
}

connect() {
const url = `ws://localhost:8000/chat/${this.sessionId}`;
this.ws = new WebSocket(url);

this.ws.onopen = () => {
console.log('连接成功');
this.reconnectAttempts = 0;
this.reconnectDelay = 1000;
};

this.ws.onmessage = (event) => {
const msg = JSON.parse(event.data);
if (msg.type === 'history') {
console.log('恢复历史会话:', msg.content);
}
};

this.ws.onclose = (event) => {
if (!event.wasClean) {
this.reconnect();
}
};

this.ws.onerror = (error) => {
console.error('WebSocket 错误:', error);
};
}

reconnect() {
if (this.reconnectAttempts >= this.maxReconnectAttempts) {
console.error('重连次数已达上限');
return;
}

const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts);
console.log(`将在 ${delay}ms 后重连...`);

setTimeout(() => {
this.reconnectAttempts++;
this.connect();
}, delay);

// 指数退避,最大 30 秒
this.reconnectDelay = Math.min(this.reconnectDelay * 2, 30000);
}

send(message) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify(message));
}
}
}

六、生产部署

6.1 使用 Gunicorn + Uvicorn Worker

1
2
3
4
5
6
7
8
pip install gunicorn uvicorn

# 多 worker 运行(注意:WebSocket 需要 sticky session)
gunicorn -k uvicorn.workers.UvicornWorker \
--workers 4 \
--bind 0.0.0.0:8000 \
--timeout 120 \
main:app

6.2 Nginx 反向代理配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
upstream fastapi_ws {
# 需要 sticky session 保持 WebSocket 连接到同一 worker
ip_hash;
server 127.0.0.1:8001;
server 127.0.0.1:8002;
server 127.0.0.1:8003;
server 127.0.0.1:8004;
}

server {
listen 443 ssl;
server_name api.example.com;

ssl_certificate /etc/nginx/ssl/cert.pem;
ssl_certificate_key /etc/nginx/ssl/key.pem;

location /ws/ {
proxy_pass http://fastapi_ws;
proxy_http_version 1.1;

# WebSocket 必需头
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;

# 超时设置(长连接)
proxy_read_timeout 86400s;
proxy_send_timeout 86400s;
}

location /api/ {
proxy_pass http://fastapi_ws;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
}

6.3 Docker Compose 部署

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
version: '3.8'

services:
app:
build: .
command: gunicorn -k uvicorn.workers.UvicornWorker -w 4 -b 0.0.0.0:8000 main:app --timeout 120
ports:
- "8000:8000"
environment:
- REDIS_URL=redis://redis:6379
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY}
depends_on:
- redis
restart: unless-stopped

redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
restart: unless-stopped

volumes:
redis_data:

七、完整示例:AI 对话引擎

将以上所有知识点整合为一个完整的 AI 对话 WebSocket 服务:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import asyncio
import json
import uuid
from typing import Optional

import httpx
import redis.asyncio as aioredis
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app = FastAPI()

# ========== 配置 ==========
DEEPSEEK_API_KEY = "your-api-key" # 从环境变量读取
MODEL = "deepseek-chat"
MAX_CONNECTIONS = 50
RATE_LIMIT = 10 # 每秒最多消息数

# ========== 连接管理 ==========
class ConnectionManager:
def __init__(self):
self.connections: dict[str, WebSocket] = {}
self.rate_limits: dict[str, list[float]] = {}

async def connect(self, client_id: str, websocket: WebSocket):
if len(self.connections) >= MAX_CONNECTIONS:
await websocket.close(code=1008, reason="服务器繁忙")
return False
await websocket.accept()
self.connections[client_id] = websocket
return True

def disconnect(self, client_id: str):
self.connections.pop(client_id, None)
self.rate_limits.pop(client_id, None)

def check_rate_limit(self, client_id: str) -> bool:
import time
now = time.time()
timestamps = self.rate_limits.get(client_id, [])
timestamps = [t for t in timestamps if now - t < 1]
if len(timestamps) >= RATE_LIMIT:
return False
timestamps.append(now)
self.rate_limits[client_id] = timestamps
return True

manager = ConnectionManager()

# ========== 会话存储 ==========
class SessionStore:
def __init__(self):
self.redis = None

async def init(self):
self.redis = await aioredis.from_url("redis://localhost:6379")

async def get_messages(self, session_id: str) -> list[dict]:
data = await self.redis.get(f"session:{session_id}")
return json.loads(data) if data else []

async def add_message(self, session_id: str, message: dict):
messages = await self.get_messages(session_id)
messages.append(message)
# 只保留最近 50 条消息作为上下文
messages = messages[-50:]
await self.redis.setex(f"session:{session_id}", 7200, json.dumps(messages))

store = SessionStore()

# ========== LLM 流式调用 ==========
async def stream_llm(websocket: WebSocket, messages: list[dict]):
"""流式调用 LLM 并逐 chunk 发送给客户端"""
full_content = ""

async with httpx.AsyncClient(timeout=120) as client:
async with client.stream(
"POST",
"https://api.deepseek.com/v1/chat/completions",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
json={"model": MODEL, "messages": messages, "stream": True}
) as response:
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue

data_str = line[6:].strip()
if data_str == "[DONE]":
await websocket.send_json({
"type": "ai_stream_end",
"session_id": messages[0].get("session_id", "")
})
return full_content

try:
chunk = json.loads(data_str)
delta = chunk["choices"][0]["delta"]
content = delta.get("content", "")
if content:
full_content += content
await websocket.send_json({
"type": "ai_stream_chunk",
"content": content
})
except json.JSONDecodeError:
continue

return full_content

# ========== WebSocket 端点 ==========
@app.websocket("/ai-chat/{session_id}")
async def ai_chat(websocket: WebSocket, session_id: str):
client_id = str(uuid.uuid4())[:8]

if not await manager.connect(client_id, websocket):
return

# 恢复历史会话
history = await store.get_messages(session_id)
if history:
await websocket.send_json({
"type": "history",
"messages": history
})

system_prompt = {
"role": "system",
"content": "你是一个友好的 AI 助手。请用中文回复,保持对话自然流畅。"
}
messages = [system_prompt] + history

try:
while True:
data = await websocket.receive_json()

if not manager.check_rate_limit(client_id):
await websocket.send_json({
"type": "error",
"content": "消息发送太频繁,请稍后再试"
})
continue

if data["type"] == "user_message":
user_msg = {"role": "user", "content": data["content"]}
messages.append(user_msg)
await store.add_message(session_id, user_msg)

# 发送 typing 指示
await websocket.send_json({"type": "typing"})

# 流式获取 AI 回复
full_reply = await stream_llm(websocket, messages)

# 保存 AI 回复到上下文
ai_msg = {"role": "assistant", "content": full_reply}
messages.append(ai_msg)
await store.add_message(session_id, ai_msg)

elif data["type"] == "ping":
await websocket.send_json({"type": "pong"})

except WebSocketDisconnect:
manager.disconnect(client_id)
print(f"客户端 {client_id} 断开,会话 {session_id} 已保存")

@app.on_event("startup")
async def startup():
await store.init()

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

八、测试与调试

8.1 使用 Python 测试客户端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import asyncio
import websockets
import json

async def test_chat():
async with websockets.connect("ws://localhost:8000/ai-chat/test-session") as ws:
# 发送消息
await ws.send(json.dumps({
"type": "user_message",
"content": "你好,请介绍一下你自己"
}))

# 接收流式回复
full_response = ""
while True:
msg = json.loads(await ws.recv())
if msg["type"] == "ai_stream_chunk":
full_response += msg["content"]
print(msg["content"], end="", flush=True)
elif msg["type"] == "ai_stream_end":
print("\n--- 回复完成 ---")
break
elif msg["type"] == "typing":
print("(AI 正在思考...)", flush=True)

asyncio.run(test_chat())

8.2 压力测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 安装 websocket 压测工具
pip install websocket-client

# 使用 Python 脚本并发测试
python -c "
import asyncio
import websockets
import json

async def stress_test():
tasks = []
for i in range(10):
tasks.append(single_client(i))
await asyncio.gather(*tasks)

async def single_client(client_id):
try:
async with websockets.connect(f'ws://localhost:8000/ai-chat/test-{client_id}', timeout=5) as ws:
await ws.send(json.dumps({'type': 'user_message', 'content': '你好'}))
async for msg in ws:
data = json.loads(msg)
if data['type'] == 'ai_stream_end':
break
print(f'客户端 {client_id} 完成')
except Exception as e:
print(f'客户端 {client_id} 失败: {e}')

asyncio.run(stress_test())
"

九、常见问题

Q:WebSocket 连接频繁断开怎么办?

检查以下几点:

  1. Nginx 代理超时设置:proxy_read_timeoutproxy_send_timeout 设置足够大(建议 86400s)
  2. 客户端实现心跳机制,每 30 秒发送 ping
  3. 检查防火墙是否拦截了长连接

Q:多 worker 下 WebSocket 连接不稳定?

WebSocket 是有状态连接,多 worker 模式下需要 sticky session:

  • 使用 ip_hashsticky 指令
  • 或使用 Redis Pub/Sub 跨 worker 广播消息

Q:流式响应中如何控制并发?

1
2
3
4
5
6
7
8
import asyncio

# 使用 asyncio.Semaphore 控制并发 LLM 调用
llm_semaphore = asyncio.Semaphore(5) # 最多 5 个并发

async def safe_stream_llm(websocket, messages):
async with llm_semaphore:
return await stream_llm(websocket, messages)

Q:如何监控 WebSocket 连接状态?

1
2
3
4
5
6
7
8
from fastapi import Request
from prometheus_client import Counter, Gauge

ws_connections = Gauge('ws_active_connections', '当前 WebSocket 连接数')
ws_messages = Counter('ws_messages_total', '消息总数', ['type'])

# 在 connect/disconnect 时更新指标
# ws_connections.inc() / ws_connections.dec()

Q:WebSocket 和 SSE 怎么选?

特性 WebSocket SSE (Server-Sent Events)
通信方向 双向 仅服务端→客户端
协议 ws:// HTTP
浏览器支持 全支持 全支持(除 IE)
自动重连 需手动实现 内置
适用场景 对话、游戏、协作 通知、进度推送

对于 AI 对话,推荐 WebSocket,因为需要用户发送消息和接收流式回复的双向通信。


FastAPI + WebSocket 是构建 AI 实时对话应用的黄金组合。掌握了本文的内容,你就能搭建一个生产级的 AI 对话后端,支持流式响应、断线重连、多客户端管理和限流保护。