notify_win/server/server.py
2025-02-28 15:11:02 +08:00

274 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import websockets
import json
import ssl
import secrets
import aiohttp
from aiohttp import web
import yaml
import logging
from datetime import datetime
from collections import defaultdict
import os
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class NotificationServer:
def __init__(self, config_path='config.yml'):
try:
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
except FileNotFoundError:
# 如果配置文件不存在,创建默认配置
self.config = {
'websocket': {
'host': '0.0.0.0',
'port': 8080
},
'security': {
'use_ssl': False,
'ssl_cert': 'path/to/cert.pem',
'ssl_key': 'path/to/key.pem'
},
'data': {
'keys_file': '/app/data/registered_keys.json'
}
}
# 保存默认配置
with open(config_path, 'w', encoding='utf-8') as f:
yaml.safe_dump(self.config, f, allow_unicode=True)
# 存储客户端连接
self.client_connections = defaultdict(set)
# 从配置文件获取数据存储路径
self.keys_file = self.config['data']['keys_file']
self.registered_keys = {}
self.load_registered_keys()
def load_registered_keys(self):
"""从文件加载已注册的密钥"""
try:
if os.path.exists(self.keys_file):
with open(self.keys_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.registered_keys = data
logger.info(f"已加载 {len(self.registered_keys)} 个注册密钥")
except Exception as e:
logger.error(f"加载密钥文件失败: {e}")
def save_registered_keys(self):
"""保存密钥到文件"""
try:
with open(self.keys_file, 'w', encoding='utf-8') as f:
json.dump(self.registered_keys, f, indent=2)
except Exception as e:
logger.error(f"保存密钥文件失败: {e}")
def generate_key(self):
"""生成唯一的客户端密钥"""
return secrets.token_urlsafe(32)
async def register_client(self, websocket):
"""注册新客户端并生成密钥"""
try:
# 等待客户端发送注册请求
msg = await websocket.recv()
data = json.loads(msg)
if data.get('action') == 'register':
device_name = data.get('device_name', 'unknown')
key = self.generate_key()
client_info = {
'device_name': device_name,
'created_at': datetime.now().isoformat(),
'last_seen': datetime.now().isoformat()
}
self.registered_keys[key] = client_info
self.client_connections[key].add(websocket)
# 发送密钥给客户端
await websocket.send(json.dumps({
'status': 'success',
'key': key,
'message': '注册成功'
}))
logger.info(f"新客户端注册: {device_name}, key: {key[:8]}...")
self.save_registered_keys() # 保存新注册的密钥
return key
elif data.get('action') == 'connect':
key = data.get('key')
if key in self.registered_keys:
self.client_connections[key].add(websocket)
self.registered_keys[key]['last_seen'] = datetime.now().isoformat()
await websocket.send(json.dumps({
'status': 'success',
'message': '连接成功'
}))
logger.info(f"客户端重连成功: {key[:8]}...")
return key
else:
await websocket.send(json.dumps({
'status': 'error',
'message': '无效的密钥'
}))
return None
except Exception as e:
logger.error(f"注册失败: {e}")
await websocket.close()
return None
async def handle_websocket(self, websocket):
key = await self.register_client(websocket)
if not key:
return
try:
# 添加心跳检测
while True:
try:
message = await asyncio.wait_for(websocket.recv(), timeout=30)
if message == "ping":
await websocket.send("pong")
except asyncio.TimeoutError:
# 30秒没有收到消息发送 ping
try:
await websocket.send("ping")
await asyncio.wait_for(websocket.recv(), timeout=10)
except:
# ping 失败,关闭连接
break
except websockets.exceptions.ConnectionClosed:
break
finally:
self.client_connections[key].discard(websocket)
if not self.client_connections[key]:
logger.info(f"客户端完全断开: {key[:8]}...")
async def send_notification(self, key, notification):
"""向指定key的所有连接发送通知"""
if key not in self.client_connections:
return False
connections = self.client_connections[key].copy()
if not connections:
return False
for websocket in connections:
try:
await websocket.send(json.dumps(notification))
except Exception as e:
logger.error(f"发送通知失败: {e}")
self.client_connections[key].discard(websocket)
return True
async def handle_http_push(self, request):
try:
# 从URL路径中获取key
key = request.match_info['key']
if key not in self.registered_keys:
return web.Response(status=404, text='Invalid key')
data = await request.json()
if not data.get('title') or not data.get('message'):
return web.Response(status=400, text='Missing title or message')
notification = {
'title': data['title'],
'message': data['message'],
'type': data.get('type', 'info'),
'timestamp': datetime.now().isoformat()
}
success = await self.send_notification(key, notification)
if success:
return web.Response(text='Notification sent')
else:
return web.Response(status=404, text='No active connections')
except Exception as e:
logger.error(f"处理HTTP请求失败: {e}")
return web.Response(status=500, text='Internal server error')
async def start(self):
# 设置SSL
ssl_context = None
if self.config['security']['use_ssl']:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
self.config['security']['ssl_cert'],
self.config['security']['ssl_key']
)
# 创建 aiohttp 应用
app = web.Application()
# 添加 WebSocket 处理
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
# 包装 aiohttp WebSocketResponse 为 websockets 兼容的接口
class WebSocketWrapper:
async def send(self, message):
await ws.send_str(message)
async def recv(self):
msg = await ws.receive()
if msg.type == web.WSMsgType.TEXT:
return msg.data
elif msg.type == web.WSMsgType.CLOSE:
raise websockets.exceptions.ConnectionClosed(None, None)
async def close(self):
await ws.close()
wrapped_ws = WebSocketWrapper()
await self.handle_websocket(wrapped_ws)
return ws
# 注册路由
app.router.add_get('/ws', websocket_handler) # WebSocket 端点
app.router.add_post('/push/{key}', self.handle_http_push) # HTTP 推送端点
# 启动服务器
runner = web.AppRunner(app)
await runner.setup()
# 根据是否使用 SSL 创建不同的站点
if ssl_context:
site = web.SSLSite(
runner,
self.config['websocket']['host'],
self.config['websocket']['port'],
ssl_context=ssl_context
)
else:
site = web.TCPSite(
runner,
self.config['websocket']['host'],
self.config['websocket']['port']
)
await site.start()
logger.info(f"服务器运行在 {self.config['websocket']['host']}:{self.config['websocket']['port']}")
logger.info("WebSocket 路径: /ws")
logger.info("HTTP 推送路径: /push/{key}")
await asyncio.Future() # 持续运行
if __name__ == "__main__":
server = NotificationServer()
asyncio.run(server.start())