" name="sm-site-verification"/>
侧边栏壁纸
博主头像
PySuper 博主等级

千里之行,始于足下

  • 累计撰写 247 篇文章
  • 累计创建 15 个标签
  • 累计收到 2 条评论

目 录CONTENT

文章目录
Web

Tornado 实现本地代理(HTTP + WebSocket)

PySuper
2025-05-01 / 0 评论 / 0 点赞 / 15 阅读 / 0 字
温馨提示:
所有牛逼的人都有一段苦逼的岁月。 但是你只要像SB一样去坚持,终将牛逼!!! ✊✊✊

1、业务流程

这里的 Local Chrome 表示客户端

该客户端也需要实现 HTTP、WebSocket

  • 客户端向 Tornado Server 发送 HTTP 请求,不管是提交请求还是上传文件,让服务端感知到用户操作

  • 服务端接收到用户操作的数据后,调用处理函数(LocalHost时使用函数调用)

  • 后处理完成再向服务端发送 HTTP 请求,不管是结果数据还是文件,让服务端感知到 后处理已结束

  • 服务端接收到结果,再通过WebSocket将数据返回给客户端

2、HTTP Server

通过日志级别,调整控制台输出的内容(debug、info);

这只是一个基类,封装了常用操作,具体业务需要根据自身项目,通过继承和重写实现

如果是单一功能,可以只用一个服务类,但如果请求较多,需要在Tornado启动文件中定义更多路由,也就需要更多的服务类,这样可以根据路由,实现RestFull API

"""
@Project :Titan
@File    :http.py
@Author  :PySuper
@Date    :2025/5/1 12:12
@Desc    :Titan http.py
"""

import datetime
import json
import traceback
import uuid
from contextvars import ContextVar
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Optional

import tornado.web
import tornado.websocket

# 修改导入
from config.files import UPLOAD_DIR
from logic.config import get_logger

# 创建请求ID上下文变量
request_id_var = ContextVar("request_id", default=None)


# 为loguru创建过滤器函数
def request_id_filter(record):
    """为日志记录添加请求ID"""
    request_id = request_id_var.get()
    record["extra"].update({"request_id": request_id or "no-request-id"})
    return True


# 获取配置了过滤器的logger
logger = get_logger("proxy")

# 使用loguru的方式添加过滤器
logger = logger.bind(request_id="no-request-id")

from enum import Enum


class ResponseStatus(str, Enum):
    """响应状态枚举"""

    SUCCESS = "success"  # 成功
    ERROR = "error"  # 错误
    WARNING = "warning"  # 警告
    INFO = "info"  # 信息


@dataclass
class SavedFile:
    """保存的文件信息"""

    original_name: str  # 原始文件名
    saved_name: str  # 保存后的文件名
    file_path: str  # 文件路径
    file_size: int  # 文件大小
    content_type: str  # 文件类型
    upload_time: str  # 上传时间
    unique_id: str  # 唯一ID


class JSONEncoder(json.JSONEncoder):
    """扩展的JSON编码器,支持更多Python类型"""

    def default(self, obj):
        """
        扩展JSON编码器的默认行为,支持更多Python类型
        :param obj: 要编码的对象
        :return: 编码后的对象
        """
        if isinstance(obj, datetime):
            return obj.isoformat()
        if hasattr(obj, "__dict__"):
            return obj.__dict__
        return super().default(obj)


class CustomHttp(tornado.web.RequestHandler):
    """
    自定义HTTP请求处理类,提供通用的请求处理功能

    该类扩展了Tornado的RequestHandler,提供了以下增强功能:
    - 自动请求日志记录
    - 统一的错误和成功响应格式
    - 跨域支持
    - 文件上传处理
    - WebSocket通知
    - 参数提取和验证
    - 异常处理
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # 初始化请求相关属性
        self.client_ip = self.request.remote_ip
        self.start_time = datetime.datetime.now()
        self._request_id = str(uuid.uuid4())

        # 设置请求ID上下文
        request_id_var.set(self._request_id)

        # 为当前请求绑定请求ID到logger
        global logger
        logger = logger.bind(request_id=self._request_id)

        # 设置默认HTTP响应头
        self.set_default_headers()

        # 记录请求日志
        logger.debug(f"收到请求==> 来自 {self.client_ip} - {self.request.method} {self.request.uri}")

    def http_err(self, msg: str, status_code: int = 400) -> None:
        """
        返回HTTP错误响应

        :param msg: 错误信息
        :param status_code: HTTP状态码,默认为400
        :return: None
        """
        logger.error(f"HTTP错误 [{self._request_id}]: {msg}")
        self.set_status(status_code)
        self.write({"status": ResponseStatus.ERROR, "message": msg, "request_id": self._request_id})
        return None

    def http_success(self, data: Any = None, message: str = "操作成功") -> None:
        """
        返回HTTP成功响应
        :param data: 响应数据
        :param message: 成功消息
        :return: None
        """
        response = {"status": ResponseStatus.SUCCESS, "message": message, "request_id": self._request_id}

        if data is not None:
            response["data"] = data

        self.write(response)
        return None

    def ws_err(self, msg: str) -> None:
        """
        返回WebSocket错误响应

        :param msg: 错误信息
        :return: None
        """
        logger.error(f"WebSocket错误 [{self._request_id}]: {msg}")
        self.write({"status": "error", "message": msg, "request_id": self._request_id})
        return None

    def set_default_headers(self) -> None:
        """
        设置跨域和安全相关的HTTP响应头
        :return: None
        """
        super().set_default_headers()

        # 跨域相关设置
        origin = self.request.headers.get("Origin", "*")
        self.set_header("Access-Control-Allow-Origin", origin)
        self.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        self.set_header("Access-Control-Allow-Credentials", "true")
        self.set_header("Access-Control-Max-Age", "3600")  # 预检请求缓存时间

        # 合并请求头
        default_headers = "Content-Type, Content-Length, Authorization, Accept, X-Requested-With, X-File-Name, Cache-Control, devicetype"
        request_headers = self.request.headers.get("Access-Control-Request-Headers", "")
        allowed_headers = default_headers + (f", {request_headers}" if request_headers else "")
        self.set_header("Access-Control-Allow-Headers", allowed_headers)

        # 安全相关设置
        self.set_header("X-XSS-Protection", "1; mode=block")
        self.set_header("X-Content-Type-Options", "nosniff")
        self.set_header("X-Frame-Options", "SAMEORIGIN")
        self.set_header("Content-Security-Policy", "default-src * 'self' 'unsafe-inline' 'unsafe-eval' data: blob:")

        # 内容类型设置
        content_type = "text/plain" if self.request.method == "OPTIONS" else "application/json; charset=UTF-8"
        self.set_header("Content-Type", content_type)

    def options(self, *args, **kwargs) -> None:
        """
        处理OPTIONS请求,设置跨域和安全相关的HTTP响应头
        """
        # 设置响应状态码为204,表示请求已成功处理,但没有响应体
        self.set_status(204)
        self.finish()

    def get_ws(self) -> Optional[tornado.websocket.WebSocketHandler]:
        """
        获取WebSocket处理器

        :return: WebSocket处理器实例或None
        """
        ws_handler_map = self.application.settings.setdefault("ws_handler_map", {})
        return ws_handler_map.get("websocket_id")

    def save(self, files: List[Dict[str, Any]]) -> List[SavedFile]:
        """
        保存上传的文件
        :param files: 上传的文件列表
        :return: 保存的文件信息列表
        """
        saved_files = []

        # 确保上传目录存在
        upload_path = Path(UPLOAD_DIR)
        upload_path.mkdir(parents=True, exist_ok=True)

        for file_data in files:
            original_filename = file_data["filename"]
            file_content = file_data["body"]
            content_type = file_data["content_type"]

            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            unique_id = str(uuid.uuid4())[:8]
            filename = f"{timestamp}_{unique_id}_{original_filename}"
            file_path = upload_path / filename

            try:
                file_path.write_bytes(file_content)

                file_info = SavedFile(
                    original_name=original_filename,
                    saved_name=filename,
                    file_path=str(file_path),
                    file_size=len(file_content),
                    content_type=content_type,
                    upload_time=timestamp,
                    unique_id=unique_id,
                )
                saved_files.append(file_info)
                logger.info(f"文件已保存: {file_path}")
            except IOError as e:
                error_msg = f"保存文件 {original_filename} 时出错: {str(e)}"
                logger.error(error_msg)
                logger.debug(traceback.format_exc())

        return saved_files

    def notify_ws(self, data: Dict[str, Any]) -> bool:
        """
        向WebSocket客户端发送通知
        :param data: 通知数据
        :return: 发送是否成功
        """
        try:
            ws = self.get_ws()
            if ws:
                # 添加时间戳和请求ID
                data.update({"timestamp": datetime.datetime.now().isoformat(), "request_id": self._request_id})

                # 使用自定义编码器
                json_data = json.dumps(data, ensure_ascii=False, cls=JSONEncoder)

                if ws.send_message(json_data):
                    logger.debug(f"通知WebSocket客户端成功: {self._request_id}")
                    return True
                else:
                    logger.warning(f"通知WebSocket客户端失败: {self._request_id}")
                    return False
            else:
                logger.warning("未找到WebSocket连接")
                return False
        except Exception as e:
            logger.error(f"通知WebSocket客户端时出错: {str(e)}")
            logger.debug(traceback.format_exc())
            return False

    def extract_params(self) -> Dict[str, Any]:
        """
        提取请求参数,支持JSON和表单数据

        :return: 参数字典
        """
        params = {}

        # 尝试解析JSON请求体
        if self.request.body:
            try:
                request_data = json.loads(self.request.body)
                if isinstance(request_data, dict):
                    params.update(request_data)
            except json.JSONDecodeError:
                logger.warning(f"无法解析请求体为JSON [{self._request_id}],尝试从form数据获取参数")

        # 获取URL查询参数
        for name, values in self.request.arguments.items():
            if values and len(values) > 0:
                # 解码字节字符串为普通字符串
                decoded_value = values[0].decode("utf-8") if isinstance(values[0], bytes) else values[0]
                params[name] = decoded_value

        # 尝试从表单数据获取参数
        for field_name, field_value in self.request.arguments.items():
            if field_name not in params or not params[field_name]:
                value = self.get_argument(field_name, None)
                if value:
                    params[field_name] = value

        logger.debug(f"提取的参数: {params}")
        return params

    def validate_params(self, required_params: List[str]) -> Optional[Dict[str, Any]]:
        """
        验证请求参数是否包含所有必需的参数
        :param required_params: 必需参数列表
        :return 参数字典或None(如果验证失败)
        """
        params = self.extract_params()
        # 更Pythonic的列表推导式
        missing_params = [param for param in required_params if param not in params or not params[param]]

        if missing_params:
            self.http_err(f"缺少必需参数: {', '.join(missing_params)}", 400)
            return None

        return params

    def handle_exception(self, func):
        """
        异常处理装饰器

        :param func: 要装饰的函数
        :return: 装饰后的函数
        """

        @wraps(func)
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as e:
                error_msg = f"处理请求时发生错误: {str(e)}"
                logger.error(error_msg)
                logger.debug(traceback.format_exc())
                self.http_err(error_msg, 500)
                return None

        return wrapper

    @property
    def request_id(self) -> str:
        """获取当前请求ID"""
        return self._request_id

    @property
    def request_duration(self) -> float:
        """获取请求持续时间(秒)"""
        return (datetime.datetime.now() - self.start_time).total_seconds()

3、WebSocket Server

同样这也是一个基类,封装了常见的功能,继承后可以直接使用;

在WebSocketHandler的基础上增加了以下功能:

  • 模式选择:多客户端、单客户端

  • 健康检查:心跳检测、连接维护

  • 权限校验:Token、IP白名单、

  • 消息模式:广播、私聊

这里代码有些臃肿,后面有时间的时候,再回来调整、优化、整理

"""
@Project :Titan
@File    :wss.py
@Author  :PySuper
@Date    :2025/5/1 12:17
@Desc    :Titan WebSocket服务器实现
"""

import asyncio
import datetime
import json
import uuid
from typing import Dict, Any

import tornado.web
import tornado.websocket

from logic.config import get_logger

logger = get_logger("proxy-server")


class CustomWebSocket(tornado.websocket.WebSocketHandler):
    """
    自定义WebSocket处理类,提供高性能的WebSocket通信功能

    特性:
    - 支持多连接模式和单连接模式
    - 自动心跳检测和连接维护
    - 消息队列异步处理
    - 身份验证机制
    - 消息广播与私聊
    - 连接状态监控
    """

    # 消息类型处理映射
    MESSAGE_TYPES = {
        "auth": "_handle_auth",
        "message": "_handle_message",
        "event": "_handle_event",
        "command": "_handle_command",
        "ping": "_handle_ping",
    }

    # 状态响应类型
    STATUS_TYPES = ["success", "error", "warning", "info", "received"]

    def __init__(self, application, request, *args, **kwargs):
        """初始化WebSocket连接处理器"""
        super().__init__(application, request)

        # 连接标识和会话信息
        self.client_id: str = ""
        self._request_id: str = str(uuid.uuid4())
        self.client_ip: str = ""
        self.user_data: Dict[str, Any] = {}
        self.is_authenticated: bool = False

        # 连接统计数据
        self.connected_at: datetime.datetime = ""
        self.start_time: datetime.datetime = ""
        self.bytes_received: int = 0
        self.bytes_sent: int = 0
        self.message_count: int = 0

        # 消息处理队列
        self.message_queue: asyncio.Queue = asyncio.Queue()
        self.is_processing_queue: bool = False

        # 心跳检测
        self._ping_interval: int = 30
        self._ping_timeout: int = 60
        self.last_pong: datetime.datetime = datetime.datetime.now()
        self.heartbeat_future: asyncio.Task = ""

        # 任务控制
        self.frame_task: asyncio.Task = ""
        self.stop_tasks: bool = False
        self.is_paused: bool = True

        # 连接模式 (True: 允许多连接, False: 单连接模式)
        self.holdon: bool = True

    def initialize(self, *args, **kwargs):
        """初始化连接配置和状态"""
        # 注册到WebSocket处理器集合
        if "ws_handlers" not in self.application.settings:
            self.application.settings["ws_handlers"] = []
        self.application.settings["ws_handlers"].append(self)

        # 设置连接模式
        self.holdon = kwargs.get("holdon", True)  # 默认为允许多连接

        # 初始化连接状态
        self.is_authenticated = False  # 默认已认证
        self.user_data = {}  # 默认用户数据为空

        # 配置心跳检测参数
        self.heartbeat_enabled = kwargs.get("heartbeat_enabled", False)  # 是否启用心跳检测
        self._ping_interval = kwargs.get("ping_interval", 30)  # 心跳检测间隔
        self._ping_timeout = kwargs.get("ping_timeout", 60)  # 心跳检测超时
        self.last_pong = datetime.datetime.now()  # 最后一次收到心跳回复的时间

        # 初始化任务控制状态
        self.stop_tasks = False  # 任务停止标志
        self.is_paused = True  # 任务暂停标志

        # 初始化统计数据
        self.message_count = 0  # 消息计数
        self.bytes_received = 0  # 接收字节数
        self.bytes_sent = 0  # 发送字节数
        self.connected_at = datetime.datetime.now()  # 连接时间

        # 设置连接基本信息
        self.client_ip = self.request.remote_ip  # 客户端IP地址
        self.start_time = datetime.datetime.now()  # 连接开始时间
        self._request_id = str(uuid.uuid4())  # 请求ID

        # 记录连接日志
        # logger.info(f"WebSocket Client: 🔌 [ID: {self._request_id}] 来自 {self.client_ip}")

    # 取绑定了当前请求ID的logger
    def _log(self):
        """获取绑定了当前请求ID的logger"""
        return logger.bind(request_id=self._request_id)

    # 连接建立事件
    def open(self):
        """处理WebSocket连接建立事件"""
        # 根据连接模式处理连接
        self._handle_multi_connection() if self.holdon else self._handle_single_connection()

        # 启动心跳检测
        if self.heartbeat_enabled:
            self._start_heartbeat()

        # 启动消息队列处理
        asyncio.create_task(self._process_message_queue())
        # logger.debug(f"ws连接初始化: [ID: {self.client_id}]")

    # 处理多连接模式 - 保留所有连接
    def _handle_multi_connection(self):
        """处理多连接模式 - 保留所有连接"""
        # 生成唯一客户端ID
        self.client_id = str(uuid.uuid4())

        # 注册连接到连接映射表
        ws_map = self.application.settings.setdefault("ws_handler_map", {})
        ws_map[self.client_id] = self

        # 保持向后兼容
        if "websocket_id" not in ws_map:
            ws_map["websocket_id"] = self

        # 活跃连接数 (减1是因为包含websocket_id键)
        active_connections = len(ws_map) - 1
        logger.info(f"ws新客户端: ID={self.client_id}, 当前连接数={active_connections}")

    # 处理单连接模式 - 保留最新连接
    def _handle_single_connection(self):
        """处理单连接模式 - 保留最新连接"""
        try:
            # 获取连接映射表
            ws_map = self.application.settings.setdefault("ws_handler_map", {})

            # 关闭已存在的连接
            self._close_existing_connection(ws_map)

            # 建立新连接
            self.client_id = str(uuid.uuid4())
            ws_map["websocket_id"] = self

            # 记录连接信息
            client_info = {
                "id": self.client_id,
                "ip": self.request.remote_ip,
                "time": datetime.datetime.now().isoformat(),
                "user_agent": self.request.headers.get("User-Agent", "未知"),
            }
            logger.debug(f"WebSocket已连接: {json.dumps(client_info, ensure_ascii=False)}")

            # 发送欢迎消息
            welcome = {
                "type": "connection_established",
                "client_id": self.client_id,
                "server_time": datetime.datetime.now().isoformat(),
                "message": "WebSocket已连接",
            }
            self._async_write_message(welcome)

        except Exception as e:
            logger.error(f"建立连接时出错: {str(e)}", exc_info=True)
            self.close(code=1011, reason="连接初始化失败")

    # 关闭已存在的连接
    def _close_existing_connection(self, ws_map):
        """关闭已存在的连接"""
        old_handler = ws_map.get("websocket_id")
        if not old_handler:
            return

        logger.debug(f"关闭旧连接: ID={getattr(old_handler, 'client_id', 'unknown')}")
        try:
            # 停止任务
            old_handler.stop_tasks = True

            # 取消帧任务
            if hasattr(old_handler, "frame_task") and old_handler.frame_task:
                old_handler.frame_task.cancel()

            # 取消心跳任务
            if hasattr(old_handler, "heartbeat_future") and old_handler.heartbeat_future:
                old_handler.heartbeat_future.cancel()

            # 发送关闭通知
            close_msg = {
                "type": "connection_closed",
                "reason": "新的客户端连接已建立,旧连接已关闭",
                "server_time": datetime.datetime.now().isoformat(),
            }
            old_handler.write_message(json.dumps(close_msg, ensure_ascii=False))

            # 关闭连接
            old_handler.close(code=1000, reason="新的客户端连接已建立,旧连接已关闭")

        except Exception as e:
            logger.warning(f"关闭旧连接时出错: {str(e)}")

    # 接收处理WebSocket消息
    def on_message(self, message: str):
        """处理接收到的WebSocket消息"""
        # 消息为空检查
        if not message:
            return self._send_error("收到空消息")

        try:
            # 更新统计信息
            self.message_count += 1
            self.bytes_received += len(message) if hasattr(message, "__len__") else 0
            self.last_pong = datetime.datetime.now()

            # 解析消息
            data = tornado.escape.json_decode(message)
            if not isinstance(data, dict):
                return self._send_error("无效的消息格式")

            # 记录消息
            logger.debug(f"收到消息: {json.dumps(data, ensure_ascii=False)}")

            # 处理心跳消息
            if data.get("type") == "ping" or data.get("action") == "heartbeat":
                return self._handle_ping(data)

            # 将消息加入队列异步处理
            asyncio.create_task(self.message_queue.put(data))
            return None

        except json.JSONDecodeError:
            logger.error("接收到无效的JSON格式消息")
            return self._send_error("无效的JSON格式")
        except Exception as e:
            logger.error(f"处理消息时出错: {str(e)}", exc_info=True)
            return self._send_error(f"处理消息时出错: {str(e)}")

    # TODO:处理消息队列
    async def _process_message_queue(self):
        """处理消息队列"""
        if self.is_processing_queue:
            return

        self.is_processing_queue = True
        log = logger.bind(request_id=self._request_id)  # 使用绑定了request_id的logger
        try:
            while not self.stop_tasks:
                try:
                    # 等待新消息,超时1秒
                    data = await asyncio.wait_for(self.message_queue.get(), timeout=1.0)

                    # 检查连接状态
                    if not self.ws_connection:
                        log.warning("WebSocket连接已关闭,停止处理消息队列")
                        break

                    # 处理消息
                    if data is not None:
                        await self._process_message(data)

                    # 标记任务完成
                    self.message_queue.task_done()

                except asyncio.TimeoutError:
                    # 超时,继续下一次循环
                    continue
                except Exception as e:
                    log.error(f"处理队列消息时出错: {str(e)}", exc_info=True)
                    # 尝试标记任务完成
                    try:
                        self.message_queue.task_done()
                    except Exception:
                        pass
        finally:
            self.is_processing_queue = False

    # 处理单条消息
    async def _process_message(self, data: Dict[str, Any]):
        """处理单条消息"""
        try:
            # 数据类型检查
            if not isinstance(data, dict):
                await self._send_response("error", "无效的消息格式")
                return

            # 处理action字段(兼容不同格式的消息)
            if "action" in data and "type" not in data:
                action_type_map = {
                    "heartbeat": "ping",
                    "auth": "auth",
                    "message": "message",
                    "command": "command",
                    "event": "event",
                }
                data["type"] = action_type_map.get(data["action"], data["action"])

            # 确保消息类型存在
            msg_type = data.get("type")
            if not msg_type:
                await self._send_response("received", "消息已收到")
                return

            # 身份验证特殊处理
            if msg_type == "auth" and not self.is_authenticated:
                await self._handle_auth(data)
                return

            # 心跳消息特殊处理
            if msg_type == "ping" or data.get("action") == "heartbeat":
                await self._handle_ping(data)
                return

            # 其他消息类型需要身份验证
            if not self.is_authenticated and msg_type not in ["ping"]:
                await self._send_response("error", "请先进行身份验证", type="auth_required")
                return

            # 根据类型分发处理
            handler_name = self.MESSAGE_TYPES.get(msg_type)
            if handler_name and hasattr(self, handler_name):
                await getattr(self, handler_name)(data)
            else:
                logger.warning(f"未知的消息类型: {msg_type}")
                await self._send_response("warning", f"未知的消息类型: {msg_type}")

        except Exception as e:
            logger.error(f"处理消息时出错: {str(e)}", exc_info=True)
            await self._send_response("error", f"处理消息时出错: {str(e)}")

    # 处理身份验证请求
    async def _handle_auth(self, data: Dict[str, Any]):
        """处理身份验证请求"""
        try:
            token = data.get("token")
            if not token:
                await self._send_response("error", "缺少身份验证令牌", type="auth_failed")
                return

            # 简单的身份验证逻辑,实际应用中应实现更安全的机制
            if len(token) > 10:
                self.is_authenticated = True
                self.user_data = {
                    "auth_time": datetime.datetime.now().isoformat(),
                    "client_id": self.client_id,
                }
                await self._send_response("success", "身份验证成功", type="auth_success", user_data=self.user_data)
            else:
                await self._send_response("error", "无效的身份验证令牌", type="auth_failed")
        except Exception as e:
            logger.error(f"身份验证处理出错: {str(e)}", exc_info=True)
            await self._send_response("error", f"身份验证处理出错", type="auth_error")

    # 处理普通消息
    async def _handle_message(self, data: Dict[str, Any]):
        """处理普通消息"""
        try:
            content = data.get("content", "")
            target = data.get("target", "all")

            # 广播消息
            if target == "all":
                await self._broadcast_message(content, exclude_self=data.get("exclude_self", False))
                return

            # 私聊消息
            if target != "all" and target:
                await self._send_private_message(target, content)
                return

            await self._send_response("success", "消息已接收", type="message_received")
        except Exception as e:
            logger.error(f"消息处理出错: {str(e)}", exc_info=True)
            await self._send_response("error", "消息处理出错", type="message_error")

    # 处理事件通知
    async def _handle_event(self, data: Dict[str, Any]):
        """处理事件通知"""
        try:
            event_name = data.get("event_name")
            event_data = data.get("event_data", {})

            if not event_name:
                await self._send_response("error", "缺少事件名称", type="event_error")
                return

            # 记录事件
            logger.info(f"收到事件: {event_name}, 数据: {json.dumps(event_data, ensure_ascii=False)}")

            # 实际项目中可以在这里添加事件处理逻辑
            await self._send_response(
                "success",
                f"事件 {event_name} 已处理",
                type="event_processed",
                event_name=event_name,
            )
        except Exception as e:
            logger.error(f"事件处理出错: {str(e)}", exc_info=True)
            await self._send_response("error", "事件处理出错", type="event_error")

    # 处理命令执行
    async def _handle_command(self, data: Dict[str, Any]):
        """处理命令执行"""
        try:
            command = data.get("command")
            params = data.get("params", {})

            if not command:
                await self._send_response("error", "缺少命令", type="command_error")
                return

            # 记录命令
            logger.info(f"执行命令: {command}, 参数: {json.dumps(params, ensure_ascii=False)}")

            # 执行命令
            command_result = self._execute_command(command, params)

            if command_result:
                await self._send_response(
                    "success",
                    "命令已执行",
                    type="command_executed",
                    command=command,
                    result=command_result,
                )
            else:
                await self._send_response("error", f"未知命令: {command}", type="command_error")
        except Exception as e:
            logger.error(f"命令处理出错: {str(e)}", exc_info=True)
            await self._send_response("error", "命令处理出错", type="command_error")

    # 执行命令并返回结果
    def _execute_command(self, command: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
        """执行命令并返回结果"""
        command_handlers = {
            "stats": self.get_connection_stats,
            "clear_queue": self._clear_message_queue,
            "ping": lambda: {"ping_time": datetime.datetime.now().isoformat()},
            "status": self._get_status,
        }

        handler = command_handlers.get(command)
        if handler:
            return handler()
        return None

    # 获取服务器状态信息
    def _get_status(self) -> Dict[str, Any]:
        """获取服务器状态信息"""
        return {
            "status": "running",
            "uptime": (datetime.datetime.now() - self.connected_at).total_seconds(),
            "authenticated": self.is_authenticated,
            "message_count": self.message_count,
        }

    # 处理心跳消息
    async def _handle_ping(self, data: Dict[str, Any]):
        """处理心跳消息"""
        pong_data = {
            "type": "pong",
            "time": datetime.datetime.now().isoformat(),
            "server_id": self._request_id,
        }
        await self._safe_write_message(json.dumps(pong_data, ensure_ascii=False))
        return None

    # 启动心跳检测任务
    def _start_heartbeat(self):
        """启动心跳检测任务"""

        async def heartbeat_check():
            while not self.stop_tasks:
                try:
                    # 发送ping消息
                    ping_data = {"type": "ping", "time": datetime.datetime.now().isoformat()}
                    await self._safe_write_message(json.dumps(ping_data, ensure_ascii=False))

                    # 检查上次响应时间
                    elapsed = (datetime.datetime.now() - self.last_pong).total_seconds()

                    # 超时检测
                    if elapsed > self._ping_timeout:
                        logger.warning(f"WebSocket连接 {self.client_id} 心跳超时 ({elapsed}秒)")
                        self.close(code=1001, reason="心跳超时")
                        break

                    # 等待下一次心跳
                    await asyncio.sleep(self._ping_interval)
                except asyncio.CancelledError:
                    break
                except Exception as e:
                    logger.error(f"心跳检测出错: {str(e)}")
                    await asyncio.sleep(5)  # 出错后等待5秒重试

        self.heartbeat_future = asyncio.create_task(heartbeat_check())

    # 广播消息到所有连接
    async def _broadcast_message(self, content: str, exclude_self: bool = False):
        """广播消息到所有连接"""
        try:
            handlers = self.application.settings.get("ws_handlers", [])
            if not handlers:
                await self._send_response("warning", "没有可用的WebSocket连接")
                return

            broadcast_data = {
                "type": "broadcast",
                "content": content,
                "sender": self.client_id,
                "timestamp": datetime.datetime.now().isoformat(),
            }

            message_json = json.dumps(broadcast_data, ensure_ascii=False)
            sent_count = 0

            for handler in handlers:
                if exclude_self and handler == self:
                    continue

                if getattr(handler, "is_authenticated", False) and handler.ws_connection:
                    await handler._safe_write_message(message_json)
                    sent_count += 1

            await self._send_response(
                "success", f"消息已广播给 {sent_count} 个连接", type="broadcast_sent", recipients_count=sent_count
            )
        except Exception as e:
            logger.error(f"广播消息时出错: {str(e)}", exc_info=True)
            await self._send_response("error", "广播消息时出错")

    # 发送私聊消息
    async def _send_private_message(self, target_id: str, content: str):
        """发送私聊消息"""
        try:
            ws_map = self.application.settings.get("ws_handler_map", {})
            target_handler = ws_map.get(target_id)

            if not target_handler or not target_handler.ws_connection:
                await self._send_response("error", f"目标用户 {target_id} 不在线", type="message_error")
                return

            private_data = {
                "type": "private_message",
                "content": content,
                "sender": self.client_id,
                "timestamp": datetime.datetime.now().isoformat(),
            }

            await target_handler._safe_write_message(json.dumps(private_data, ensure_ascii=False))
            await self._send_response("success", "私聊消息已发送", type="message_sent", recipient=target_id)
        except Exception as e:
            logger.error(f"发送私聊消息时出错: {str(e)}", exc_info=True)
            await self._send_response("error", "发送私聊消息时出错")

    # 发送统一格式的响应(异步版本)
    async def _send_response(self, status: str, message: str, **kwargs):
        """发送统一格式的响应(异步版本)"""
        try:
            res = {"status": status, "message": message}
            res.update(kwargs)
            await self._safe_write_message(json.dumps(res, ensure_ascii=False))
        except Exception as e:
            logger.error(f"发送响应时出错: {str(e)}")

    # 发送错误响应(同步版本)
    def _send_error(self, message: str, **kwargs) -> None:
        """发送错误响应(同步版本)"""
        self._send_message({"status": "error", "message": message, **kwargs})
        return None

    # 发送消息到WebSocket客户端
    def _send_message(self, data: Dict[str, Any]) -> bool:
        """发送消息到WebSocket客户端"""
        try:
            if not self.ws_connection:
                logger.warning("尝试发送消息,但WebSocket连接已关闭")
                return False

            message = json.dumps(data, ensure_ascii=False)

            # 更新统计信息
            self.bytes_sent += len(message)

            self.write_message(message)
            return True
        except Exception as e:
            logger.error(f"发送WebSocket消息时出错: {str(e)}")
            return False

    # 异步发送JSON消息
    def _async_write_message(self, data: Dict[str, Any]):
        """异步发送JSON消息"""
        asyncio.create_task(self._safe_write_message(json.dumps(data, ensure_ascii=False)))

    # 安全地发送WebSocket消息
    async def _safe_write_message(self, message: str) -> bool:
        """安全地发送WebSocket消息"""
        try:
            if not self.ws_connection:
                logger.warning("尝试发送消息,但WebSocket连接已关闭")
                return False

            # 更新统计信息
            self.bytes_sent += len(message)

            self.write_message(message)
            return True
        except tornado.websocket.WebSocketClosedError:
            logger.warning("WebSocket连接已关闭,无法发送消息")
            return False
        except Exception as e:
            logger.error(f"发送WebSocket消息时出错: {str(e)}")
            return False

    # 获取连接统计信息
    def get_connection_stats(self) -> Dict[str, Any]:
        """获取连接统计信息"""
        try:
            if not self.connected_at:
                return {
                    "client_id": self.client_id,
                    "ip": self.client_ip,
                    "connected_at": "未连接",
                    "uptime_seconds": 0,
                    "messages_received": self.message_count,
                    "bytes_received": self.bytes_received,
                    "bytes_sent": self.bytes_sent,
                    "authenticated": self.is_authenticated,
                    "queue_size": self.message_queue.qsize() if hasattr(self, "message_queue") else 0,
                }

            uptime = (datetime.datetime.now() - self.connected_at).total_seconds()
            connected_at_iso = self.connected_at.isoformat() if self.connected_at else "未知"

            return {
                "client_id": self.client_id or "未知",
                "ip": self.client_ip or "未知",
                "connected_at": connected_at_iso,
                "uptime_seconds": uptime,
                "messages_received": self.message_count,
                "bytes_received": self.bytes_received,
                "bytes_sent": self.bytes_sent,
                "authenticated": self.is_authenticated,
                "queue_size": self.message_queue.qsize() if hasattr(self, "message_queue") else 0,
            }
        except Exception as e:
            # 出现异常时返回基本信息
            logger.error(f"获取连接统计信息出错: {str(e)}")
            return {
                "error": "获取连接统计信息失败",
                "client_id": getattr(self, "client_id", "未知"),
                "exception": str(e),
            }

    # 清空消息队列
    def _clear_message_queue(self) -> Dict[str, Any]:
        """清空消息队列"""
        try:
            # 安全地清空队列
            cleared_count = 0
            if hasattr(self, "message_queue") and self.message_queue:
                while not self.message_queue.empty():
                    try:
                        self.message_queue.get_nowait()
                        self.message_queue.task_done()
                        cleared_count += 1
                    except Exception as e:
                        logger.warning(f"清空单个消息时出错: {str(e)}")
                        break

            return {"status": "success", "message": "消息队列已清空", "cleared_messages": cleared_count}
        except Exception as e:
            logger.error(f"清空消息队列时出错: {str(e)}")
            return {"status": "error", "message": f"清空消息队列时出错: {str(e)}"}

    # 处理WebSocket连接关闭事件
    def on_close(self):
        """处理WebSocket连接关闭事件"""
        log = logger.bind(request_id=getattr(self, "_request_id", "unknown"))

        try:
            # 停止所有任务
            self.stop_tasks = True

            # 取消心跳任务
            if hasattr(self, "heartbeat_future") and self.heartbeat_future and not self.heartbeat_future.done():
                self.heartbeat_future.cancel()

            # 取消帧任务
            if hasattr(self, "frame_task") and self.frame_task and not self.frame_task.done():
                self.frame_task.cancel()

            # 清理消息队列
            self._clear_message_queue()

            # 从处理器映射中移除
            if (
                hasattr(self, "application")
                and hasattr(self.application, "settings")
                and "ws_handler_map" in self.application.settings
            ):
                ws_map = self.application.settings["ws_handler_map"]

                # 移除全局映射
                if ws_map.get("websocket_id") == self:
                    ws_map.pop("websocket_id", None)

                # 移除客户端ID映射
                if hasattr(self, "client_id") and self.client_id in ws_map:
                    ws_map.pop(self.client_id, None)

            # 从处理器列表中移除
            if (
                hasattr(self, "application")
                and hasattr(self.application, "settings")
                and "ws_handlers" in self.application.settings
            ):
                handlers = self.application.settings["ws_handlers"]
                if self in handlers:
                    handlers.remove(self)

            # 记录连接统计
            try:
                stats = self.get_connection_stats()
                log.info(f"WebSocket连接已断开 ❌ : {json.dumps(stats, ensure_ascii=False)}")
            except Exception as e:
                log.error(f"获取连接统计时出错: {str(e)}")

        except Exception as e:
            log.error(f"处理WebSocket关闭时出错: {str(e)}", exc_info=True)

    # 跨域验证逻辑
    def check_origin(self, origin: str) -> bool:
        """
        跨域验证逻辑,默认允许所有来源

        可以在此方法中实现域名白名单或其他安全验证
        """
        # TODO: 实现更安全的跨域验证
        return True

4、Run Server

这里定义了丰富的服务端功能,可以根据项目情况进行增减

后面如果还有其他的HTPP、WebSocket服务,可以直接继承上面两个类,实现后注册到路由规则(handlers)中

"""
@Project :Titan
@File    :main.py
@Author  :PySuper
@Date    :2025/5/1 12:23
@Desc    :Titan main.py
"""

import os
import signal

import tornado
import tornado.httpserver
import tornado.ioloop
import tornado.netutil
from tornado.web import StaticFileHandler

from logic.config import get_logger
from proxy.server import HttpProxy, WsProxy
from utils.system import close_port

# 创建一个系统级别的logger
logger = get_logger("proxy-server")


def make_app():
    """
    配置tornado应用程序

    功能:
    1. 设置路由规则,将URL路径映射到对应的处理类
    2. 配置静态文件目录,用于提供前端资源
    3. 设置应用程序级别的配置项
    4. 配置安全相关选项
    5. 设置调试模式和日志级别

    :return: tornado应用程序实例
    """
    # 生成启动标识ID
    # startup_id = str(uuid.uuid4())[:8]
    # log = logger.bind(request_id=f"startup_{startup_id}")

    # 定义静态文件目录
    static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
    template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
    upload_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "uploads")

    # 确保目录存在
    for path in [static_path, template_path, upload_path]:
        if not os.path.exists(path):
            os.makedirs(path)
            logger.debug(f"创建目录: {path}")

    # 定义路由规则
    handlers = [
        # API 路由
        # (r"/api/upload", FaceProxy),  # 文件上传接口
        (r"/api/websocket", WsProxy),  # WebSocket连接
        (r"/api/control", HttpProxy),  # 控制接口
        # (r"/api/video_control", FaceProxy),  # 视频控制接口
        # 静态文件路由
        # (r"/files/(.*)", FileProxy),  # 文件访问接口
        (r"/static/(.*)", StaticFileHandler, {"path": static_path}),  # 静态资源
        (r"/uploads/(.*)", StaticFileHandler, {"path": upload_path}),  # 上传文件访问
        # 默认路由 - 可以重定向到首页或返回404
        (r"/(.*)", StaticFileHandler, {"path": template_path, "default_filename": "index.html"}),
    ]

    # 应用程序设置
    settings = {
        "debug": True,  # 开发模式下启用调试
        "autoreload": True,  # 代码变更时自动重载
        "compress_response": True,  # 压缩HTTP响应
        "serve_traceback": True,  # 在调试模式下显示错误堆栈
        "static_path": static_path,  # 静态文件目录
        "template_path": template_path,  # 模板文件目录
        "cookie_secret": os.environ.get("COOKIE_SECRET", "Titan_Secret_Key_2025"),  # Cookie加密密钥
        "xsrf_cookies": False,  # 暂时禁用XSRF保护,根据需要启用
        "websocket_ping_interval": 30,  # WebSocket心跳间隔(秒)
        "websocket_ping_timeout": 120,  # WebSocket心跳超时(秒)
        "ws_handler_map": {},  # WebSocket处理器映射
        "upload_path": upload_path,  # 上传文件保存目录
        "max_buffer_size": 1024 * 1024 * 100,  # 最大缓冲区大小(100MB)
        "max_body_size": 1024 * 1024 * 200,  # 最大请求体大小(200MB)
    }

    # 创建并返回应用程序实例
    app = tornado.web.Application(handlers, **settings)

    # 记录应用程序配置信息
    logger.debug(f"Tornado调试模式: {settings['debug']}")
    logger.debug(f"静态文件目录: {static_path}")
    logger.debug(f"上传文件目录: {upload_path}")
    logger.debug(f"注册的路由数量: {len(handlers)}")

    return app


def handle_signal(sig, frame):
    """处理系统信号,优雅关闭服务器"""
    # log = logger.bind(request_id="shutdown")
    logger.info(f"接收到信号 {sig},服务器正在优雅关闭...")

    # 获取当前的IOLoop实例
    ioloop = tornado.ioloop.IOLoop.current()

    # 使用ioloop.add_callback来安全地停止
    ioloop.add_callback_from_signal(ioloop.stop)


def main():
    """主程序入口"""
    # 创建一个带有请求ID的日志记录器
    # log = logger.bind(request_id="server_startup")

    # 检查端口占用情况并释放端口
    logger.debug("检查端口占用情况...")
    close_port(9000, logger)
    close_port(9001, logger)

    try:
        # 创建应用
        logger.debug("正在初始化Titan服务器...")
        app = make_app()

        # 设置HTTP服务器
        logger.debug("正在启动HTTP服务器...")
        http_server = tornado.httpserver.HTTPServer(app)
        http_server.listen(9000, address="0.0.0.0")
        logger.debug("HTTP      服务器:已启动,监听端口: 9000")

        # 设置WebSocket服务器
        logger.debug("正在启动WebSocket服务器...")
        ws_server = tornado.httpserver.HTTPServer(app)
        ws_server.add_sockets(tornado.netutil.bind_sockets(9001))
        logger.debug("WebSocket 服务器:已启动,监听端口: 9001")

        # 注册信号处理器
        signal.signal(signal.SIGINT, handle_signal)  # 处理Ctrl+C
        signal.signal(signal.SIGTERM, handle_signal)  # 处理终止信号

        # 启动服务
        logger.info("✅ Titan服务器启动完成,正在运行...")
        logger.debug("按Ctrl+C停止服务器")

        # 启动事件循环
        tornado.ioloop.IOLoop.current().start()

        # 如果到达这里,说明事件循环已停止
        logger.info("Titan服务器已关闭")

    except Exception as e:
        logger.error(f"服务器启动过程中出现错误: {e}", exc_info=True)
        return 1

    return 0


if __name__ == "__main__":
    # print("🔌 WebSocket 已连接")  # 插头符号
    # print("📡 WebSocket 通信中")  # 天线符号
    # print("❌ WebSocket 已断开")  # 叉号符号
    exit_code = main()
    exit(exit_code)

0
  1. 支付宝打赏

    qrcode alipay
  2. 微信打赏

    qrcode weixin

评论区