274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
|
import asyncio
|
|||
|
import websockets
|
|||
|
import json
|
|||
|
import ssl
|
|||
|
import secrets
|
|||
|
import aiohttp
|
|||
|
from aiohttp import web
|
|||
|
import yaml
|
|||
|
import logging
|
|||
|
from datetime import datetime
|
|||
|
from collections import defaultdict
|
|||
|
import os
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(
|
|||
|
level=logging.INFO,
|
|||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|||
|
)
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
class NotificationServer:
|
|||
|
def __init__(self, config_path='config.yml'):
|
|||
|
try:
|
|||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|||
|
self.config = yaml.safe_load(f)
|
|||
|
except FileNotFoundError:
|
|||
|
# 如果配置文件不存在,创建默认配置
|
|||
|
self.config = {
|
|||
|
'websocket': {
|
|||
|
'host': '0.0.0.0',
|
|||
|
'port': 8080
|
|||
|
},
|
|||
|
'security': {
|
|||
|
'use_ssl': False,
|
|||
|
'ssl_cert': 'path/to/cert.pem',
|
|||
|
'ssl_key': 'path/to/key.pem'
|
|||
|
},
|
|||
|
'data': {
|
|||
|
'keys_file': '/app/data/registered_keys.json'
|
|||
|
}
|
|||
|
}
|
|||
|
# 保存默认配置
|
|||
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|||
|
yaml.safe_dump(self.config, f, allow_unicode=True)
|
|||
|
|
|||
|
# 存储客户端连接
|
|||
|
self.client_connections = defaultdict(set)
|
|||
|
# 从配置文件获取数据存储路径
|
|||
|
self.keys_file = self.config['data']['keys_file']
|
|||
|
self.registered_keys = {}
|
|||
|
self.load_registered_keys()
|
|||
|
|
|||
|
def load_registered_keys(self):
|
|||
|
"""从文件加载已注册的密钥"""
|
|||
|
try:
|
|||
|
if os.path.exists(self.keys_file):
|
|||
|
with open(self.keys_file, 'r', encoding='utf-8') as f:
|
|||
|
data = json.load(f)
|
|||
|
self.registered_keys = data
|
|||
|
logger.info(f"已加载 {len(self.registered_keys)} 个注册密钥")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"加载密钥文件失败: {e}")
|
|||
|
|
|||
|
def save_registered_keys(self):
|
|||
|
"""保存密钥到文件"""
|
|||
|
try:
|
|||
|
with open(self.keys_file, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(self.registered_keys, f, indent=2)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"保存密钥文件失败: {e}")
|
|||
|
|
|||
|
def generate_key(self):
|
|||
|
"""生成唯一的客户端密钥"""
|
|||
|
return secrets.token_urlsafe(32)
|
|||
|
|
|||
|
async def register_client(self, websocket):
|
|||
|
"""注册新客户端并生成密钥"""
|
|||
|
try:
|
|||
|
# 等待客户端发送注册请求
|
|||
|
msg = await websocket.recv()
|
|||
|
data = json.loads(msg)
|
|||
|
|
|||
|
if data.get('action') == 'register':
|
|||
|
device_name = data.get('device_name', 'unknown')
|
|||
|
key = self.generate_key()
|
|||
|
|
|||
|
client_info = {
|
|||
|
'device_name': device_name,
|
|||
|
'created_at': datetime.now().isoformat(),
|
|||
|
'last_seen': datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
self.registered_keys[key] = client_info
|
|||
|
self.client_connections[key].add(websocket)
|
|||
|
|
|||
|
# 发送密钥给客户端
|
|||
|
await websocket.send(json.dumps({
|
|||
|
'status': 'success',
|
|||
|
'key': key,
|
|||
|
'message': '注册成功'
|
|||
|
}))
|
|||
|
|
|||
|
logger.info(f"新客户端注册: {device_name}, key: {key[:8]}...")
|
|||
|
self.save_registered_keys() # 保存新注册的密钥
|
|||
|
return key
|
|||
|
|
|||
|
elif data.get('action') == 'connect':
|
|||
|
key = data.get('key')
|
|||
|
if key in self.registered_keys:
|
|||
|
self.client_connections[key].add(websocket)
|
|||
|
self.registered_keys[key]['last_seen'] = datetime.now().isoformat()
|
|||
|
await websocket.send(json.dumps({
|
|||
|
'status': 'success',
|
|||
|
'message': '连接成功'
|
|||
|
}))
|
|||
|
logger.info(f"客户端重连成功: {key[:8]}...")
|
|||
|
return key
|
|||
|
else:
|
|||
|
await websocket.send(json.dumps({
|
|||
|
'status': 'error',
|
|||
|
'message': '无效的密钥'
|
|||
|
}))
|
|||
|
return None
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"注册失败: {e}")
|
|||
|
await websocket.close()
|
|||
|
return None
|
|||
|
|
|||
|
async def handle_websocket(self, websocket):
|
|||
|
key = await self.register_client(websocket)
|
|||
|
if not key:
|
|||
|
return
|
|||
|
|
|||
|
try:
|
|||
|
# 添加心跳检测
|
|||
|
while True:
|
|||
|
try:
|
|||
|
message = await asyncio.wait_for(websocket.recv(), timeout=30)
|
|||
|
if message == "ping":
|
|||
|
await websocket.send("pong")
|
|||
|
except asyncio.TimeoutError:
|
|||
|
# 30秒没有收到消息,发送 ping
|
|||
|
try:
|
|||
|
await websocket.send("ping")
|
|||
|
await asyncio.wait_for(websocket.recv(), timeout=10)
|
|||
|
except:
|
|||
|
# ping 失败,关闭连接
|
|||
|
break
|
|||
|
except websockets.exceptions.ConnectionClosed:
|
|||
|
break
|
|||
|
finally:
|
|||
|
self.client_connections[key].discard(websocket)
|
|||
|
if not self.client_connections[key]:
|
|||
|
logger.info(f"客户端完全断开: {key[:8]}...")
|
|||
|
|
|||
|
async def send_notification(self, key, notification):
|
|||
|
"""向指定key的所有连接发送通知"""
|
|||
|
if key not in self.client_connections:
|
|||
|
return False
|
|||
|
|
|||
|
connections = self.client_connections[key].copy()
|
|||
|
if not connections:
|
|||
|
return False
|
|||
|
|
|||
|
for websocket in connections:
|
|||
|
try:
|
|||
|
await websocket.send(json.dumps(notification))
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"发送通知失败: {e}")
|
|||
|
self.client_connections[key].discard(websocket)
|
|||
|
|
|||
|
return True
|
|||
|
|
|||
|
async def handle_http_push(self, request):
|
|||
|
try:
|
|||
|
# 从URL路径中获取key
|
|||
|
key = request.match_info['key']
|
|||
|
|
|||
|
if key not in self.registered_keys:
|
|||
|
return web.Response(status=404, text='Invalid key')
|
|||
|
|
|||
|
data = await request.json()
|
|||
|
if not data.get('title') or not data.get('message'):
|
|||
|
return web.Response(status=400, text='Missing title or message')
|
|||
|
|
|||
|
notification = {
|
|||
|
'title': data['title'],
|
|||
|
'message': data['message'],
|
|||
|
'type': data.get('type', 'info'),
|
|||
|
'timestamp': datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
success = await self.send_notification(key, notification)
|
|||
|
if success:
|
|||
|
return web.Response(text='Notification sent')
|
|||
|
else:
|
|||
|
return web.Response(status=404, text='No active connections')
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理HTTP请求失败: {e}")
|
|||
|
return web.Response(status=500, text='Internal server error')
|
|||
|
|
|||
|
async def start(self):
|
|||
|
# 设置SSL
|
|||
|
ssl_context = None
|
|||
|
if self.config['security']['use_ssl']:
|
|||
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|||
|
ssl_context.load_cert_chain(
|
|||
|
self.config['security']['ssl_cert'],
|
|||
|
self.config['security']['ssl_key']
|
|||
|
)
|
|||
|
|
|||
|
# 创建 aiohttp 应用
|
|||
|
app = web.Application()
|
|||
|
|
|||
|
# 添加 WebSocket 处理
|
|||
|
async def websocket_handler(request):
|
|||
|
ws = web.WebSocketResponse()
|
|||
|
await ws.prepare(request)
|
|||
|
|
|||
|
# 包装 aiohttp WebSocketResponse 为 websockets 兼容的接口
|
|||
|
class WebSocketWrapper:
|
|||
|
async def send(self, message):
|
|||
|
await ws.send_str(message)
|
|||
|
|
|||
|
async def recv(self):
|
|||
|
msg = await ws.receive()
|
|||
|
if msg.type == web.WSMsgType.TEXT:
|
|||
|
return msg.data
|
|||
|
elif msg.type == web.WSMsgType.CLOSE:
|
|||
|
raise websockets.exceptions.ConnectionClosed(None, None)
|
|||
|
|
|||
|
async def close(self):
|
|||
|
await ws.close()
|
|||
|
|
|||
|
wrapped_ws = WebSocketWrapper()
|
|||
|
await self.handle_websocket(wrapped_ws)
|
|||
|
return ws
|
|||
|
|
|||
|
# 注册路由
|
|||
|
app.router.add_get('/ws', websocket_handler) # WebSocket 端点
|
|||
|
app.router.add_post('/push/{key}', self.handle_http_push) # HTTP 推送端点
|
|||
|
|
|||
|
# 启动服务器
|
|||
|
runner = web.AppRunner(app)
|
|||
|
await runner.setup()
|
|||
|
|
|||
|
# 根据是否使用 SSL 创建不同的站点
|
|||
|
if ssl_context:
|
|||
|
site = web.SSLSite(
|
|||
|
runner,
|
|||
|
self.config['websocket']['host'],
|
|||
|
self.config['websocket']['port'],
|
|||
|
ssl_context=ssl_context
|
|||
|
)
|
|||
|
else:
|
|||
|
site = web.TCPSite(
|
|||
|
runner,
|
|||
|
self.config['websocket']['host'],
|
|||
|
self.config['websocket']['port']
|
|||
|
)
|
|||
|
|
|||
|
await site.start()
|
|||
|
|
|||
|
logger.info(f"服务器运行在 {self.config['websocket']['host']}:{self.config['websocket']['port']}")
|
|||
|
logger.info("WebSocket 路径: /ws")
|
|||
|
logger.info("HTTP 推送路径: /push/{key}")
|
|||
|
|
|||
|
await asyncio.Future() # 持续运行
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
server = NotificationServer()
|
|||
|
asyncio.run(server.start())
|