notify_win/server/server.py

274 lines
9.7 KiB
Python
Raw Normal View History

2025-02-28 15:11:02 +08:00
import asyncio
import websockets
import json
import ssl
import secrets
import aiohttp
from aiohttp import web
import yaml
import logging
from datetime import datetime
from collections import defaultdict
import os
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class NotificationServer:
def __init__(self, config_path='config.yml'):
try:
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
except FileNotFoundError:
# 如果配置文件不存在,创建默认配置
self.config = {
'websocket': {
'host': '0.0.0.0',
'port': 8080
},
'security': {
'use_ssl': False,
'ssl_cert': 'path/to/cert.pem',
'ssl_key': 'path/to/key.pem'
},
'data': {
'keys_file': '/app/data/registered_keys.json'
}
}
# 保存默认配置
with open(config_path, 'w', encoding='utf-8') as f:
yaml.safe_dump(self.config, f, allow_unicode=True)
# 存储客户端连接
self.client_connections = defaultdict(set)
# 从配置文件获取数据存储路径
self.keys_file = self.config['data']['keys_file']
self.registered_keys = {}
self.load_registered_keys()
def load_registered_keys(self):
"""从文件加载已注册的密钥"""
try:
if os.path.exists(self.keys_file):
with open(self.keys_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.registered_keys = data
logger.info(f"已加载 {len(self.registered_keys)} 个注册密钥")
except Exception as e:
logger.error(f"加载密钥文件失败: {e}")
def save_registered_keys(self):
"""保存密钥到文件"""
try:
with open(self.keys_file, 'w', encoding='utf-8') as f:
json.dump(self.registered_keys, f, indent=2)
except Exception as e:
logger.error(f"保存密钥文件失败: {e}")
def generate_key(self):
"""生成唯一的客户端密钥"""
return secrets.token_urlsafe(32)
async def register_client(self, websocket):
"""注册新客户端并生成密钥"""
try:
# 等待客户端发送注册请求
msg = await websocket.recv()
data = json.loads(msg)
if data.get('action') == 'register':
device_name = data.get('device_name', 'unknown')
key = self.generate_key()
client_info = {
'device_name': device_name,
'created_at': datetime.now().isoformat(),
'last_seen': datetime.now().isoformat()
}
self.registered_keys[key] = client_info
self.client_connections[key].add(websocket)
# 发送密钥给客户端
await websocket.send(json.dumps({
'status': 'success',
'key': key,
'message': '注册成功'
}))
logger.info(f"新客户端注册: {device_name}, key: {key[:8]}...")
self.save_registered_keys() # 保存新注册的密钥
return key
elif data.get('action') == 'connect':
key = data.get('key')
if key in self.registered_keys:
self.client_connections[key].add(websocket)
self.registered_keys[key]['last_seen'] = datetime.now().isoformat()
await websocket.send(json.dumps({
'status': 'success',
'message': '连接成功'
}))
logger.info(f"客户端重连成功: {key[:8]}...")
return key
else:
await websocket.send(json.dumps({
'status': 'error',
'message': '无效的密钥'
}))
return None
except Exception as e:
logger.error(f"注册失败: {e}")
await websocket.close()
return None
async def handle_websocket(self, websocket):
key = await self.register_client(websocket)
if not key:
return
try:
# 添加心跳检测
while True:
try:
message = await asyncio.wait_for(websocket.recv(), timeout=30)
if message == "ping":
await websocket.send("pong")
except asyncio.TimeoutError:
# 30秒没有收到消息发送 ping
try:
await websocket.send("ping")
await asyncio.wait_for(websocket.recv(), timeout=10)
except:
# ping 失败,关闭连接
break
except websockets.exceptions.ConnectionClosed:
break
finally:
self.client_connections[key].discard(websocket)
if not self.client_connections[key]:
logger.info(f"客户端完全断开: {key[:8]}...")
async def send_notification(self, key, notification):
"""向指定key的所有连接发送通知"""
if key not in self.client_connections:
return False
connections = self.client_connections[key].copy()
if not connections:
return False
for websocket in connections:
try:
await websocket.send(json.dumps(notification))
except Exception as e:
logger.error(f"发送通知失败: {e}")
self.client_connections[key].discard(websocket)
return True
async def handle_http_push(self, request):
try:
# 从URL路径中获取key
key = request.match_info['key']
if key not in self.registered_keys:
return web.Response(status=404, text='Invalid key')
data = await request.json()
if not data.get('title') or not data.get('message'):
return web.Response(status=400, text='Missing title or message')
notification = {
'title': data['title'],
'message': data['message'],
'type': data.get('type', 'info'),
'timestamp': datetime.now().isoformat()
}
success = await self.send_notification(key, notification)
if success:
return web.Response(text='Notification sent')
else:
return web.Response(status=404, text='No active connections')
except Exception as e:
logger.error(f"处理HTTP请求失败: {e}")
return web.Response(status=500, text='Internal server error')
async def start(self):
# 设置SSL
ssl_context = None
if self.config['security']['use_ssl']:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
self.config['security']['ssl_cert'],
self.config['security']['ssl_key']
)
# 创建 aiohttp 应用
app = web.Application()
# 添加 WebSocket 处理
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
# 包装 aiohttp WebSocketResponse 为 websockets 兼容的接口
class WebSocketWrapper:
async def send(self, message):
await ws.send_str(message)
async def recv(self):
msg = await ws.receive()
if msg.type == web.WSMsgType.TEXT:
return msg.data
elif msg.type == web.WSMsgType.CLOSE:
raise websockets.exceptions.ConnectionClosed(None, None)
async def close(self):
await ws.close()
wrapped_ws = WebSocketWrapper()
await self.handle_websocket(wrapped_ws)
return ws
# 注册路由
app.router.add_get('/ws', websocket_handler) # WebSocket 端点
app.router.add_post('/push/{key}', self.handle_http_push) # HTTP 推送端点
# 启动服务器
runner = web.AppRunner(app)
await runner.setup()
# 根据是否使用 SSL 创建不同的站点
if ssl_context:
site = web.SSLSite(
runner,
self.config['websocket']['host'],
self.config['websocket']['port'],
ssl_context=ssl_context
)
else:
site = web.TCPSite(
runner,
self.config['websocket']['host'],
self.config['websocket']['port']
)
await site.start()
logger.info(f"服务器运行在 {self.config['websocket']['host']}:{self.config['websocket']['port']}")
logger.info("WebSocket 路径: /ws")
logger.info("HTTP 推送路径: /push/{key}")
await asyncio.Future() # 持续运行
if __name__ == "__main__":
server = NotificationServer()
asyncio.run(server.start())