| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- package event
-
- import (
- "bufio"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "sync"
- "time"
-
- "git.x2erp.com/qdy/go-base/logger"
- )
-
- // EventDispatcher 事件分发器 - 单例模式
- type EventDispatcher struct {
- mu sync.RWMutex
- baseURL string
- port int
- subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels
- client *http.Client
- cancelFunc context.CancelFunc
- running bool
- }
-
- // NewEventDispatcher 创建新的事件分发器
- func NewEventDispatcher(baseURL string, port int) *EventDispatcher {
- return &EventDispatcher{
- baseURL: baseURL,
- port: port,
- subscriptions: make(map[string]map[chan string]struct{}),
- client: &http.Client{
- Timeout: 0, // 无超时限制,用于长连接
- },
- running: false,
- }
- }
-
- // Start 启动事件分发器,连接到opencode全局事件流
- func (ed *EventDispatcher) Start(ctx context.Context) error {
- ed.mu.Lock()
- if ed.running {
- ed.mu.Unlock()
- return fmt.Errorf("event dispatcher already running")
- }
-
- // 创建子上下文用于控制SSE连接
- sseCtx, cancel := context.WithCancel(ctx)
- ed.cancelFunc = cancel
- ed.running = true
- ed.mu.Unlock()
-
- // 启动SSE连接协程
- go ed.runSSEConnection(sseCtx)
-
- logger.Info(fmt.Sprintf("事件分发器已启动 baseURL=%s port=%d",
- ed.baseURL, ed.port))
- return nil
- }
-
- // Stop 停止事件分发器
- func (ed *EventDispatcher) Stop() {
- ed.mu.Lock()
- if !ed.running {
- ed.mu.Unlock()
- return
- }
-
- if ed.cancelFunc != nil {
- ed.cancelFunc()
- }
-
- // 清理所有订阅通道
- for sessionID, channels := range ed.subscriptions {
- for ch := range channels {
- close(ch)
- }
- delete(ed.subscriptions, sessionID)
- }
-
- ed.running = false
- ed.mu.Unlock()
-
- logger.Info("事件分发器已停止")
- }
-
- // Subscribe 订阅指定会话的事件
- func (ed *EventDispatcher) Subscribe(sessionID, userID string) (<-chan string, error) {
- ed.mu.Lock()
- defer ed.mu.Unlock()
-
- // 创建缓冲通道
- ch := make(chan string, 100)
-
- // 添加到订阅列表
- if _, exists := ed.subscriptions[sessionID]; !exists {
- ed.subscriptions[sessionID] = make(map[chan string]struct{})
- }
- ed.subscriptions[sessionID][ch] = struct{}{}
-
- logger.Debug(fmt.Sprintf("新订阅添加 sessionID=%s userID=%s totalSubscriptions=%d",
- sessionID, userID, len(ed.subscriptions[sessionID])))
-
- return ch, nil
- }
-
- // Unsubscribe 取消订阅指定会话的事件
- func (ed *EventDispatcher) Unsubscribe(sessionID string, ch <-chan string) {
- ed.mu.Lock()
- defer ed.mu.Unlock()
-
- if channels, exists := ed.subscriptions[sessionID]; exists {
- // 遍历查找对应的通道(因为ch是只读通道,无法直接作为key)
- var foundChan chan string
- for candidate := range channels {
- // 比较通道值
- if candidate == ch {
- foundChan = candidate
- break
- }
- }
-
- if foundChan != nil {
- close(foundChan)
- delete(channels, foundChan)
- logger.Debug(fmt.Sprintf("订阅已移除 sessionID=%s remainingSubscriptions=%d",
- sessionID, len(channels)))
- }
-
- // 如果没有订阅者了,清理该会话的映射
- if len(channels) == 0 {
- delete(ed.subscriptions, sessionID)
- }
- }
- }
-
- // buildSSEURL 构建SSE URL,避免端口重复
- func (ed *EventDispatcher) buildSSEURL() string {
- // 检查baseURL是否已包含端口
- base := ed.baseURL
- // 简单检查:如果baseURL已经包含端口号模式(冒号后跟数字),就不再加端口
- // 查找最后一个冒号的位置
- lastColon := -1
- for i := len(base) - 1; i >= 0; i-- {
- if base[i] == ':' {
- lastColon = i
- break
- }
- }
-
- if lastColon != -1 {
- // 检查冒号后是否都是数字(端口号)
- hasPort := true
- for i := lastColon + 1; i < len(base); i++ {
- if base[i] < '0' || base[i] > '9' {
- hasPort = false
- break
- }
- }
- if hasPort {
- // baseURL已有端口,直接拼接路径
- if strings.HasSuffix(base, "/") {
- return base + "global/event"
- }
- return base + "/global/event"
- }
- }
-
- // baseURL没有端口或端口格式不正确,添加端口
- if strings.HasSuffix(base, "/") {
- return fmt.Sprintf("%s:%d/global/event", strings.TrimSuffix(base, "/"), ed.port)
- }
- return fmt.Sprintf("%s:%d/global/event", base, ed.port)
- }
-
- // runSSEConnection 运行SSE连接,读取全局事件并分发
- func (ed *EventDispatcher) runSSEConnection(ctx context.Context) {
- // 构建SSE URL,避免重复端口
- url := ed.buildSSEURL()
-
- for {
- select {
- case <-ctx.Done():
- logger.Info("SSE连接停止(上下文取消)")
- return
- default:
- // 建立SSE连接
- logger.Info(fmt.Sprintf("正在连接SSE流 url=%s",
- url))
- if err := ed.connectAndProcessSSE(ctx, url); err != nil {
- logger.Error(fmt.Sprintf("SSE连接失败,5秒后重试 error=%s url=%s",
- err.Error(), url))
-
- select {
- case <-ctx.Done():
- return
- case <-time.After(5 * time.Second):
- continue
- }
- }
- }
- }
- }
-
- // connectAndProcessSSE 连接并处理SSE流
- func (ed *EventDispatcher) connectAndProcessSSE(ctx context.Context, url string) error {
- req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
- if err != nil {
- return fmt.Errorf("创建请求失败: %w", err)
- }
- req.Header.Set("Accept", "text/event-stream")
-
- resp, err := ed.client.Do(req)
- if err != nil {
- return fmt.Errorf("发送请求失败: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("SSE请求失败,状态码: %d, 响应: %s", resp.StatusCode, string(body))
- }
-
- logger.Info(fmt.Sprintf("SSE连接已建立 url=%s",
- url))
-
- reader := bufio.NewReader(resp.Body)
- eventCount := 0
-
- for {
- select {
- case <-ctx.Done():
- return nil
- default:
- line, err := reader.ReadString('\n')
- if err != nil {
- if err == io.EOF {
- logger.Info(fmt.Sprintf("SSE流正常结束 totalEvents=%d",
- eventCount))
- } else if ctx.Err() != nil {
- logger.Info("SSE流上下文取消")
- } else {
- logger.Error(fmt.Sprintf("读取SSE流错误 error=%s",
- err.Error()))
- }
- return err
- }
-
- line = strings.TrimSpace(line)
- if line == "" {
- continue
- }
-
- if strings.HasPrefix(line, "data: ") {
- data := strings.TrimPrefix(line, "data: ")
- eventCount++
-
- // 分发事件
- ed.dispatchEvent(data)
-
- }
- }
- }
- }
-
- // dispatchEvent 分发事件到相关订阅者
- func (ed *EventDispatcher) dispatchEvent(data string) {
- // 解析事件数据获取sessionID
- sessionID := extractSessionIDFromEvent(data)
-
- if sessionID == "" {
- // 没有sessionID的事件(如全局心跳)丢弃,不广播给所有订阅者
- // 确保按sessionID严格隔离,避免多用户消息交叉
- return
- }
-
- // 只记录关键事件的路由日志,减少日志输出
- var eventMap map[string]interface{}
- if err := json.Unmarshal([]byte(data), &eventMap); err == nil {
- // 提取事件类型
- var eventType string
- if payload, ok := eventMap["payload"].(map[string]interface{}); ok {
- if t, ok := payload["type"].(string); ok {
- eventType = t
- }
- }
- // 只记录关键事件类型的路由信息
- switch eventType {
- case "session.status", "message.updated", "session.diff", "session.idle":
- logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s type=%s",
- sessionID, eventType))
- }
- }
-
- // 只分发给订阅该会话的通道
- ed.mu.RLock()
- channels, exists := ed.subscriptions[sessionID]
- ed.mu.RUnlock()
-
- if !exists {
- // 没有该会话的订阅者,忽略事件
- return
- }
-
- // 发送事件到所有订阅该会话的通道
- ed.mu.RLock()
- for ch := range channels {
- select {
- case ch <- data:
- // 成功发送
- default:
- // 通道已满,丢弃事件并记录警告
- logger.Warn(fmt.Sprintf("事件通道已满,丢弃事件 sessionID=%s",
- sessionID))
- }
- }
- ed.mu.RUnlock()
- }
-
- // extractSessionIDFromEvent 从事件数据中提取sessionID
- func extractSessionIDFromEvent(data string) string {
- // 尝试解析为JSON
- var eventMap map[string]interface{}
- if err := json.Unmarshal([]byte(data), &eventMap); err != nil {
- logger.Error("无法解析事件JSON", "error", err.Error(), "dataPreview", safeSubstring(data, 0, 200))
- return ""
- }
-
- // 递归查找sessionID字段
- sessionID := findSessionIDRecursive(eventMap)
-
- return sessionID
- }
-
- // findSessionIDRecursive 递归查找sessionID字段
- func findSessionIDRecursive(data interface{}) string {
- switch v := data.(type) {
- case map[string]interface{}:
- // 检查当前层级的sessionID字段(支持多种命名变体)
- for _, key := range []string{"sessionID", "session_id", "sessionId"} {
- if val, ok := v[key]; ok {
- if str, ok := val.(string); ok && str != "" {
- return str
- }
- }
- }
-
- // 检查常见嵌套路径
- // 1. payload.properties.sessionID (session.status事件)
- if payload, ok := v["payload"].(map[string]interface{}); ok {
- if props, ok := payload["properties"].(map[string]interface{}); ok {
- if sessionID, ok := props["sessionID"].(string); ok && sessionID != "" {
- return sessionID
- }
- }
- }
-
- // 2. payload.properties.part.sessionID (message.part.updated事件)
- if payload, ok := v["payload"].(map[string]interface{}); ok {
- if props, ok := payload["properties"].(map[string]interface{}); ok {
- if part, ok := props["part"].(map[string]interface{}); ok {
- if sessionID, ok := part["sessionID"].(string); ok && sessionID != "" {
- return sessionID
- }
- }
- }
- }
-
- // 3. payload.properties.info.sessionID (message.updated事件)
- if payload, ok := v["payload"].(map[string]interface{}); ok {
- if props, ok := payload["properties"].(map[string]interface{}); ok {
- if info, ok := props["info"].(map[string]interface{}); ok {
- if sessionID, ok := info["sessionID"].(string); ok && sessionID != "" {
- return sessionID
- }
- }
- }
- }
-
- // 递归遍历所有值
- for _, value := range v {
- if result := findSessionIDRecursive(value); result != "" {
- return result
- }
- }
-
- case []interface{}:
- // 遍历数组
- for _, item := range v {
- if result := findSessionIDRecursive(item); result != "" {
- return result
- }
- }
- }
-
- return ""
- }
-
- // safeSubstring 安全的子字符串函数
- func safeSubstring(s string, start, length int) string {
- if start < 0 {
- start = 0
- }
- if start >= len(s) {
- return ""
- }
- end := start + length
- if end > len(s) {
- end = len(s)
- }
- return s[start:end]
- }
-
- // GetInstance 获取单例实例(线程安全)
- var (
- instance *EventDispatcher
- instanceOnce sync.Once
- )
-
- // GetEventDispatcher 获取事件分发器单例
- func GetEventDispatcher(baseURL string, port int) *EventDispatcher {
- instanceOnce.Do(func() {
- instance = NewEventDispatcher(baseURL, port)
- })
- return instance
- }
|