package event import ( "bufio" "context" "encoding/json" "fmt" "io" "net/http" "strings" "sync" "time" "git.x2erp.com/qdy/go-base/logger" ) // EventDispatcher 事件分发器 - 单例模式 type EventDispatcher struct { mu sync.RWMutex baseURL string port int subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels client *http.Client cancelFunc context.CancelFunc running bool } // NewEventDispatcher 创建新的事件分发器 func NewEventDispatcher(baseURL string, port int) *EventDispatcher { return &EventDispatcher{ baseURL: baseURL, port: port, subscriptions: make(map[string]map[chan string]struct{}), client: &http.Client{ Timeout: 0, // 无超时限制,用于长连接 }, running: false, } } // Start 启动事件分发器,连接到opencode全局事件流 func (ed *EventDispatcher) Start(ctx context.Context) error { ed.mu.Lock() if ed.running { ed.mu.Unlock() return fmt.Errorf("event dispatcher already running") } // 创建子上下文用于控制SSE连接 sseCtx, cancel := context.WithCancel(ctx) ed.cancelFunc = cancel ed.running = true ed.mu.Unlock() // 启动SSE连接协程 go ed.runSSEConnection(sseCtx) logger.Info(fmt.Sprintf("事件分发器已启动 baseURL=%s port=%d", ed.baseURL, ed.port)) return nil } // Stop 停止事件分发器 func (ed *EventDispatcher) Stop() { ed.mu.Lock() if !ed.running { ed.mu.Unlock() return } if ed.cancelFunc != nil { ed.cancelFunc() } // 清理所有订阅通道 for sessionID, channels := range ed.subscriptions { for ch := range channels { close(ch) } delete(ed.subscriptions, sessionID) } ed.running = false ed.mu.Unlock() logger.Info("事件分发器已停止") } // Subscribe 订阅指定会话的事件 func (ed *EventDispatcher) Subscribe(sessionID, userID string) (<-chan string, error) { ed.mu.Lock() defer ed.mu.Unlock() // 创建缓冲通道 ch := make(chan string, 100) // 添加到订阅列表 if _, exists := ed.subscriptions[sessionID]; !exists { ed.subscriptions[sessionID] = make(map[chan string]struct{}) } ed.subscriptions[sessionID][ch] = struct{}{} logger.Debug(fmt.Sprintf("新订阅添加 sessionID=%s userID=%s totalSubscriptions=%d", sessionID, userID, len(ed.subscriptions[sessionID]))) return ch, nil } // Unsubscribe 取消订阅指定会话的事件 func (ed *EventDispatcher) Unsubscribe(sessionID string, ch <-chan string) { ed.mu.Lock() defer ed.mu.Unlock() if channels, exists := ed.subscriptions[sessionID]; exists { // 遍历查找对应的通道(因为ch是只读通道,无法直接作为key) var foundChan chan string for candidate := range channels { // 比较通道值 if candidate == ch { foundChan = candidate break } } if foundChan != nil { close(foundChan) delete(channels, foundChan) logger.Debug(fmt.Sprintf("订阅已移除 sessionID=%s remainingSubscriptions=%d", sessionID, len(channels))) } // 如果没有订阅者了,清理该会话的映射 if len(channels) == 0 { delete(ed.subscriptions, sessionID) } } } // buildSSEURL 构建SSE URL,避免端口重复 func (ed *EventDispatcher) buildSSEURL() string { // 检查baseURL是否已包含端口 base := ed.baseURL // 简单检查:如果baseURL已经包含端口号模式(冒号后跟数字),就不再加端口 // 查找最后一个冒号的位置 lastColon := -1 for i := len(base) - 1; i >= 0; i-- { if base[i] == ':' { lastColon = i break } } if lastColon != -1 { // 检查冒号后是否都是数字(端口号) hasPort := true for i := lastColon + 1; i < len(base); i++ { if base[i] < '0' || base[i] > '9' { hasPort = false break } } if hasPort { // baseURL已有端口,直接拼接路径 if strings.HasSuffix(base, "/") { return base + "global/event" } return base + "/global/event" } } // baseURL没有端口或端口格式不正确,添加端口 if strings.HasSuffix(base, "/") { return fmt.Sprintf("%s:%d/global/event", strings.TrimSuffix(base, "/"), ed.port) } return fmt.Sprintf("%s:%d/global/event", base, ed.port) } // runSSEConnection 运行SSE连接,读取全局事件并分发 func (ed *EventDispatcher) runSSEConnection(ctx context.Context) { // 构建SSE URL,避免重复端口 url := ed.buildSSEURL() for { select { case <-ctx.Done(): logger.Info("SSE连接停止(上下文取消)") return default: // 建立SSE连接 logger.Info(fmt.Sprintf("正在连接SSE流 url=%s", url)) if err := ed.connectAndProcessSSE(ctx, url); err != nil { logger.Error(fmt.Sprintf("SSE连接失败,5秒后重试 error=%s url=%s", err.Error(), url)) select { case <-ctx.Done(): return case <-time.After(5 * time.Second): continue } } } } } // connectAndProcessSSE 连接并处理SSE流 func (ed *EventDispatcher) connectAndProcessSSE(ctx context.Context, url string) error { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Accept", "text/event-stream") resp, err := ed.client.Do(req) if err != nil { return fmt.Errorf("发送请求失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("SSE请求失败,状态码: %d, 响应: %s", resp.StatusCode, string(body)) } logger.Info(fmt.Sprintf("SSE连接已建立 url=%s", url)) reader := bufio.NewReader(resp.Body) eventCount := 0 for { select { case <-ctx.Done(): return nil default: line, err := reader.ReadString('\n') if err != nil { if err == io.EOF { logger.Info(fmt.Sprintf("SSE流正常结束 totalEvents=%d", eventCount)) } else if ctx.Err() != nil { logger.Info("SSE流上下文取消") } else { logger.Error(fmt.Sprintf("读取SSE流错误 error=%s", err.Error())) } return err } line = strings.TrimSpace(line) if line == "" { continue } if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") eventCount++ // 分发事件 ed.dispatchEvent(data) } } } } // dispatchEvent 分发事件到相关订阅者 func (ed *EventDispatcher) dispatchEvent(data string) { // 解析事件数据获取sessionID sessionID := extractSessionIDFromEvent(data) if sessionID == "" { // 没有sessionID的事件(如全局心跳)丢弃,不广播给所有订阅者 // 确保按sessionID严格隔离,避免多用户消息交叉 return } // 只记录关键事件的路由日志,减少日志输出 var eventMap map[string]interface{} if err := json.Unmarshal([]byte(data), &eventMap); err == nil { // 提取事件类型 var eventType string if payload, ok := eventMap["payload"].(map[string]interface{}); ok { if t, ok := payload["type"].(string); ok { eventType = t } } // 只记录关键事件类型的路由信息 switch eventType { case "session.status", "message.updated", "session.diff", "session.idle": logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s type=%s", sessionID, eventType)) } } // 只分发给订阅该会话的通道 ed.mu.RLock() channels, exists := ed.subscriptions[sessionID] ed.mu.RUnlock() if !exists { // 没有该会话的订阅者,忽略事件 return } // 发送事件到所有订阅该会话的通道 ed.mu.RLock() for ch := range channels { select { case ch <- data: // 成功发送 default: // 通道已满,丢弃事件并记录警告 logger.Warn(fmt.Sprintf("事件通道已满,丢弃事件 sessionID=%s", sessionID)) } } ed.mu.RUnlock() } // extractSessionIDFromEvent 从事件数据中提取sessionID func extractSessionIDFromEvent(data string) string { // 尝试解析为JSON var eventMap map[string]interface{} if err := json.Unmarshal([]byte(data), &eventMap); err != nil { logger.Error("无法解析事件JSON", "error", err.Error(), "dataPreview", safeSubstring(data, 0, 200)) return "" } // 递归查找sessionID字段 sessionID := findSessionIDRecursive(eventMap) return sessionID } // findSessionIDRecursive 递归查找sessionID字段 func findSessionIDRecursive(data interface{}) string { switch v := data.(type) { case map[string]interface{}: // 检查当前层级的sessionID字段(支持多种命名变体) for _, key := range []string{"sessionID", "session_id", "sessionId"} { if val, ok := v[key]; ok { if str, ok := val.(string); ok && str != "" { return str } } } // 检查常见嵌套路径 // 1. payload.properties.sessionID (session.status事件) if payload, ok := v["payload"].(map[string]interface{}); ok { if props, ok := payload["properties"].(map[string]interface{}); ok { if sessionID, ok := props["sessionID"].(string); ok && sessionID != "" { return sessionID } } } // 2. payload.properties.part.sessionID (message.part.updated事件) if payload, ok := v["payload"].(map[string]interface{}); ok { if props, ok := payload["properties"].(map[string]interface{}); ok { if part, ok := props["part"].(map[string]interface{}); ok { if sessionID, ok := part["sessionID"].(string); ok && sessionID != "" { return sessionID } } } } // 3. payload.properties.info.sessionID (message.updated事件) if payload, ok := v["payload"].(map[string]interface{}); ok { if props, ok := payload["properties"].(map[string]interface{}); ok { if info, ok := props["info"].(map[string]interface{}); ok { if sessionID, ok := info["sessionID"].(string); ok && sessionID != "" { return sessionID } } } } // 递归遍历所有值 for _, value := range v { if result := findSessionIDRecursive(value); result != "" { return result } } case []interface{}: // 遍历数组 for _, item := range v { if result := findSessionIDRecursive(item); result != "" { return result } } } return "" } // safeSubstring 安全的子字符串函数 func safeSubstring(s string, start, length int) string { if start < 0 { start = 0 } if start >= len(s) { return "" } end := start + length if end > len(s) { end = len(s) } return s[start:end] } // GetInstance 获取单例实例(线程安全) var ( instance *EventDispatcher instanceOnce sync.Once ) // GetEventDispatcher 获取事件分发器单例 func GetEventDispatcher(baseURL string, port int) *EventDispatcher { instanceOnce.Do(func() { instance = NewEventDispatcher(baseURL, port) }) return instance }