설명 없음
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dispatcher.go 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. package event
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "time"
  12. "git.x2erp.com/qdy/go-base/logger"
  13. )
  14. // MessageType 消息类型枚举
  15. type MessageType string
  16. const (
  17. MessageTypeThinking MessageType = "thinking" // 思考过程
  18. MessageTypeTool MessageType = "tool" // 工具调用
  19. MessageTypeReply MessageType = "reply" // 最终回复
  20. MessageTypeUnknown MessageType = "unknown" // 未知类型
  21. )
  22. // CompletionHook 完成钩子接口,用于消息完成时的处理(如保存到数据库)
  23. type CompletionHook interface {
  24. OnMessageComplete(sessionID string, messageID string, completeText string, eventType string, metadata map[string]interface{})
  25. }
  26. // MessageState 消息状态,用于跟踪单个消息的增量内容
  27. type MessageState struct {
  28. SessionID string
  29. MessageID string
  30. StartTime time.Time
  31. LastUpdate time.Time
  32. Metadata map[string]interface{}
  33. // 分离的缓冲区,用于不同类型的内容
  34. ReasoningBuffer strings.Builder // 思考内容
  35. ReplyBuffer strings.Builder // 回答内容
  36. ToolBuffer strings.Builder // 工具调用
  37. // 类型完成状态跟踪
  38. CompletedTypes map[string]bool // 已完成的类型: "reasoning", "text", "tool"
  39. HookTriggered map[string]bool // 钩子触发状态,按类型记录
  40. // 当前活跃类型(用于跟踪正在处理的内容)
  41. CurrentType string // 当前正在处理的类型
  42. }
  43. // MessageAggregator 消息聚合器,负责合并增量内容并检测完成状态
  44. type MessageAggregator struct {
  45. mu sync.RWMutex
  46. messages map[string]*MessageState // key: sessionID_messageID
  47. hooks []CompletionHook
  48. OnMessageCompleteFunc func(sessionID string, messageID string, completeText string, messageType MessageType, metadata map[string]interface{}) // 消息完成回调函数
  49. OnEventProcessedFunc func(sessionID string, eventType string, eventData string, eventMap map[string]interface{}) // 事件处理回调函数
  50. }
  51. // NewMessageAggregator 创建新的消息聚合器
  52. func NewMessageAggregator() *MessageAggregator {
  53. return &MessageAggregator{
  54. messages: make(map[string]*MessageState),
  55. hooks: make([]CompletionHook, 0),
  56. OnMessageCompleteFunc: nil,
  57. OnEventProcessedFunc: nil,
  58. }
  59. }
  60. // RegisterHook 注册完成钩子
  61. func (ma *MessageAggregator) RegisterHook(hook CompletionHook) {
  62. ma.mu.Lock()
  63. defer ma.mu.Unlock()
  64. ma.hooks = append(ma.hooks, hook)
  65. }
  66. // ProcessEvent 处理事件,合并增量内容并检测完成状态
  67. func (ma *MessageAggregator) ProcessEvent(eventData string, sessionID string) {
  68. // 解析事件数据
  69. var eventMap map[string]interface{}
  70. if err := json.Unmarshal([]byte(eventData), &eventMap); err != nil {
  71. logger.Debug(fmt.Sprintf("无法解析事件JSON error=%s dataPreview=%s",
  72. err.Error(), safeSubstring(eventData, 0, 200)))
  73. return
  74. }
  75. eventType := getEventType(eventMap)
  76. // 诊断日志:记录处理的事件类型
  77. logger.Debug(fmt.Sprintf("🔍 ProcessEvent: sessionID=%s eventType=%s dataPreview=%s",
  78. sessionID, eventType, safeSubstring(eventData, 0, 100)))
  79. // 调用事件处理回调函数(如果有设置)
  80. if ma.OnEventProcessedFunc != nil {
  81. ma.OnEventProcessedFunc(sessionID, eventType, eventData, eventMap)
  82. }
  83. // 只处理 message.part.updated, message.updated, session.status 事件
  84. if eventType == "message.part.updated" {
  85. ma.handleMessagePartUpdated(eventMap, sessionID, eventData)
  86. } else if eventType == "message.updated" {
  87. ma.handleMessageUpdated(eventMap, sessionID)
  88. } else if eventType == "session.status" {
  89. ma.handleSessionStatus(eventMap, sessionID)
  90. } else if eventType == "session.idle" {
  91. ma.handleSessionIdle(sessionID)
  92. }
  93. // 其他事件类型忽略
  94. }
  95. // handleMessagePartUpdated 处理 message.part.updated 事件
  96. func (ma *MessageAggregator) handleMessagePartUpdated(eventMap map[string]interface{}, sessionID string, eventData string) {
  97. // 提取消息部分信息
  98. payload, _ := eventMap["payload"].(map[string]interface{})
  99. props, _ := payload["properties"].(map[string]interface{})
  100. part, _ := props["part"].(map[string]interface{})
  101. messageID, _ := part["messageID"].(string)
  102. partType, _ := part["type"].(string)
  103. if messageID == "" || (partType != "text" && partType != "reasoning" && partType != "step-finish" && partType != "tool") {
  104. return
  105. }
  106. ma.mu.Lock()
  107. defer ma.mu.Unlock()
  108. key := sessionID + "_" + messageID
  109. state, exists := ma.messages[key]
  110. if !exists {
  111. // step-finish 事件不应该创建新状态,它应该总是跟随在text/reasoning事件之后
  112. if partType == "step-finish" {
  113. logger.Debug(fmt.Sprintf("忽略step-finish事件,无对应消息状态 sessionID=%s messageID=%s",
  114. sessionID, messageID))
  115. return
  116. }
  117. // 新消息,初始化状态
  118. state = &MessageState{
  119. SessionID: sessionID,
  120. MessageID: messageID,
  121. StartTime: time.Now(),
  122. LastUpdate: time.Now(),
  123. Metadata: make(map[string]interface{}),
  124. CompletedTypes: make(map[string]bool),
  125. HookTriggered: make(map[string]bool),
  126. CurrentType: partType,
  127. }
  128. ma.messages[key] = state
  129. logger.Debug(fmt.Sprintf("开始跟踪新消息 sessionID=%s messageID=%s type=%s",
  130. sessionID, messageID, partType))
  131. }
  132. // 更新增量内容 - 根据类型写入不同的缓冲区
  133. state.CurrentType = partType
  134. if text, ok := part["text"].(string); ok && text != "" {
  135. // 根据类型选择缓冲区
  136. switch partType {
  137. case "reasoning":
  138. state.ReasoningBuffer.WriteString(text)
  139. case "text":
  140. state.ReplyBuffer.WriteString(text)
  141. case "tool":
  142. // 工具调用可能有text字段,也可能通过其他方式处理
  143. state.ToolBuffer.WriteString(text)
  144. }
  145. state.LastUpdate = time.Now()
  146. // 记录增量合并日志(仅调试)- 已禁用以减少日志量
  147. // logger.Debug(fmt.Sprintf("合并增量内容 sessionID=%s messageID=%s type=%s deltaLength=%d",
  148. // sessionID, messageID, partType, len(text)))
  149. } else if partType == "tool" {
  150. // 处理工具调用事件,提取工具信息
  151. if name, ok := part["name"].(string); ok && name != "" {
  152. toolText := fmt.Sprintf("[工具调用: %s", name)
  153. if args, ok := part["arguments"].(string); ok && args != "" {
  154. // 参数可能是JSON字符串,可以尝试美化
  155. toolText += fmt.Sprintf(" 参数: %s", args)
  156. }
  157. toolText += "]"
  158. state.ToolBuffer.WriteString(toolText)
  159. state.LastUpdate = time.Now()
  160. logger.Debug(fmt.Sprintf("合并工具调用内容 sessionID=%s messageID=%s toolName=%s",
  161. sessionID, messageID, name))
  162. }
  163. }
  164. // 检查是否为 step-finish
  165. if partType == "step-finish" {
  166. // 标记当前类型为完成
  167. if state.CurrentType != "" {
  168. state.CompletedTypes[state.CurrentType] = true
  169. logger.Info(fmt.Sprintf("步骤完成 sessionID=%s messageID=%s type=%s",
  170. sessionID, messageID, state.CurrentType))
  171. // 触发该类型的完成钩子
  172. ma.triggerTypeCompletionHooks(state, state.CurrentType)
  173. } else {
  174. logger.Warn(fmt.Sprintf("step-finish事件无当前类型 sessionID=%s messageID=%s",
  175. sessionID, messageID))
  176. }
  177. }
  178. }
  179. // handleMessageUpdated 处理 message.updated 事件
  180. func (ma *MessageAggregator) handleMessageUpdated(eventMap map[string]interface{}, sessionID string) {
  181. payload, _ := eventMap["payload"].(map[string]interface{})
  182. props, _ := payload["properties"].(map[string]interface{})
  183. info, _ := props["info"].(map[string]interface{})
  184. messageID, _ := info["id"].(string)
  185. finishReason, _ := info["finish"].(string)
  186. if messageID == "" || finishReason == "" {
  187. return
  188. }
  189. key := sessionID + "_" + messageID
  190. ma.mu.Lock()
  191. defer ma.mu.Unlock()
  192. if state, exists := ma.messages[key]; exists {
  193. logger.Info(fmt.Sprintf("消息完成 sessionID=%s messageID=%s finishReason=%s",
  194. sessionID, messageID, finishReason))
  195. // 触发完成钩子(处理所有类型的内容)
  196. ma.triggerCompletionHooks(state)
  197. // 清理完成的消息状态
  198. delete(ma.messages, key)
  199. }
  200. }
  201. // handleSessionStatus 处理 session.status 事件
  202. func (ma *MessageAggregator) handleSessionStatus(eventMap map[string]interface{}, sessionID string) {
  203. payload, _ := eventMap["payload"].(map[string]interface{})
  204. props, _ := payload["properties"].(map[string]interface{})
  205. status, _ := props["status"].(map[string]interface{})
  206. statusType, _ := status["type"].(string)
  207. if statusType == "idle" {
  208. logger.Info(fmt.Sprintf("会话进入空闲状态 sessionID=%s", sessionID))
  209. // 可以清理该会话的所有消息状态
  210. ma.cleanupSession(sessionID)
  211. }
  212. }
  213. // handleSessionIdle 处理 session.idle 事件
  214. func (ma *MessageAggregator) handleSessionIdle(sessionID string) {
  215. logger.Info(fmt.Sprintf("会话空闲事件 sessionID=%s", sessionID))
  216. ma.cleanupSession(sessionID)
  217. }
  218. // cleanupSession 清理指定会话的所有消息状态
  219. func (ma *MessageAggregator) cleanupSession(sessionID string) {
  220. ma.mu.Lock()
  221. defer ma.mu.Unlock()
  222. keysToDelete := make([]string, 0)
  223. for key, state := range ma.messages {
  224. if state.SessionID == sessionID {
  225. keysToDelete = append(keysToDelete, key)
  226. // 检查是否有未触发的类型内容,强制触发完成钩子
  227. hasUnfinishedContent := false
  228. // 检查每种类型是否有内容但钩子未触发
  229. if state.ReasoningBuffer.Len() > 0 {
  230. if triggered, exists := state.HookTriggered["reasoning"]; !exists || !triggered {
  231. hasUnfinishedContent = true
  232. }
  233. }
  234. if state.ReplyBuffer.Len() > 0 {
  235. if triggered, exists := state.HookTriggered["text"]; !exists || !triggered {
  236. hasUnfinishedContent = true
  237. }
  238. }
  239. if state.ToolBuffer.Len() > 0 {
  240. if triggered, exists := state.HookTriggered["tool"]; !exists || !triggered {
  241. hasUnfinishedContent = true
  242. }
  243. }
  244. if hasUnfinishedContent {
  245. logger.Warn(fmt.Sprintf("强制完成未完成消息 sessionID=%s messageID=%s",
  246. sessionID, state.MessageID))
  247. ma.triggerCompletionHooks(state)
  248. }
  249. }
  250. }
  251. for _, key := range keysToDelete {
  252. delete(ma.messages, key)
  253. }
  254. if len(keysToDelete) > 0 {
  255. logger.Info(fmt.Sprintf("清理会话消息状态 sessionID=%s cleanedCount=%d",
  256. sessionID, len(keysToDelete)))
  257. }
  258. }
  259. // triggerTypeCompletionHooks 触发特定类型的完成钩子
  260. func (ma *MessageAggregator) triggerTypeCompletionHooks(state *MessageState, partType string) {
  261. // 避免重复触发
  262. if triggered, exists := state.HookTriggered[partType]; exists && triggered {
  263. logger.Debug(fmt.Sprintf("钩子已触发过,跳过 sessionID=%s messageID=%s type=%s",
  264. state.SessionID, state.MessageID, partType))
  265. return
  266. }
  267. // 获取该类型的缓冲区内容
  268. var completeText string
  269. var textLength int
  270. switch partType {
  271. case "reasoning":
  272. completeText = state.ReasoningBuffer.String()
  273. textLength = len(completeText)
  274. case "text":
  275. completeText = state.ReplyBuffer.String()
  276. textLength = len(completeText)
  277. case "tool":
  278. completeText = state.ToolBuffer.String()
  279. textLength = len(completeText)
  280. default:
  281. logger.Warn(fmt.Sprintf("未知类型,跳过钩子触发 sessionID=%s messageID=%s type=%s",
  282. state.SessionID, state.MessageID, partType))
  283. return
  284. }
  285. duration := time.Since(state.StartTime)
  286. // 记录完成事件
  287. logger.Info(fmt.Sprintf("🔔 类型完成总结 sessionID=%s messageID=%s type=%s textLength=%d duration=%v",
  288. state.SessionID, state.MessageID, partType, textLength, duration))
  289. if textLength > 0 {
  290. // 记录文本预览(前200字符)
  291. preview := completeText
  292. if len(preview) > 200 {
  293. preview = preview[:200] + "..."
  294. }
  295. logger.Info(fmt.Sprintf("📝 类型文本预览: %s", preview))
  296. } else {
  297. logger.Info("📭 空文本完成事件")
  298. }
  299. // 转换事件类型为消息类型枚举
  300. messageType := convertToMessageType(partType)
  301. // 调用消息完成回调函数(如果设置)- 使用新的类型化接口
  302. // 注意:这里调用现有的OnMessageCompleteFunc,传入特定类型的内容
  303. if ma.OnMessageCompleteFunc != nil {
  304. ma.OnMessageCompleteFunc(state.SessionID, state.MessageID, completeText, messageType, state.Metadata)
  305. }
  306. // 调用所有注册的钩子(保持向后兼容)
  307. for _, hook := range ma.hooks {
  308. hook.OnMessageComplete(state.SessionID, state.MessageID, completeText, partType, state.Metadata)
  309. }
  310. // 标记该类型的钩子已触发
  311. state.HookTriggered[partType] = true
  312. logger.Debug(fmt.Sprintf("✅ 类型钩子标记为已触发 sessionID=%s messageID=%s type=%s",
  313. state.SessionID, state.MessageID, partType))
  314. }
  315. // triggerCompletionHooks 触发完成钩子(向后兼容,触发所有类型)
  316. func (ma *MessageAggregator) triggerCompletionHooks(state *MessageState) {
  317. // 遍历所有支持的类型,触发各自的完成钩子
  318. types := []string{"reasoning", "text", "tool"}
  319. hasAnyContent := false
  320. for _, partType := range types {
  321. // 检查该类型是否有内容
  322. var hasContent bool
  323. switch partType {
  324. case "reasoning":
  325. hasContent = state.ReasoningBuffer.Len() > 0
  326. case "text":
  327. hasContent = state.ReplyBuffer.Len() > 0
  328. case "tool":
  329. hasContent = state.ToolBuffer.Len() > 0
  330. }
  331. if hasContent {
  332. hasAnyContent = true
  333. // 触发该类型的完成钩子(函数内部会检查是否已触发)
  334. ma.triggerTypeCompletionHooks(state, partType)
  335. }
  336. }
  337. // 如果没有内容,至少记录一个事件(保持向后兼容)
  338. if !hasAnyContent {
  339. logger.Info(fmt.Sprintf("📭 空消息完成事件 sessionID=%s messageID=%s",
  340. state.SessionID, state.MessageID))
  341. }
  342. }
  343. // EventDispatcher 事件分发器 - 单例模式
  344. type EventDispatcher struct {
  345. mu sync.RWMutex
  346. baseURL string
  347. port int
  348. subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels
  349. sessionUserCache *SessionUserCache // sessionID -> userID 映射缓存(用于用户验证)
  350. client *http.Client
  351. cancelFunc context.CancelFunc
  352. running bool
  353. messageAggregator *MessageAggregator // 消息聚合器
  354. }
  355. // EventData opencode事件数据结构 - 匹配实际事件格式
  356. type EventData struct {
  357. Directory string `json:"directory,omitempty"`
  358. Payload map[string]interface{} `json:"payload"`
  359. }
  360. // PayloadData payload内部结构(辅助类型)
  361. type PayloadData struct {
  362. Type string `json:"type"`
  363. Properties map[string]interface{} `json:"properties,omitempty"`
  364. }
  365. // NewEventDispatcher 创建新的事件分发器
  366. func NewEventDispatcher(baseURL string, port int) *EventDispatcher {
  367. return &EventDispatcher{
  368. baseURL: baseURL,
  369. port: port,
  370. subscriptions: make(map[string]map[chan string]struct{}),
  371. sessionUserCache: NewSessionUserCache(20 * time.Minute),
  372. messageAggregator: NewMessageAggregator(),
  373. client: &http.Client{
  374. Timeout: 0, // 无超时限制,用于长连接
  375. },
  376. running: false,
  377. }
  378. }
  379. // Start 启动事件分发器,连接到opencode全局事件流
  380. func (ed *EventDispatcher) Start(ctx context.Context) error {
  381. ed.mu.Lock()
  382. if ed.running {
  383. ed.mu.Unlock()
  384. return fmt.Errorf("event dispatcher already running")
  385. }
  386. // 创建子上下文用于控制SSE连接
  387. sseCtx, cancel := context.WithCancel(ctx)
  388. ed.cancelFunc = cancel
  389. ed.running = true
  390. ed.mu.Unlock()
  391. // 启动SSE连接协程
  392. go ed.runSSEConnection(sseCtx)
  393. logger.Info(fmt.Sprintf("事件分发器已启动 baseURL=%s port=%d",
  394. ed.baseURL, ed.port))
  395. return nil
  396. }
  397. // Stop 停止事件分发器
  398. func (ed *EventDispatcher) Stop() {
  399. ed.mu.Lock()
  400. if !ed.running {
  401. ed.mu.Unlock()
  402. return
  403. }
  404. if ed.cancelFunc != nil {
  405. ed.cancelFunc()
  406. }
  407. // 清理所有订阅通道
  408. for sessionID, channels := range ed.subscriptions {
  409. for ch := range channels {
  410. close(ch)
  411. }
  412. delete(ed.subscriptions, sessionID)
  413. }
  414. ed.running = false
  415. ed.mu.Unlock()
  416. logger.Info("事件分发器已停止")
  417. }
  418. // Subscribe 订阅指定会话的事件
  419. func (ed *EventDispatcher) Subscribe(sessionID, userID string) (<-chan string, error) {
  420. ed.mu.Lock()
  421. defer ed.mu.Unlock()
  422. // 缓存会话-用户映射(用于未来需要用户验证时)
  423. ed.sessionUserCache.Set(sessionID, userID)
  424. // 创建缓冲通道
  425. ch := make(chan string, 100)
  426. // 添加到订阅列表
  427. if _, exists := ed.subscriptions[sessionID]; !exists {
  428. ed.subscriptions[sessionID] = make(map[chan string]struct{})
  429. }
  430. ed.subscriptions[sessionID][ch] = struct{}{}
  431. logger.Debug(fmt.Sprintf("新订阅添加 sessionID=%s userID=%s totalSubscriptions=%d",
  432. sessionID, userID, len(ed.subscriptions[sessionID])))
  433. return ch, nil
  434. }
  435. // Unsubscribe 取消订阅指定会话的事件
  436. func (ed *EventDispatcher) Unsubscribe(sessionID string, ch <-chan string) {
  437. ed.mu.Lock()
  438. defer ed.mu.Unlock()
  439. if channels, exists := ed.subscriptions[sessionID]; exists {
  440. // 遍历查找对应的通道(因为ch是只读通道,无法直接作为key)
  441. var foundChan chan string
  442. for candidate := range channels {
  443. // 比较通道值
  444. if candidate == ch {
  445. foundChan = candidate
  446. break
  447. }
  448. }
  449. if foundChan != nil {
  450. close(foundChan)
  451. delete(channels, foundChan)
  452. logger.Debug(fmt.Sprintf("订阅已移除 sessionID=%s remainingSubscriptions=%d",
  453. sessionID, len(channels)))
  454. }
  455. // 如果没有订阅者了,清理该会话的映射
  456. if len(channels) == 0 {
  457. delete(ed.subscriptions, sessionID)
  458. ed.sessionUserCache.Delete(sessionID)
  459. }
  460. }
  461. }
  462. // RegisterSession 注册会话(前端调用SendPromptStream时调用)
  463. func (ed *EventDispatcher) RegisterSession(sessionID, userID string) {
  464. ed.sessionUserCache.Set(sessionID, userID)
  465. logger.Debug(fmt.Sprintf("会话已注册 sessionID=%s userID=%s",
  466. sessionID, userID))
  467. }
  468. // buildSSEURL 构建SSE URL,避免端口重复
  469. func (ed *EventDispatcher) buildSSEURL() string {
  470. // 检查baseURL是否已包含端口
  471. base := ed.baseURL
  472. // 简单检查:如果baseURL已经包含端口号模式(冒号后跟数字),就不再加端口
  473. // 查找最后一个冒号的位置
  474. lastColon := -1
  475. for i := len(base) - 1; i >= 0; i-- {
  476. if base[i] == ':' {
  477. lastColon = i
  478. break
  479. }
  480. }
  481. if lastColon != -1 {
  482. // 检查冒号后是否都是数字(端口号)
  483. hasPort := true
  484. for i := lastColon + 1; i < len(base); i++ {
  485. if base[i] < '0' || base[i] > '9' {
  486. hasPort = false
  487. break
  488. }
  489. }
  490. if hasPort {
  491. // baseURL已有端口,直接拼接路径
  492. if strings.HasSuffix(base, "/") {
  493. return base + "global/event"
  494. }
  495. return base + "/global/event"
  496. }
  497. }
  498. // baseURL没有端口或端口格式不正确,添加端口
  499. if strings.HasSuffix(base, "/") {
  500. return fmt.Sprintf("%s:%d/global/event", strings.TrimSuffix(base, "/"), ed.port)
  501. }
  502. return fmt.Sprintf("%s:%d/global/event", base, ed.port)
  503. }
  504. // runSSEConnection 运行SSE连接,读取全局事件并分发
  505. func (ed *EventDispatcher) runSSEConnection(ctx context.Context) {
  506. // 构建SSE URL,避免重复端口
  507. url := ed.buildSSEURL()
  508. for {
  509. select {
  510. case <-ctx.Done():
  511. logger.Info("SSE连接停止(上下文取消)")
  512. return
  513. default:
  514. // 建立SSE连接
  515. logger.Info(fmt.Sprintf("正在连接SSE流 url=%s",
  516. url))
  517. if err := ed.connectAndProcessSSE(ctx, url); err != nil {
  518. logger.Error(fmt.Sprintf("SSE连接失败,5秒后重试 error=%s url=%s",
  519. err.Error(), url))
  520. select {
  521. case <-ctx.Done():
  522. return
  523. case <-time.After(5 * time.Second):
  524. continue
  525. }
  526. }
  527. }
  528. }
  529. }
  530. // connectAndProcessSSE 连接并处理SSE流
  531. func (ed *EventDispatcher) connectAndProcessSSE(ctx context.Context, url string) error {
  532. req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  533. if err != nil {
  534. return fmt.Errorf("创建请求失败: %w", err)
  535. }
  536. req.Header.Set("Accept", "text/event-stream")
  537. resp, err := ed.client.Do(req)
  538. if err != nil {
  539. return fmt.Errorf("发送请求失败: %w", err)
  540. }
  541. defer resp.Body.Close()
  542. if resp.StatusCode != http.StatusOK {
  543. body, _ := io.ReadAll(resp.Body)
  544. return fmt.Errorf("SSE请求失败,状态码: %d, 响应: %s", resp.StatusCode, string(body))
  545. }
  546. logger.Info(fmt.Sprintf("SSE连接已建立 url=%s",
  547. url))
  548. reader := bufio.NewReader(resp.Body)
  549. eventCount := 0
  550. for {
  551. select {
  552. case <-ctx.Done():
  553. return nil
  554. default:
  555. line, err := reader.ReadString('\n')
  556. if err != nil {
  557. if err == io.EOF {
  558. logger.Info(fmt.Sprintf("SSE流正常结束 totalEvents=%d",
  559. eventCount))
  560. } else if ctx.Err() != nil {
  561. logger.Info("SSE流上下文取消")
  562. } else {
  563. logger.Error(fmt.Sprintf("读取SSE流错误 error=%s",
  564. err.Error()))
  565. }
  566. return err
  567. }
  568. line = strings.TrimSpace(line)
  569. if line == "" {
  570. continue
  571. }
  572. if strings.HasPrefix(line, "data: ") {
  573. data := strings.TrimPrefix(line, "data: ")
  574. eventCount++
  575. // 分发事件
  576. ed.dispatchEvent(data)
  577. if eventCount%100 == 0 {
  578. //logger.Debug(fmt.Sprintf("事件处理统计 totalEvents=%d activeSessions=%d",
  579. // eventCount, len(ed.subscriptions)))
  580. }
  581. }
  582. }
  583. }
  584. }
  585. // dispatchEvent 分发事件到相关订阅者
  586. func (ed *EventDispatcher) dispatchEvent(data string) {
  587. // 解析事件数据获取sessionID
  588. sessionID := extractSessionIDFromEvent(data)
  589. // 处理事件聚合(无论是否有sessionID都处理)
  590. if ed.messageAggregator != nil && sessionID != "" {
  591. ed.messageAggregator.ProcessEvent(data, sessionID)
  592. }
  593. if sessionID == "" {
  594. // 没有sessionID的事件(如全局心跳)分发给所有订阅者
  595. //logger.Debug(fmt.Sprintf("广播全局事件 dataPreview=%s", safeSubstring(data, 0, 100)))
  596. ed.broadcastToAll(data)
  597. return
  598. }
  599. // 只记录非增量文本事件的路由日志,减少日志量
  600. shouldLog := true
  601. var eventMap map[string]interface{}
  602. if err := json.Unmarshal([]byte(data), &eventMap); err == nil {
  603. eventType := getEventType(eventMap)
  604. // message.part.updated 是最频繁的事件,跳过其路由日志
  605. if eventType == "message.part.updated" {
  606. shouldLog = false
  607. }
  608. }
  609. if shouldLog {
  610. logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s dataPreview=%s",
  611. sessionID, safeSubstring(data, 0, 100)))
  612. }
  613. // 只分发给订阅该会话的通道
  614. ed.mu.RLock()
  615. channels, exists := ed.subscriptions[sessionID]
  616. ed.mu.RUnlock()
  617. if !exists {
  618. // 没有该会话的订阅者,忽略事件
  619. //logger.Debug(fmt.Sprintf("忽略事件,无订阅者 sessionID=%s", sessionID))
  620. return
  621. }
  622. // 发送事件到所有订阅该会话的通道
  623. ed.mu.RLock()
  624. for ch := range channels {
  625. select {
  626. case ch <- data:
  627. // 成功发送
  628. default:
  629. // 通道已满,丢弃事件并记录警告
  630. logger.Warn(fmt.Sprintf("事件通道已满,丢弃事件 sessionID=%s",
  631. sessionID))
  632. }
  633. }
  634. ed.mu.RUnlock()
  635. }
  636. // broadcastToAll 广播事件给所有订阅者(用于全局事件如心跳)
  637. func (ed *EventDispatcher) broadcastToAll(data string) {
  638. //logger.Debug(fmt.Sprintf("广播事件给所有订阅者 dataPreview=%s", safeSubstring(data, 0, 100)))
  639. ed.mu.RLock()
  640. defer ed.mu.RUnlock()
  641. for sessionID, channels := range ed.subscriptions {
  642. for ch := range channels {
  643. select {
  644. case ch <- data:
  645. // 成功发送
  646. default:
  647. // 通道已满,丢弃事件
  648. logger.Warn(fmt.Sprintf("事件通道已满,丢弃全局事件 sessionID=%s",
  649. sessionID))
  650. }
  651. }
  652. }
  653. }
  654. // extractSessionIDFromEvent 从事件数据中提取sessionID
  655. func extractSessionIDFromEvent(data string) string {
  656. // 尝试解析为JSON
  657. var eventMap map[string]interface{}
  658. if err := json.Unmarshal([]byte(data), &eventMap); err != nil {
  659. logger.ErrorC(fmt.Sprintf("无法解析事件JSON error=%s dataPreview=%s",
  660. err.Error(), safeSubstring(data, 0, 200)))
  661. return ""
  662. }
  663. // 添加调试日志,显示完整事件结构(仅调试时启用)
  664. debugMode := true
  665. if debugMode {
  666. eventJSON, _ := json.MarshalIndent(eventMap, "", " ")
  667. logger.Debug(fmt.Sprintf("事件数据结构 eventStructure=%s",
  668. string(eventJSON)))
  669. }
  670. // 递归查找sessionID字段
  671. sessionID := findSessionIDRecursive(eventMap)
  672. if sessionID == "" {
  673. //logger.Debug(fmt.Sprintf("未找到sessionID字段 eventType=%s dataPreview=%s",
  674. // getEventType(eventMap), safeSubstring(data, 0, 100)))
  675. }
  676. return sessionID
  677. }
  678. // findSessionIDRecursive 递归查找sessionID字段
  679. func findSessionIDRecursive(data interface{}) string {
  680. switch v := data.(type) {
  681. case map[string]interface{}:
  682. // 检查当前层级的sessionID字段(支持多种命名变体)
  683. for _, key := range []string{"sessionID", "session_id", "sessionId"} {
  684. if val, ok := v[key]; ok {
  685. if str, ok := val.(string); ok && str != "" {
  686. return str
  687. }
  688. }
  689. }
  690. // 检查常见嵌套路径
  691. // 1. payload.properties.sessionID (session.status事件)
  692. if payload, ok := v["payload"].(map[string]interface{}); ok {
  693. if props, ok := payload["properties"].(map[string]interface{}); ok {
  694. if sessionID, ok := props["sessionID"].(string); ok && sessionID != "" {
  695. return sessionID
  696. }
  697. }
  698. }
  699. // 2. payload.properties.part.sessionID (message.part.updated事件)
  700. if payload, ok := v["payload"].(map[string]interface{}); ok {
  701. if props, ok := payload["properties"].(map[string]interface{}); ok {
  702. if part, ok := props["part"].(map[string]interface{}); ok {
  703. if sessionID, ok := part["sessionID"].(string); ok && sessionID != "" {
  704. return sessionID
  705. }
  706. }
  707. }
  708. }
  709. // 3. payload.properties.info.sessionID (message.updated事件)
  710. if payload, ok := v["payload"].(map[string]interface{}); ok {
  711. if props, ok := payload["properties"].(map[string]interface{}); ok {
  712. if info, ok := props["info"].(map[string]interface{}); ok {
  713. if sessionID, ok := info["sessionID"].(string); ok && sessionID != "" {
  714. return sessionID
  715. }
  716. }
  717. }
  718. }
  719. // 递归遍历所有值
  720. for _, value := range v {
  721. if result := findSessionIDRecursive(value); result != "" {
  722. return result
  723. }
  724. }
  725. case []interface{}:
  726. // 遍历数组
  727. for _, item := range v {
  728. if result := findSessionIDRecursive(item); result != "" {
  729. return result
  730. }
  731. }
  732. }
  733. return ""
  734. }
  735. // getEventType 获取事件类型
  736. func getEventType(eventMap map[string]interface{}) string {
  737. if payload, ok := eventMap["payload"].(map[string]interface{}); ok {
  738. if eventType, ok := payload["type"].(string); ok {
  739. return eventType
  740. }
  741. }
  742. return "unknown"
  743. }
  744. // convertToMessageType 将事件类型转换为消息类型枚举
  745. func convertToMessageType(eventType string) MessageType {
  746. switch eventType {
  747. case "reasoning":
  748. return MessageTypeThinking
  749. case "text":
  750. return MessageTypeReply
  751. case "tool":
  752. return MessageTypeTool
  753. default:
  754. // 尝试识别其他类型
  755. if strings.Contains(eventType, "reasoning") || strings.Contains(eventType, "thinking") {
  756. return MessageTypeThinking
  757. }
  758. if strings.Contains(eventType, "tool") || strings.Contains(eventType, "function") {
  759. return MessageTypeTool
  760. }
  761. return MessageTypeUnknown
  762. }
  763. }
  764. // safeSubstring 安全的子字符串函数
  765. func safeSubstring(s string, start, length int) string {
  766. if start < 0 {
  767. start = 0
  768. }
  769. if start >= len(s) {
  770. return ""
  771. }
  772. end := start + length
  773. if end > len(s) {
  774. end = len(s)
  775. }
  776. return s[start:end]
  777. }
  778. // getHookTriggerSource 获取钩子触发源(用于调试)
  779. func getHookTriggerSource(state *MessageState) string {
  780. // 检查是否有任何类型的钩子已触发
  781. for _, triggered := range state.HookTriggered {
  782. if triggered {
  783. return "completed"
  784. }
  785. }
  786. return "unknown"
  787. }
  788. // RegisterHook 注册完成钩子
  789. func (ed *EventDispatcher) RegisterHook(hook CompletionHook) {
  790. if ed.messageAggregator != nil {
  791. ed.messageAggregator.RegisterHook(hook)
  792. }
  793. }
  794. // SetOnMessageCompleteFunc 设置消息完成回调函数
  795. func (ed *EventDispatcher) SetOnMessageCompleteFunc(f func(sessionID string, messageID string, completeText string, messageType MessageType, metadata map[string]interface{})) {
  796. if ed.messageAggregator != nil {
  797. ed.messageAggregator.OnMessageCompleteFunc = f
  798. logger.Info("✅ 消息完成回调函数已设置")
  799. }
  800. }
  801. // SetOnEventProcessedFunc 设置事件处理回调函数
  802. func (ed *EventDispatcher) SetOnEventProcessedFunc(f func(sessionID string, eventType string, eventData string, eventMap map[string]interface{})) {
  803. if ed.messageAggregator != nil {
  804. ed.messageAggregator.OnEventProcessedFunc = f
  805. logger.Info("✅ 事件处理回调函数已设置")
  806. }
  807. }
  808. // GetInstance 获取单例实例(线程安全)
  809. var (
  810. instance *EventDispatcher
  811. instanceOnce sync.Once
  812. )
  813. // GetEventDispatcher 获取事件分发器单例
  814. func GetEventDispatcher(baseURL string, port int) *EventDispatcher {
  815. instanceOnce.Do(func() {
  816. instance = NewEventDispatcher(baseURL, port)
  817. })
  818. return instance
  819. }
  820. // DiagnosticHook 诊断钩子实现(用于调试和测试)
  821. type DiagnosticHook struct{}
  822. func (h *DiagnosticHook) OnMessageComplete(sessionID string, messageID string, completeText string, eventType string, metadata map[string]interface{}) {
  823. logger.Info(fmt.Sprintf("🔍 诊断钩子触发: session=%s message=%s type=%s textLength=%d",
  824. sessionID, messageID, eventType, len(completeText)))
  825. if len(completeText) > 0 {
  826. preview := completeText
  827. if len(preview) > 150 {
  828. preview = preview[:150] + "..."
  829. }
  830. logger.Info(fmt.Sprintf("📋 诊断钩子文本预览: %s", preview))
  831. } else {
  832. logger.Info("📭 诊断钩子: 空文本")
  833. }
  834. // 记录元数据(如果有)
  835. if metadata != nil && len(metadata) > 0 {
  836. logger.Info(fmt.Sprintf("📊 诊断钩子元数据: %+v", metadata))
  837. }
  838. }
  839. // 使用示例:
  840. // dispatcher := GetEventDispatcher("http://localhost", 3000)
  841. // dispatcher.RegisterHook(&DiagnosticHook{})