Ei kuvausta
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dispatcher.go 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. package event
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "time"
  12. "git.x2erp.com/qdy/go-base/logger"
  13. )
  14. // EventDispatcher 事件分发器 - 单例模式
  15. type EventDispatcher struct {
  16. mu sync.RWMutex
  17. baseURL string
  18. port int
  19. subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels
  20. sessionUserCache *SessionUserCache // sessionID -> userID 映射缓存(用于用户验证)
  21. client *http.Client
  22. cancelFunc context.CancelFunc
  23. running bool
  24. }
  25. // EventData opencode事件数据结构 - 匹配实际事件格式
  26. type EventData struct {
  27. Directory string `json:"directory,omitempty"`
  28. Payload map[string]interface{} `json:"payload"`
  29. }
  30. // PayloadData payload内部结构(辅助类型)
  31. type PayloadData struct {
  32. Type string `json:"type"`
  33. Properties map[string]interface{} `json:"properties,omitempty"`
  34. }
  35. // NewEventDispatcher 创建新的事件分发器
  36. func NewEventDispatcher(baseURL string, port int) *EventDispatcher {
  37. return &EventDispatcher{
  38. baseURL: baseURL,
  39. port: port,
  40. subscriptions: make(map[string]map[chan string]struct{}),
  41. sessionUserCache: NewSessionUserCache(20 * time.Minute),
  42. client: &http.Client{
  43. Timeout: 0, // 无超时限制,用于长连接
  44. },
  45. running: false,
  46. }
  47. }
  48. // Start 启动事件分发器,连接到opencode全局事件流
  49. func (ed *EventDispatcher) Start(ctx context.Context) error {
  50. ed.mu.Lock()
  51. if ed.running {
  52. ed.mu.Unlock()
  53. return fmt.Errorf("event dispatcher already running")
  54. }
  55. // 创建子上下文用于控制SSE连接
  56. sseCtx, cancel := context.WithCancel(ctx)
  57. ed.cancelFunc = cancel
  58. ed.running = true
  59. ed.mu.Unlock()
  60. // 启动SSE连接协程
  61. go ed.runSSEConnection(sseCtx)
  62. logger.Info(fmt.Sprintf("事件分发器已启动 baseURL=%s port=%d",
  63. ed.baseURL, ed.port))
  64. return nil
  65. }
  66. // Stop 停止事件分发器
  67. func (ed *EventDispatcher) Stop() {
  68. ed.mu.Lock()
  69. if !ed.running {
  70. ed.mu.Unlock()
  71. return
  72. }
  73. if ed.cancelFunc != nil {
  74. ed.cancelFunc()
  75. }
  76. // 清理所有订阅通道
  77. for sessionID, channels := range ed.subscriptions {
  78. for ch := range channels {
  79. close(ch)
  80. }
  81. delete(ed.subscriptions, sessionID)
  82. }
  83. ed.running = false
  84. ed.mu.Unlock()
  85. logger.Info("事件分发器已停止")
  86. }
  87. // Subscribe 订阅指定会话的事件
  88. func (ed *EventDispatcher) Subscribe(sessionID, userID string) (<-chan string, error) {
  89. ed.mu.Lock()
  90. defer ed.mu.Unlock()
  91. // 缓存会话-用户映射(用于未来需要用户验证时)
  92. ed.sessionUserCache.Set(sessionID, userID)
  93. // 创建缓冲通道
  94. ch := make(chan string, 100)
  95. // 添加到订阅列表
  96. if _, exists := ed.subscriptions[sessionID]; !exists {
  97. ed.subscriptions[sessionID] = make(map[chan string]struct{})
  98. }
  99. ed.subscriptions[sessionID][ch] = struct{}{}
  100. logger.Debug(fmt.Sprintf("新订阅添加 sessionID=%s userID=%s totalSubscriptions=%d",
  101. sessionID, userID, len(ed.subscriptions[sessionID])))
  102. return ch, nil
  103. }
  104. // Unsubscribe 取消订阅指定会话的事件
  105. func (ed *EventDispatcher) Unsubscribe(sessionID string, ch <-chan string) {
  106. ed.mu.Lock()
  107. defer ed.mu.Unlock()
  108. if channels, exists := ed.subscriptions[sessionID]; exists {
  109. // 遍历查找对应的通道(因为ch是只读通道,无法直接作为key)
  110. var foundChan chan string
  111. for candidate := range channels {
  112. // 比较通道值
  113. if candidate == ch {
  114. foundChan = candidate
  115. break
  116. }
  117. }
  118. if foundChan != nil {
  119. close(foundChan)
  120. delete(channels, foundChan)
  121. logger.Debug(fmt.Sprintf("订阅已移除 sessionID=%s remainingSubscriptions=%d",
  122. sessionID, len(channels)))
  123. }
  124. // 如果没有订阅者了,清理该会话的映射
  125. if len(channels) == 0 {
  126. delete(ed.subscriptions, sessionID)
  127. ed.sessionUserCache.Delete(sessionID)
  128. }
  129. }
  130. }
  131. // RegisterSession 注册会话(前端调用SendPromptStream时调用)
  132. func (ed *EventDispatcher) RegisterSession(sessionID, userID string) {
  133. ed.sessionUserCache.Set(sessionID, userID)
  134. logger.Debug(fmt.Sprintf("会话已注册 sessionID=%s userID=%s",
  135. sessionID, userID))
  136. }
  137. // buildSSEURL 构建SSE URL,避免端口重复
  138. func (ed *EventDispatcher) buildSSEURL() string {
  139. // 检查baseURL是否已包含端口
  140. base := ed.baseURL
  141. // 简单检查:如果baseURL已经包含端口号模式(冒号后跟数字),就不再加端口
  142. // 查找最后一个冒号的位置
  143. lastColon := -1
  144. for i := len(base) - 1; i >= 0; i-- {
  145. if base[i] == ':' {
  146. lastColon = i
  147. break
  148. }
  149. }
  150. if lastColon != -1 {
  151. // 检查冒号后是否都是数字(端口号)
  152. hasPort := true
  153. for i := lastColon + 1; i < len(base); i++ {
  154. if base[i] < '0' || base[i] > '9' {
  155. hasPort = false
  156. break
  157. }
  158. }
  159. if hasPort {
  160. // baseURL已有端口,直接拼接路径
  161. if strings.HasSuffix(base, "/") {
  162. return base + "global/event"
  163. }
  164. return base + "/global/event"
  165. }
  166. }
  167. // baseURL没有端口或端口格式不正确,添加端口
  168. if strings.HasSuffix(base, "/") {
  169. return fmt.Sprintf("%s:%d/global/event", strings.TrimSuffix(base, "/"), ed.port)
  170. }
  171. return fmt.Sprintf("%s:%d/global/event", base, ed.port)
  172. }
  173. // runSSEConnection 运行SSE连接,读取全局事件并分发
  174. func (ed *EventDispatcher) runSSEConnection(ctx context.Context) {
  175. // 构建SSE URL,避免重复端口
  176. url := ed.buildSSEURL()
  177. for {
  178. select {
  179. case <-ctx.Done():
  180. logger.Info("SSE连接停止(上下文取消)")
  181. return
  182. default:
  183. // 建立SSE连接
  184. logger.Info(fmt.Sprintf("正在连接SSE流 url=%s",
  185. url))
  186. if err := ed.connectAndProcessSSE(ctx, url); err != nil {
  187. logger.Error(fmt.Sprintf("SSE连接失败,5秒后重试 error=%s url=%s",
  188. err.Error(), url))
  189. select {
  190. case <-ctx.Done():
  191. return
  192. case <-time.After(5 * time.Second):
  193. continue
  194. }
  195. }
  196. }
  197. }
  198. }
  199. // connectAndProcessSSE 连接并处理SSE流
  200. func (ed *EventDispatcher) connectAndProcessSSE(ctx context.Context, url string) error {
  201. req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  202. if err != nil {
  203. return fmt.Errorf("创建请求失败: %w", err)
  204. }
  205. req.Header.Set("Accept", "text/event-stream")
  206. resp, err := ed.client.Do(req)
  207. if err != nil {
  208. return fmt.Errorf("发送请求失败: %w", err)
  209. }
  210. defer resp.Body.Close()
  211. if resp.StatusCode != http.StatusOK {
  212. body, _ := io.ReadAll(resp.Body)
  213. return fmt.Errorf("SSE请求失败,状态码: %d, 响应: %s", resp.StatusCode, string(body))
  214. }
  215. logger.Info(fmt.Sprintf("SSE连接已建立 url=%s",
  216. url))
  217. reader := bufio.NewReader(resp.Body)
  218. eventCount := 0
  219. for {
  220. select {
  221. case <-ctx.Done():
  222. return nil
  223. default:
  224. line, err := reader.ReadString('\n')
  225. if err != nil {
  226. if err == io.EOF {
  227. logger.Info(fmt.Sprintf("SSE流正常结束 totalEvents=%d",
  228. eventCount))
  229. } else if ctx.Err() != nil {
  230. logger.Info("SSE流上下文取消")
  231. } else {
  232. logger.Error(fmt.Sprintf("读取SSE流错误 error=%s",
  233. err.Error()))
  234. }
  235. return err
  236. }
  237. line = strings.TrimSpace(line)
  238. if line == "" {
  239. continue
  240. }
  241. if strings.HasPrefix(line, "data: ") {
  242. data := strings.TrimPrefix(line, "data: ")
  243. eventCount++
  244. // 分发事件
  245. ed.dispatchEvent(data)
  246. if eventCount%100 == 0 {
  247. logger.Debug(fmt.Sprintf("事件处理统计 totalEvents=%d activeSessions=%d",
  248. eventCount, len(ed.subscriptions)))
  249. }
  250. }
  251. }
  252. }
  253. }
  254. // dispatchEvent 分发事件到相关订阅者
  255. func (ed *EventDispatcher) dispatchEvent(data string) {
  256. // 解析事件数据获取sessionID
  257. sessionID := extractSessionIDFromEvent(data)
  258. if sessionID == "" {
  259. // 没有sessionID的事件(如全局心跳)分发给所有订阅者
  260. logger.Debug(fmt.Sprintf("广播全局事件 dataPreview=%s", safeSubstring(data, 0, 100)))
  261. ed.broadcastToAll(data)
  262. return
  263. }
  264. logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s dataPreview=%s",
  265. sessionID, safeSubstring(data, 0, 100)))
  266. // 只分发给订阅该会话的通道
  267. ed.mu.RLock()
  268. channels, exists := ed.subscriptions[sessionID]
  269. ed.mu.RUnlock()
  270. if !exists {
  271. // 没有该会话的订阅者,忽略事件
  272. logger.Debug(fmt.Sprintf("忽略事件,无订阅者 sessionID=%s", sessionID))
  273. return
  274. }
  275. // 发送事件到所有订阅该会话的通道
  276. ed.mu.RLock()
  277. for ch := range channels {
  278. select {
  279. case ch <- data:
  280. // 成功发送
  281. default:
  282. // 通道已满,丢弃事件并记录警告
  283. logger.Warn(fmt.Sprintf("事件通道已满,丢弃事件 sessionID=%s",
  284. sessionID))
  285. }
  286. }
  287. ed.mu.RUnlock()
  288. }
  289. // broadcastToAll 广播事件给所有订阅者(用于全局事件如心跳)
  290. func (ed *EventDispatcher) broadcastToAll(data string) {
  291. logger.Debug(fmt.Sprintf("广播事件给所有订阅者 dataPreview=%s", safeSubstring(data, 0, 100)))
  292. ed.mu.RLock()
  293. defer ed.mu.RUnlock()
  294. for sessionID, channels := range ed.subscriptions {
  295. for ch := range channels {
  296. select {
  297. case ch <- data:
  298. // 成功发送
  299. default:
  300. // 通道已满,丢弃事件
  301. logger.Warn(fmt.Sprintf("事件通道已满,丢弃全局事件 sessionID=%s",
  302. sessionID))
  303. }
  304. }
  305. }
  306. }
  307. // extractSessionIDFromEvent 从事件数据中提取sessionID
  308. func extractSessionIDFromEvent(data string) string {
  309. // 尝试解析为JSON
  310. var eventMap map[string]interface{}
  311. if err := json.Unmarshal([]byte(data), &eventMap); err != nil {
  312. logger.Debug(fmt.Sprintf("无法解析事件JSON error=%s dataPreview=%s",
  313. err.Error(), safeSubstring(data, 0, 200)))
  314. return ""
  315. }
  316. // 添加调试日志,显示完整事件结构(仅调试时启用)
  317. debugMode := false
  318. if debugMode {
  319. eventJSON, _ := json.MarshalIndent(eventMap, "", " ")
  320. logger.Debug(fmt.Sprintf("事件数据结构 eventStructure=%s",
  321. string(eventJSON)))
  322. }
  323. // 递归查找sessionID字段
  324. sessionID := findSessionIDRecursive(eventMap)
  325. if sessionID == "" {
  326. logger.Debug(fmt.Sprintf("未找到sessionID字段 eventType=%s dataPreview=%s",
  327. getEventType(eventMap), safeSubstring(data, 0, 100)))
  328. } else {
  329. logger.Debug(fmt.Sprintf("成功提取sessionID sessionID=%s eventType=%s",
  330. sessionID, getEventType(eventMap)))
  331. }
  332. return sessionID
  333. }
  334. // findSessionIDRecursive 递归查找sessionID字段
  335. func findSessionIDRecursive(data interface{}) string {
  336. switch v := data.(type) {
  337. case map[string]interface{}:
  338. // 检查当前层级的sessionID字段(支持多种命名变体)
  339. for _, key := range []string{"sessionID", "session_id", "sessionId"} {
  340. if val, ok := v[key]; ok {
  341. if str, ok := val.(string); ok && str != "" {
  342. return str
  343. }
  344. }
  345. }
  346. // 检查常见嵌套路径
  347. // 1. payload.properties.sessionID (session.status事件)
  348. if payload, ok := v["payload"].(map[string]interface{}); ok {
  349. if props, ok := payload["properties"].(map[string]interface{}); ok {
  350. if sessionID, ok := props["sessionID"].(string); ok && sessionID != "" {
  351. return sessionID
  352. }
  353. }
  354. }
  355. // 2. payload.properties.part.sessionID (message.part.updated事件)
  356. if payload, ok := v["payload"].(map[string]interface{}); ok {
  357. if props, ok := payload["properties"].(map[string]interface{}); ok {
  358. if part, ok := props["part"].(map[string]interface{}); ok {
  359. if sessionID, ok := part["sessionID"].(string); ok && sessionID != "" {
  360. return sessionID
  361. }
  362. }
  363. }
  364. }
  365. // 3. payload.properties.info.sessionID (message.updated事件)
  366. if payload, ok := v["payload"].(map[string]interface{}); ok {
  367. if props, ok := payload["properties"].(map[string]interface{}); ok {
  368. if info, ok := props["info"].(map[string]interface{}); ok {
  369. if sessionID, ok := info["sessionID"].(string); ok && sessionID != "" {
  370. return sessionID
  371. }
  372. }
  373. }
  374. }
  375. // 递归遍历所有值
  376. for _, value := range v {
  377. if result := findSessionIDRecursive(value); result != "" {
  378. return result
  379. }
  380. }
  381. case []interface{}:
  382. // 遍历数组
  383. for _, item := range v {
  384. if result := findSessionIDRecursive(item); result != "" {
  385. return result
  386. }
  387. }
  388. }
  389. return ""
  390. }
  391. // getEventType 获取事件类型
  392. func getEventType(eventMap map[string]interface{}) string {
  393. if payload, ok := eventMap["payload"].(map[string]interface{}); ok {
  394. if eventType, ok := payload["type"].(string); ok {
  395. return eventType
  396. }
  397. }
  398. return "unknown"
  399. }
  400. // safeSubstring 安全的子字符串函数
  401. func safeSubstring(s string, start, length int) string {
  402. if start < 0 {
  403. start = 0
  404. }
  405. if start >= len(s) {
  406. return ""
  407. }
  408. end := start + length
  409. if end > len(s) {
  410. end = len(s)
  411. }
  412. return s[start:end]
  413. }
  414. // GetInstance 获取单例实例(线程安全)
  415. var (
  416. instance *EventDispatcher
  417. instanceOnce sync.Once
  418. )
  419. // GetEventDispatcher 获取事件分发器单例
  420. func GetEventDispatcher(baseURL string, port int) *EventDispatcher {
  421. instanceOnce.Do(func() {
  422. instance = NewEventDispatcher(baseURL, port)
  423. })
  424. return instance
  425. }