274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
import asyncio
|
||
import websockets
|
||
import json
|
||
import ssl
|
||
import secrets
|
||
import aiohttp
|
||
from aiohttp import web
|
||
import yaml
|
||
import logging
|
||
from datetime import datetime
|
||
from collections import defaultdict
|
||
import os
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class NotificationServer:
|
||
def __init__(self, config_path='config.yml'):
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
self.config = yaml.safe_load(f)
|
||
except FileNotFoundError:
|
||
# 如果配置文件不存在,创建默认配置
|
||
self.config = {
|
||
'websocket': {
|
||
'host': '0.0.0.0',
|
||
'port': 8080
|
||
},
|
||
'security': {
|
||
'use_ssl': False,
|
||
'ssl_cert': 'path/to/cert.pem',
|
||
'ssl_key': 'path/to/key.pem'
|
||
},
|
||
'data': {
|
||
'keys_file': '/app/data/registered_keys.json'
|
||
}
|
||
}
|
||
# 保存默认配置
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
yaml.safe_dump(self.config, f, allow_unicode=True)
|
||
|
||
# 存储客户端连接
|
||
self.client_connections = defaultdict(set)
|
||
# 从配置文件获取数据存储路径
|
||
self.keys_file = self.config['data']['keys_file']
|
||
self.registered_keys = {}
|
||
self.load_registered_keys()
|
||
|
||
def load_registered_keys(self):
|
||
"""从文件加载已注册的密钥"""
|
||
try:
|
||
if os.path.exists(self.keys_file):
|
||
with open(self.keys_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
self.registered_keys = data
|
||
logger.info(f"已加载 {len(self.registered_keys)} 个注册密钥")
|
||
except Exception as e:
|
||
logger.error(f"加载密钥文件失败: {e}")
|
||
|
||
def save_registered_keys(self):
|
||
"""保存密钥到文件"""
|
||
try:
|
||
with open(self.keys_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.registered_keys, f, indent=2)
|
||
except Exception as e:
|
||
logger.error(f"保存密钥文件失败: {e}")
|
||
|
||
def generate_key(self):
|
||
"""生成唯一的客户端密钥"""
|
||
return secrets.token_urlsafe(32)
|
||
|
||
async def register_client(self, websocket):
|
||
"""注册新客户端并生成密钥"""
|
||
try:
|
||
# 等待客户端发送注册请求
|
||
msg = await websocket.recv()
|
||
data = json.loads(msg)
|
||
|
||
if data.get('action') == 'register':
|
||
device_name = data.get('device_name', 'unknown')
|
||
key = self.generate_key()
|
||
|
||
client_info = {
|
||
'device_name': device_name,
|
||
'created_at': datetime.now().isoformat(),
|
||
'last_seen': datetime.now().isoformat()
|
||
}
|
||
|
||
self.registered_keys[key] = client_info
|
||
self.client_connections[key].add(websocket)
|
||
|
||
# 发送密钥给客户端
|
||
await websocket.send(json.dumps({
|
||
'status': 'success',
|
||
'key': key,
|
||
'message': '注册成功'
|
||
}))
|
||
|
||
logger.info(f"新客户端注册: {device_name}, key: {key[:8]}...")
|
||
self.save_registered_keys() # 保存新注册的密钥
|
||
return key
|
||
|
||
elif data.get('action') == 'connect':
|
||
key = data.get('key')
|
||
if key in self.registered_keys:
|
||
self.client_connections[key].add(websocket)
|
||
self.registered_keys[key]['last_seen'] = datetime.now().isoformat()
|
||
await websocket.send(json.dumps({
|
||
'status': 'success',
|
||
'message': '连接成功'
|
||
}))
|
||
logger.info(f"客户端重连成功: {key[:8]}...")
|
||
return key
|
||
else:
|
||
await websocket.send(json.dumps({
|
||
'status': 'error',
|
||
'message': '无效的密钥'
|
||
}))
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"注册失败: {e}")
|
||
await websocket.close()
|
||
return None
|
||
|
||
async def handle_websocket(self, websocket):
|
||
key = await self.register_client(websocket)
|
||
if not key:
|
||
return
|
||
|
||
try:
|
||
# 添加心跳检测
|
||
while True:
|
||
try:
|
||
message = await asyncio.wait_for(websocket.recv(), timeout=30)
|
||
if message == "ping":
|
||
await websocket.send("pong")
|
||
except asyncio.TimeoutError:
|
||
# 30秒没有收到消息,发送 ping
|
||
try:
|
||
await websocket.send("ping")
|
||
await asyncio.wait_for(websocket.recv(), timeout=10)
|
||
except:
|
||
# ping 失败,关闭连接
|
||
break
|
||
except websockets.exceptions.ConnectionClosed:
|
||
break
|
||
finally:
|
||
self.client_connections[key].discard(websocket)
|
||
if not self.client_connections[key]:
|
||
logger.info(f"客户端完全断开: {key[:8]}...")
|
||
|
||
async def send_notification(self, key, notification):
|
||
"""向指定key的所有连接发送通知"""
|
||
if key not in self.client_connections:
|
||
return False
|
||
|
||
connections = self.client_connections[key].copy()
|
||
if not connections:
|
||
return False
|
||
|
||
for websocket in connections:
|
||
try:
|
||
await websocket.send(json.dumps(notification))
|
||
except Exception as e:
|
||
logger.error(f"发送通知失败: {e}")
|
||
self.client_connections[key].discard(websocket)
|
||
|
||
return True
|
||
|
||
async def handle_http_push(self, request):
|
||
try:
|
||
# 从URL路径中获取key
|
||
key = request.match_info['key']
|
||
|
||
if key not in self.registered_keys:
|
||
return web.Response(status=404, text='Invalid key')
|
||
|
||
data = await request.json()
|
||
if not data.get('title') or not data.get('message'):
|
||
return web.Response(status=400, text='Missing title or message')
|
||
|
||
notification = {
|
||
'title': data['title'],
|
||
'message': data['message'],
|
||
'type': data.get('type', 'info'),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
success = await self.send_notification(key, notification)
|
||
if success:
|
||
return web.Response(text='Notification sent')
|
||
else:
|
||
return web.Response(status=404, text='No active connections')
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理HTTP请求失败: {e}")
|
||
return web.Response(status=500, text='Internal server error')
|
||
|
||
async def start(self):
|
||
# 设置SSL
|
||
ssl_context = None
|
||
if self.config['security']['use_ssl']:
|
||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||
ssl_context.load_cert_chain(
|
||
self.config['security']['ssl_cert'],
|
||
self.config['security']['ssl_key']
|
||
)
|
||
|
||
# 创建 aiohttp 应用
|
||
app = web.Application()
|
||
|
||
# 添加 WebSocket 处理
|
||
async def websocket_handler(request):
|
||
ws = web.WebSocketResponse()
|
||
await ws.prepare(request)
|
||
|
||
# 包装 aiohttp WebSocketResponse 为 websockets 兼容的接口
|
||
class WebSocketWrapper:
|
||
async def send(self, message):
|
||
await ws.send_str(message)
|
||
|
||
async def recv(self):
|
||
msg = await ws.receive()
|
||
if msg.type == web.WSMsgType.TEXT:
|
||
return msg.data
|
||
elif msg.type == web.WSMsgType.CLOSE:
|
||
raise websockets.exceptions.ConnectionClosed(None, None)
|
||
|
||
async def close(self):
|
||
await ws.close()
|
||
|
||
wrapped_ws = WebSocketWrapper()
|
||
await self.handle_websocket(wrapped_ws)
|
||
return ws
|
||
|
||
# 注册路由
|
||
app.router.add_get('/ws', websocket_handler) # WebSocket 端点
|
||
app.router.add_post('/push/{key}', self.handle_http_push) # HTTP 推送端点
|
||
|
||
# 启动服务器
|
||
runner = web.AppRunner(app)
|
||
await runner.setup()
|
||
|
||
# 根据是否使用 SSL 创建不同的站点
|
||
if ssl_context:
|
||
site = web.SSLSite(
|
||
runner,
|
||
self.config['websocket']['host'],
|
||
self.config['websocket']['port'],
|
||
ssl_context=ssl_context
|
||
)
|
||
else:
|
||
site = web.TCPSite(
|
||
runner,
|
||
self.config['websocket']['host'],
|
||
self.config['websocket']['port']
|
||
)
|
||
|
||
await site.start()
|
||
|
||
logger.info(f"服务器运行在 {self.config['websocket']['host']}:{self.config['websocket']['port']}")
|
||
logger.info("WebSocket 路径: /ws")
|
||
logger.info("HTTP 推送路径: /push/{key}")
|
||
|
||
await asyncio.Future() # 持续运行
|
||
|
||
if __name__ == "__main__":
|
||
server = NotificationServer()
|
||
asyncio.run(server.start()) |