import asyncio import websockets import json import ssl import secrets import aiohttp from aiohttp import web import yaml import logging from datetime import datetime from collections import defaultdict import os # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class NotificationServer: def __init__(self, config_path='config.yml'): try: with open(config_path, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) except FileNotFoundError: # 如果配置文件不存在,创建默认配置 self.config = { 'websocket': { 'host': '0.0.0.0', 'port': 8080 }, 'security': { 'use_ssl': False, 'ssl_cert': 'path/to/cert.pem', 'ssl_key': 'path/to/key.pem' }, 'data': { 'keys_file': '/app/data/registered_keys.json' } } # 保存默认配置 with open(config_path, 'w', encoding='utf-8') as f: yaml.safe_dump(self.config, f, allow_unicode=True) # 存储客户端连接 self.client_connections = defaultdict(set) # 从配置文件获取数据存储路径 self.keys_file = self.config['data']['keys_file'] self.registered_keys = {} self.load_registered_keys() def load_registered_keys(self): """从文件加载已注册的密钥""" try: if os.path.exists(self.keys_file): with open(self.keys_file, 'r', encoding='utf-8') as f: data = json.load(f) self.registered_keys = data logger.info(f"已加载 {len(self.registered_keys)} 个注册密钥") except Exception as e: logger.error(f"加载密钥文件失败: {e}") def save_registered_keys(self): """保存密钥到文件""" try: with open(self.keys_file, 'w', encoding='utf-8') as f: json.dump(self.registered_keys, f, indent=2) except Exception as e: logger.error(f"保存密钥文件失败: {e}") def generate_key(self): """生成唯一的客户端密钥""" return secrets.token_urlsafe(32) async def register_client(self, websocket): """注册新客户端并生成密钥""" try: # 等待客户端发送注册请求 msg = await websocket.recv() data = json.loads(msg) if data.get('action') == 'register': device_name = data.get('device_name', 'unknown') key = self.generate_key() client_info = { 'device_name': device_name, 'created_at': datetime.now().isoformat(), 'last_seen': datetime.now().isoformat() } self.registered_keys[key] = client_info self.client_connections[key].add(websocket) # 发送密钥给客户端 await websocket.send(json.dumps({ 'status': 'success', 'key': key, 'message': '注册成功' })) logger.info(f"新客户端注册: {device_name}, key: {key[:8]}...") self.save_registered_keys() # 保存新注册的密钥 return key elif data.get('action') == 'connect': key = data.get('key') if key in self.registered_keys: self.client_connections[key].add(websocket) self.registered_keys[key]['last_seen'] = datetime.now().isoformat() await websocket.send(json.dumps({ 'status': 'success', 'message': '连接成功' })) logger.info(f"客户端重连成功: {key[:8]}...") return key else: await websocket.send(json.dumps({ 'status': 'error', 'message': '无效的密钥' })) return None except Exception as e: logger.error(f"注册失败: {e}") await websocket.close() return None async def handle_websocket(self, websocket): key = await self.register_client(websocket) if not key: return try: # 添加心跳检测 while True: try: message = await asyncio.wait_for(websocket.recv(), timeout=30) if message == "ping": await websocket.send("pong") except asyncio.TimeoutError: # 30秒没有收到消息,发送 ping try: await websocket.send("ping") await asyncio.wait_for(websocket.recv(), timeout=10) except: # ping 失败,关闭连接 break except websockets.exceptions.ConnectionClosed: break finally: self.client_connections[key].discard(websocket) if not self.client_connections[key]: logger.info(f"客户端完全断开: {key[:8]}...") async def send_notification(self, key, notification): """向指定key的所有连接发送通知""" if key not in self.client_connections: return False connections = self.client_connections[key].copy() if not connections: return False for websocket in connections: try: await websocket.send(json.dumps(notification)) except Exception as e: logger.error(f"发送通知失败: {e}") self.client_connections[key].discard(websocket) return True async def handle_http_push(self, request): try: # 从URL路径中获取key key = request.match_info['key'] if key not in self.registered_keys: return web.Response(status=404, text='Invalid key') data = await request.json() if not data.get('title') or not data.get('message'): return web.Response(status=400, text='Missing title or message') notification = { 'title': data['title'], 'message': data['message'], 'type': data.get('type', 'info'), 'timestamp': datetime.now().isoformat() } success = await self.send_notification(key, notification) if success: return web.Response(text='Notification sent') else: return web.Response(status=404, text='No active connections') except Exception as e: logger.error(f"处理HTTP请求失败: {e}") return web.Response(status=500, text='Internal server error') async def start(self): # 设置SSL ssl_context = None if self.config['security']['use_ssl']: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain( self.config['security']['ssl_cert'], self.config['security']['ssl_key'] ) # 创建 aiohttp 应用 app = web.Application() # 添加 WebSocket 处理 async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) # 包装 aiohttp WebSocketResponse 为 websockets 兼容的接口 class WebSocketWrapper: async def send(self, message): await ws.send_str(message) async def recv(self): msg = await ws.receive() if msg.type == web.WSMsgType.TEXT: return msg.data elif msg.type == web.WSMsgType.CLOSE: raise websockets.exceptions.ConnectionClosed(None, None) async def close(self): await ws.close() wrapped_ws = WebSocketWrapper() await self.handle_websocket(wrapped_ws) return ws # 注册路由 app.router.add_get('/ws', websocket_handler) # WebSocket 端点 app.router.add_post('/push/{key}', self.handle_http_push) # HTTP 推送端点 # 启动服务器 runner = web.AppRunner(app) await runner.setup() # 根据是否使用 SSL 创建不同的站点 if ssl_context: site = web.SSLSite( runner, self.config['websocket']['host'], self.config['websocket']['port'], ssl_context=ssl_context ) else: site = web.TCPSite( runner, self.config['websocket']['host'], self.config['websocket']['port'] ) await site.start() logger.info(f"服务器运行在 {self.config['websocket']['host']}:{self.config['websocket']['port']}") logger.info("WebSocket 路径: /ws") logger.info("HTTP 推送路径: /push/{key}") await asyncio.Future() # 持续运行 if __name__ == "__main__": server = NotificationServer() asyncio.run(server.start())