No Description
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dispatcher.go 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. package event
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "time"
  12. "git.x2erp.com/qdy/go-base/logger"
  13. )
  14. // EventDispatcher 事件分发器 - 单例模式
  15. type EventDispatcher struct {
  16. mu sync.RWMutex
  17. baseURL string
  18. port int
  19. subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels
  20. client *http.Client
  21. cancelFunc context.CancelFunc
  22. running bool
  23. }
  24. // NewEventDispatcher 创建新的事件分发器
  25. func NewEventDispatcher(baseURL string, port int) *EventDispatcher {
  26. return &EventDispatcher{
  27. baseURL: baseURL,
  28. port: port,
  29. subscriptions: make(map[string]map[chan string]struct{}),
  30. client: &http.Client{
  31. Timeout: 0, // 无超时限制,用于长连接
  32. },
  33. running: false,
  34. }
  35. }
  36. // Start 启动事件分发器,连接到opencode全局事件流
  37. func (ed *EventDispatcher) Start(ctx context.Context) error {
  38. ed.mu.Lock()
  39. if ed.running {
  40. ed.mu.Unlock()
  41. return fmt.Errorf("event dispatcher already running")
  42. }
  43. // 创建子上下文用于控制SSE连接
  44. sseCtx, cancel := context.WithCancel(ctx)
  45. ed.cancelFunc = cancel
  46. ed.running = true
  47. ed.mu.Unlock()
  48. // 启动SSE连接协程
  49. go ed.runSSEConnection(sseCtx)
  50. logger.Info(fmt.Sprintf("事件分发器已启动 baseURL=%s port=%d",
  51. ed.baseURL, ed.port))
  52. return nil
  53. }
  54. // Stop 停止事件分发器
  55. func (ed *EventDispatcher) Stop() {
  56. ed.mu.Lock()
  57. if !ed.running {
  58. ed.mu.Unlock()
  59. return
  60. }
  61. if ed.cancelFunc != nil {
  62. ed.cancelFunc()
  63. }
  64. // 清理所有订阅通道
  65. for sessionID, channels := range ed.subscriptions {
  66. for ch := range channels {
  67. close(ch)
  68. }
  69. delete(ed.subscriptions, sessionID)
  70. }
  71. ed.running = false
  72. ed.mu.Unlock()
  73. logger.Info("事件分发器已停止")
  74. }
  75. // Subscribe 订阅指定会话的事件
  76. func (ed *EventDispatcher) Subscribe(sessionID, userID string) (<-chan string, error) {
  77. ed.mu.Lock()
  78. defer ed.mu.Unlock()
  79. // 创建缓冲通道
  80. ch := make(chan string, 100)
  81. // 添加到订阅列表
  82. if _, exists := ed.subscriptions[sessionID]; !exists {
  83. ed.subscriptions[sessionID] = make(map[chan string]struct{})
  84. }
  85. ed.subscriptions[sessionID][ch] = struct{}{}
  86. logger.Debug(fmt.Sprintf("新订阅添加 sessionID=%s userID=%s totalSubscriptions=%d",
  87. sessionID, userID, len(ed.subscriptions[sessionID])))
  88. return ch, nil
  89. }
  90. // Unsubscribe 取消订阅指定会话的事件
  91. func (ed *EventDispatcher) Unsubscribe(sessionID string, ch <-chan string) {
  92. ed.mu.Lock()
  93. defer ed.mu.Unlock()
  94. if channels, exists := ed.subscriptions[sessionID]; exists {
  95. // 遍历查找对应的通道(因为ch是只读通道,无法直接作为key)
  96. var foundChan chan string
  97. for candidate := range channels {
  98. // 比较通道值
  99. if candidate == ch {
  100. foundChan = candidate
  101. break
  102. }
  103. }
  104. if foundChan != nil {
  105. close(foundChan)
  106. delete(channels, foundChan)
  107. logger.Debug(fmt.Sprintf("订阅已移除 sessionID=%s remainingSubscriptions=%d",
  108. sessionID, len(channels)))
  109. }
  110. // 如果没有订阅者了,清理该会话的映射
  111. if len(channels) == 0 {
  112. delete(ed.subscriptions, sessionID)
  113. }
  114. }
  115. }
  116. // buildSSEURL 构建SSE URL,避免端口重复
  117. func (ed *EventDispatcher) buildSSEURL() string {
  118. // 检查baseURL是否已包含端口
  119. base := ed.baseURL
  120. // 简单检查:如果baseURL已经包含端口号模式(冒号后跟数字),就不再加端口
  121. // 查找最后一个冒号的位置
  122. lastColon := -1
  123. for i := len(base) - 1; i >= 0; i-- {
  124. if base[i] == ':' {
  125. lastColon = i
  126. break
  127. }
  128. }
  129. if lastColon != -1 {
  130. // 检查冒号后是否都是数字(端口号)
  131. hasPort := true
  132. for i := lastColon + 1; i < len(base); i++ {
  133. if base[i] < '0' || base[i] > '9' {
  134. hasPort = false
  135. break
  136. }
  137. }
  138. if hasPort {
  139. // baseURL已有端口,直接拼接路径
  140. if strings.HasSuffix(base, "/") {
  141. return base + "global/event"
  142. }
  143. return base + "/global/event"
  144. }
  145. }
  146. // baseURL没有端口或端口格式不正确,添加端口
  147. if strings.HasSuffix(base, "/") {
  148. return fmt.Sprintf("%s:%d/global/event", strings.TrimSuffix(base, "/"), ed.port)
  149. }
  150. return fmt.Sprintf("%s:%d/global/event", base, ed.port)
  151. }
  152. // runSSEConnection 运行SSE连接,读取全局事件并分发
  153. func (ed *EventDispatcher) runSSEConnection(ctx context.Context) {
  154. // 构建SSE URL,避免重复端口
  155. url := ed.buildSSEURL()
  156. for {
  157. select {
  158. case <-ctx.Done():
  159. logger.Info("SSE连接停止(上下文取消)")
  160. return
  161. default:
  162. // 建立SSE连接
  163. logger.Info(fmt.Sprintf("正在连接SSE流 url=%s",
  164. url))
  165. if err := ed.connectAndProcessSSE(ctx, url); err != nil {
  166. logger.Error(fmt.Sprintf("SSE连接失败,5秒后重试 error=%s url=%s",
  167. err.Error(), url))
  168. select {
  169. case <-ctx.Done():
  170. return
  171. case <-time.After(5 * time.Second):
  172. continue
  173. }
  174. }
  175. }
  176. }
  177. }
  178. // connectAndProcessSSE 连接并处理SSE流
  179. func (ed *EventDispatcher) connectAndProcessSSE(ctx context.Context, url string) error {
  180. req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  181. if err != nil {
  182. return fmt.Errorf("创建请求失败: %w", err)
  183. }
  184. req.Header.Set("Accept", "text/event-stream")
  185. resp, err := ed.client.Do(req)
  186. if err != nil {
  187. return fmt.Errorf("发送请求失败: %w", err)
  188. }
  189. defer resp.Body.Close()
  190. if resp.StatusCode != http.StatusOK {
  191. body, _ := io.ReadAll(resp.Body)
  192. return fmt.Errorf("SSE请求失败,状态码: %d, 响应: %s", resp.StatusCode, string(body))
  193. }
  194. logger.Info(fmt.Sprintf("SSE连接已建立 url=%s",
  195. url))
  196. reader := bufio.NewReader(resp.Body)
  197. eventCount := 0
  198. for {
  199. select {
  200. case <-ctx.Done():
  201. return nil
  202. default:
  203. line, err := reader.ReadString('\n')
  204. if err != nil {
  205. if err == io.EOF {
  206. logger.Info(fmt.Sprintf("SSE流正常结束 totalEvents=%d",
  207. eventCount))
  208. } else if ctx.Err() != nil {
  209. logger.Info("SSE流上下文取消")
  210. } else {
  211. logger.Error(fmt.Sprintf("读取SSE流错误 error=%s",
  212. err.Error()))
  213. }
  214. return err
  215. }
  216. line = strings.TrimSpace(line)
  217. if line == "" {
  218. continue
  219. }
  220. if strings.HasPrefix(line, "data: ") {
  221. data := strings.TrimPrefix(line, "data: ")
  222. eventCount++
  223. // 分发事件
  224. ed.dispatchEvent(data)
  225. }
  226. }
  227. }
  228. }
  229. // dispatchEvent 分发事件到相关订阅者
  230. func (ed *EventDispatcher) dispatchEvent(data string) {
  231. // 解析事件数据获取sessionID
  232. sessionID := extractSessionIDFromEvent(data)
  233. if sessionID == "" {
  234. // 没有sessionID的事件(如全局心跳)丢弃,不广播给所有订阅者
  235. // 确保按sessionID严格隔离,避免多用户消息交叉
  236. return
  237. }
  238. // 只记录关键事件的路由日志,减少日志输出
  239. var eventMap map[string]interface{}
  240. if err := json.Unmarshal([]byte(data), &eventMap); err == nil {
  241. // 提取事件类型
  242. var eventType string
  243. if payload, ok := eventMap["payload"].(map[string]interface{}); ok {
  244. if t, ok := payload["type"].(string); ok {
  245. eventType = t
  246. }
  247. }
  248. // 只记录关键事件类型的路由信息
  249. switch eventType {
  250. case "session.status", "message.updated", "session.diff", "session.idle":
  251. logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s type=%s",
  252. sessionID, eventType))
  253. }
  254. }
  255. // 只分发给订阅该会话的通道
  256. ed.mu.RLock()
  257. channels, exists := ed.subscriptions[sessionID]
  258. ed.mu.RUnlock()
  259. if !exists {
  260. // 没有该会话的订阅者,忽略事件
  261. return
  262. }
  263. // 发送事件到所有订阅该会话的通道
  264. ed.mu.RLock()
  265. for ch := range channels {
  266. select {
  267. case ch <- data:
  268. // 成功发送
  269. default:
  270. // 通道已满,丢弃事件并记录警告
  271. logger.Warn(fmt.Sprintf("事件通道已满,丢弃事件 sessionID=%s",
  272. sessionID))
  273. }
  274. }
  275. ed.mu.RUnlock()
  276. }
  277. // extractSessionIDFromEvent 从事件数据中提取sessionID
  278. func extractSessionIDFromEvent(data string) string {
  279. // 尝试解析为JSON
  280. var eventMap map[string]interface{}
  281. if err := json.Unmarshal([]byte(data), &eventMap); err != nil {
  282. logger.Error("无法解析事件JSON", "error", err.Error(), "dataPreview", safeSubstring(data, 0, 200))
  283. return ""
  284. }
  285. // 递归查找sessionID字段
  286. sessionID := findSessionIDRecursive(eventMap)
  287. return sessionID
  288. }
  289. // findSessionIDRecursive 递归查找sessionID字段
  290. func findSessionIDRecursive(data interface{}) string {
  291. switch v := data.(type) {
  292. case map[string]interface{}:
  293. // 检查当前层级的sessionID字段(支持多种命名变体)
  294. for _, key := range []string{"sessionID", "session_id", "sessionId"} {
  295. if val, ok := v[key]; ok {
  296. if str, ok := val.(string); ok && str != "" {
  297. return str
  298. }
  299. }
  300. }
  301. // 检查常见嵌套路径
  302. // 1. payload.properties.sessionID (session.status事件)
  303. if payload, ok := v["payload"].(map[string]interface{}); ok {
  304. if props, ok := payload["properties"].(map[string]interface{}); ok {
  305. if sessionID, ok := props["sessionID"].(string); ok && sessionID != "" {
  306. return sessionID
  307. }
  308. }
  309. }
  310. // 2. payload.properties.part.sessionID (message.part.updated事件)
  311. if payload, ok := v["payload"].(map[string]interface{}); ok {
  312. if props, ok := payload["properties"].(map[string]interface{}); ok {
  313. if part, ok := props["part"].(map[string]interface{}); ok {
  314. if sessionID, ok := part["sessionID"].(string); ok && sessionID != "" {
  315. return sessionID
  316. }
  317. }
  318. }
  319. }
  320. // 3. payload.properties.info.sessionID (message.updated事件)
  321. if payload, ok := v["payload"].(map[string]interface{}); ok {
  322. if props, ok := payload["properties"].(map[string]interface{}); ok {
  323. if info, ok := props["info"].(map[string]interface{}); ok {
  324. if sessionID, ok := info["sessionID"].(string); ok && sessionID != "" {
  325. return sessionID
  326. }
  327. }
  328. }
  329. }
  330. // 递归遍历所有值
  331. for _, value := range v {
  332. if result := findSessionIDRecursive(value); result != "" {
  333. return result
  334. }
  335. }
  336. case []interface{}:
  337. // 遍历数组
  338. for _, item := range v {
  339. if result := findSessionIDRecursive(item); result != "" {
  340. return result
  341. }
  342. }
  343. }
  344. return ""
  345. }
  346. // safeSubstring 安全的子字符串函数
  347. func safeSubstring(s string, start, length int) string {
  348. if start < 0 {
  349. start = 0
  350. }
  351. if start >= len(s) {
  352. return ""
  353. }
  354. end := start + length
  355. if end > len(s) {
  356. end = len(s)
  357. }
  358. return s[start:end]
  359. }
  360. // GetInstance 获取单例实例(线程安全)
  361. var (
  362. instance *EventDispatcher
  363. instanceOnce sync.Once
  364. )
  365. // GetEventDispatcher 获取事件分发器单例
  366. func GetEventDispatcher(baseURL string, port int) *EventDispatcher {
  367. instanceOnce.Do(func() {
  368. instance = NewEventDispatcher(baseURL, port)
  369. })
  370. return instance
  371. }