|
|
@@ -14,426 +14,23 @@ import (
|
|
14
|
14
|
"git.x2erp.com/qdy/go-base/logger"
|
|
15
|
15
|
)
|
|
16
|
16
|
|
|
17
|
|
-// MessageType 消息类型枚举
|
|
18
|
|
-type MessageType string
|
|
19
|
|
-
|
|
20
|
|
-const (
|
|
21
|
|
- MessageTypeThinking MessageType = "thinking" // 思考过程
|
|
22
|
|
- MessageTypeTool MessageType = "tool" // 工具调用
|
|
23
|
|
- MessageTypeReply MessageType = "reply" // 最终回复
|
|
24
|
|
- MessageTypeUnknown MessageType = "unknown" // 未知类型
|
|
25
|
|
-)
|
|
26
|
|
-
|
|
27
|
|
-// CompletionHook 完成钩子接口,用于消息完成时的处理(如保存到数据库)
|
|
28
|
|
-type CompletionHook interface {
|
|
29
|
|
- OnMessageComplete(sessionID string, messageID string, completeText string, eventType string, metadata map[string]interface{})
|
|
30
|
|
-}
|
|
31
|
|
-
|
|
32
|
|
-// MessageState 消息状态,用于跟踪单个消息的增量内容
|
|
33
|
|
-type MessageState struct {
|
|
34
|
|
- SessionID string
|
|
35
|
|
- MessageID string
|
|
36
|
|
- StartTime time.Time
|
|
37
|
|
- LastUpdate time.Time
|
|
38
|
|
- Metadata map[string]interface{}
|
|
39
|
|
-
|
|
40
|
|
- // 分离的缓冲区,用于不同类型的内容
|
|
41
|
|
- ReasoningBuffer strings.Builder // 思考内容
|
|
42
|
|
- ReplyBuffer strings.Builder // 回答内容
|
|
43
|
|
- ToolBuffer strings.Builder // 工具调用
|
|
44
|
|
-
|
|
45
|
|
- // 类型完成状态跟踪
|
|
46
|
|
- CompletedTypes map[string]bool // 已完成的类型: "reasoning", "text", "tool"
|
|
47
|
|
- HookTriggered map[string]bool // 钩子触发状态,按类型记录
|
|
48
|
|
-
|
|
49
|
|
- // 当前活跃类型(用于跟踪正在处理的内容)
|
|
50
|
|
- CurrentType string // 当前正在处理的类型
|
|
51
|
|
-}
|
|
52
|
|
-
|
|
53
|
|
-// MessageAggregator 消息聚合器,负责合并增量内容并检测完成状态
|
|
54
|
|
-type MessageAggregator struct {
|
|
55
|
|
- mu sync.RWMutex
|
|
56
|
|
- messages map[string]*MessageState // key: sessionID_messageID
|
|
57
|
|
- hooks []CompletionHook
|
|
58
|
|
- OnMessageCompleteFunc func(sessionID string, messageID string, completeText string, messageType MessageType, metadata map[string]interface{}) // 消息完成回调函数
|
|
59
|
|
- OnEventProcessedFunc func(sessionID string, eventType string, eventData string, eventMap map[string]interface{}) // 事件处理回调函数
|
|
60
|
|
-}
|
|
61
|
|
-
|
|
62
|
|
-// NewMessageAggregator 创建新的消息聚合器
|
|
63
|
|
-func NewMessageAggregator() *MessageAggregator {
|
|
64
|
|
- return &MessageAggregator{
|
|
65
|
|
- messages: make(map[string]*MessageState),
|
|
66
|
|
- hooks: make([]CompletionHook, 0),
|
|
67
|
|
- OnMessageCompleteFunc: nil,
|
|
68
|
|
- OnEventProcessedFunc: nil,
|
|
69
|
|
- }
|
|
70
|
|
-}
|
|
71
|
|
-
|
|
72
|
|
-// RegisterHook 注册完成钩子
|
|
73
|
|
-func (ma *MessageAggregator) RegisterHook(hook CompletionHook) {
|
|
74
|
|
- ma.mu.Lock()
|
|
75
|
|
- defer ma.mu.Unlock()
|
|
76
|
|
- ma.hooks = append(ma.hooks, hook)
|
|
77
|
|
-}
|
|
78
|
|
-
|
|
79
|
|
-// ProcessEvent 处理事件,合并增量内容并检测完成状态
|
|
80
|
|
-func (ma *MessageAggregator) ProcessEvent(eventData string, sessionID string) {
|
|
81
|
|
- // 解析事件数据
|
|
82
|
|
- var eventMap map[string]interface{}
|
|
83
|
|
- if err := json.Unmarshal([]byte(eventData), &eventMap); err != nil {
|
|
84
|
|
- logger.Debug(fmt.Sprintf("无法解析事件JSON error=%s dataPreview=%s",
|
|
85
|
|
- err.Error(), safeSubstring(eventData, 0, 200)))
|
|
86
|
|
- return
|
|
87
|
|
- }
|
|
88
|
|
-
|
|
89
|
|
- eventType := getEventType(eventMap)
|
|
90
|
|
-
|
|
91
|
|
- // 诊断日志:记录处理的事件类型
|
|
92
|
|
- logger.Debug(fmt.Sprintf("🔍 ProcessEvent: sessionID=%s eventType=%s dataPreview=%s",
|
|
93
|
|
- sessionID, eventType, safeSubstring(eventData, 0, 100)))
|
|
94
|
|
-
|
|
95
|
|
- // 调用事件处理回调函数(如果有设置)
|
|
96
|
|
- if ma.OnEventProcessedFunc != nil {
|
|
97
|
|
- ma.OnEventProcessedFunc(sessionID, eventType, eventData, eventMap)
|
|
98
|
|
- }
|
|
99
|
|
-
|
|
100
|
|
- // 只处理 message.part.updated, message.updated, session.status 事件
|
|
101
|
|
- if eventType == "message.part.updated" {
|
|
102
|
|
- ma.handleMessagePartUpdated(eventMap, sessionID, eventData)
|
|
103
|
|
- } else if eventType == "message.updated" {
|
|
104
|
|
- ma.handleMessageUpdated(eventMap, sessionID)
|
|
105
|
|
- } else if eventType == "session.status" {
|
|
106
|
|
- ma.handleSessionStatus(eventMap, sessionID)
|
|
107
|
|
- } else if eventType == "session.idle" {
|
|
108
|
|
- ma.handleSessionIdle(sessionID)
|
|
109
|
|
- }
|
|
110
|
|
- // 其他事件类型忽略
|
|
111
|
|
-}
|
|
112
|
|
-
|
|
113
|
|
-// handleMessagePartUpdated 处理 message.part.updated 事件
|
|
114
|
|
-func (ma *MessageAggregator) handleMessagePartUpdated(eventMap map[string]interface{}, sessionID string, eventData string) {
|
|
115
|
|
- // 提取消息部分信息
|
|
116
|
|
- payload, _ := eventMap["payload"].(map[string]interface{})
|
|
117
|
|
- props, _ := payload["properties"].(map[string]interface{})
|
|
118
|
|
- part, _ := props["part"].(map[string]interface{})
|
|
119
|
|
-
|
|
120
|
|
- messageID, _ := part["messageID"].(string)
|
|
121
|
|
- partType, _ := part["type"].(string)
|
|
122
|
|
-
|
|
123
|
|
- if messageID == "" || (partType != "text" && partType != "reasoning" && partType != "step-finish" && partType != "tool") {
|
|
124
|
|
- return
|
|
125
|
|
- }
|
|
126
|
|
-
|
|
127
|
|
- ma.mu.Lock()
|
|
128
|
|
- defer ma.mu.Unlock()
|
|
129
|
|
-
|
|
130
|
|
- key := sessionID + "_" + messageID
|
|
131
|
|
- state, exists := ma.messages[key]
|
|
132
|
|
-
|
|
133
|
|
- if !exists {
|
|
134
|
|
- // step-finish 事件不应该创建新状态,它应该总是跟随在text/reasoning事件之后
|
|
135
|
|
- if partType == "step-finish" {
|
|
136
|
|
- logger.Debug(fmt.Sprintf("忽略step-finish事件,无对应消息状态 sessionID=%s messageID=%s",
|
|
137
|
|
- sessionID, messageID))
|
|
138
|
|
- return
|
|
139
|
|
- }
|
|
140
|
|
-
|
|
141
|
|
- // 新消息,初始化状态
|
|
142
|
|
- state = &MessageState{
|
|
143
|
|
- SessionID: sessionID,
|
|
144
|
|
- MessageID: messageID,
|
|
145
|
|
- StartTime: time.Now(),
|
|
146
|
|
- LastUpdate: time.Now(),
|
|
147
|
|
- Metadata: make(map[string]interface{}),
|
|
148
|
|
- CompletedTypes: make(map[string]bool),
|
|
149
|
|
- HookTriggered: make(map[string]bool),
|
|
150
|
|
- CurrentType: partType,
|
|
151
|
|
- }
|
|
152
|
|
- ma.messages[key] = state
|
|
153
|
|
- logger.Debug(fmt.Sprintf("开始跟踪新消息 sessionID=%s messageID=%s type=%s",
|
|
154
|
|
- sessionID, messageID, partType))
|
|
155
|
|
- }
|
|
156
|
|
-
|
|
157
|
|
- // 更新增量内容 - 根据类型写入不同的缓冲区
|
|
158
|
|
- state.CurrentType = partType
|
|
159
|
|
-
|
|
160
|
|
- if text, ok := part["text"].(string); ok && text != "" {
|
|
161
|
|
- // 根据类型选择缓冲区
|
|
162
|
|
- switch partType {
|
|
163
|
|
- case "reasoning":
|
|
164
|
|
- state.ReasoningBuffer.WriteString(text)
|
|
165
|
|
- case "text":
|
|
166
|
|
- state.ReplyBuffer.WriteString(text)
|
|
167
|
|
- case "tool":
|
|
168
|
|
- // 工具调用可能有text字段,也可能通过其他方式处理
|
|
169
|
|
- state.ToolBuffer.WriteString(text)
|
|
170
|
|
- }
|
|
171
|
|
- state.LastUpdate = time.Now()
|
|
172
|
|
-
|
|
173
|
|
- // 记录增量合并日志(仅调试)- 已禁用以减少日志量
|
|
174
|
|
- // logger.Debug(fmt.Sprintf("合并增量内容 sessionID=%s messageID=%s type=%s deltaLength=%d",
|
|
175
|
|
- // sessionID, messageID, partType, len(text)))
|
|
176
|
|
- } else if partType == "tool" {
|
|
177
|
|
- // 处理工具调用事件,提取工具信息
|
|
178
|
|
- if name, ok := part["name"].(string); ok && name != "" {
|
|
179
|
|
- toolText := fmt.Sprintf("[工具调用: %s", name)
|
|
180
|
|
- if args, ok := part["arguments"].(string); ok && args != "" {
|
|
181
|
|
- // 参数可能是JSON字符串,可以尝试美化
|
|
182
|
|
- toolText += fmt.Sprintf(" 参数: %s", args)
|
|
183
|
|
- }
|
|
184
|
|
- toolText += "]"
|
|
185
|
|
- state.ToolBuffer.WriteString(toolText)
|
|
186
|
|
- state.LastUpdate = time.Now()
|
|
187
|
|
- logger.Debug(fmt.Sprintf("合并工具调用内容 sessionID=%s messageID=%s toolName=%s",
|
|
188
|
|
- sessionID, messageID, name))
|
|
189
|
|
- }
|
|
190
|
|
- }
|
|
191
|
|
-
|
|
192
|
|
- // 检查是否为 step-finish
|
|
193
|
|
- if partType == "step-finish" {
|
|
194
|
|
- // 标记当前类型为完成
|
|
195
|
|
- if state.CurrentType != "" {
|
|
196
|
|
- state.CompletedTypes[state.CurrentType] = true
|
|
197
|
|
- logger.Info(fmt.Sprintf("步骤完成 sessionID=%s messageID=%s type=%s",
|
|
198
|
|
- sessionID, messageID, state.CurrentType))
|
|
199
|
|
-
|
|
200
|
|
- // 触发该类型的完成钩子
|
|
201
|
|
- ma.triggerTypeCompletionHooks(state, state.CurrentType)
|
|
202
|
|
- } else {
|
|
203
|
|
- logger.Warn(fmt.Sprintf("step-finish事件无当前类型 sessionID=%s messageID=%s",
|
|
204
|
|
- sessionID, messageID))
|
|
205
|
|
- }
|
|
206
|
|
- }
|
|
207
|
|
-}
|
|
208
|
|
-
|
|
209
|
|
-// handleMessageUpdated 处理 message.updated 事件
|
|
210
|
|
-func (ma *MessageAggregator) handleMessageUpdated(eventMap map[string]interface{}, sessionID string) {
|
|
211
|
|
- payload, _ := eventMap["payload"].(map[string]interface{})
|
|
212
|
|
- props, _ := payload["properties"].(map[string]interface{})
|
|
213
|
|
- info, _ := props["info"].(map[string]interface{})
|
|
214
|
|
-
|
|
215
|
|
- messageID, _ := info["id"].(string)
|
|
216
|
|
- finishReason, _ := info["finish"].(string)
|
|
217
|
|
-
|
|
218
|
|
- if messageID == "" || finishReason == "" {
|
|
219
|
|
- return
|
|
220
|
|
- }
|
|
221
|
|
-
|
|
222
|
|
- key := sessionID + "_" + messageID
|
|
223
|
|
- ma.mu.Lock()
|
|
224
|
|
- defer ma.mu.Unlock()
|
|
225
|
|
-
|
|
226
|
|
- if state, exists := ma.messages[key]; exists {
|
|
227
|
|
- logger.Info(fmt.Sprintf("消息完成 sessionID=%s messageID=%s finishReason=%s",
|
|
228
|
|
- sessionID, messageID, finishReason))
|
|
229
|
|
-
|
|
230
|
|
- // 触发完成钩子(处理所有类型的内容)
|
|
231
|
|
- ma.triggerCompletionHooks(state)
|
|
232
|
|
-
|
|
233
|
|
- // 清理完成的消息状态
|
|
234
|
|
- delete(ma.messages, key)
|
|
235
|
|
- }
|
|
236
|
|
-}
|
|
237
|
|
-
|
|
238
|
|
-// handleSessionStatus 处理 session.status 事件
|
|
239
|
|
-func (ma *MessageAggregator) handleSessionStatus(eventMap map[string]interface{}, sessionID string) {
|
|
240
|
|
- payload, _ := eventMap["payload"].(map[string]interface{})
|
|
241
|
|
- props, _ := payload["properties"].(map[string]interface{})
|
|
242
|
|
- status, _ := props["status"].(map[string]interface{})
|
|
243
|
|
-
|
|
244
|
|
- statusType, _ := status["type"].(string)
|
|
245
|
|
- if statusType == "idle" {
|
|
246
|
|
- logger.Info(fmt.Sprintf("会话进入空闲状态 sessionID=%s", sessionID))
|
|
247
|
|
- // 可以清理该会话的所有消息状态
|
|
248
|
|
- ma.cleanupSession(sessionID)
|
|
249
|
|
- }
|
|
250
|
|
-}
|
|
251
|
|
-
|
|
252
|
|
-// handleSessionIdle 处理 session.idle 事件
|
|
253
|
|
-func (ma *MessageAggregator) handleSessionIdle(sessionID string) {
|
|
254
|
|
- logger.Info(fmt.Sprintf("会话空闲事件 sessionID=%s", sessionID))
|
|
255
|
|
- ma.cleanupSession(sessionID)
|
|
256
|
|
-}
|
|
257
|
|
-
|
|
258
|
|
-// cleanupSession 清理指定会话的所有消息状态
|
|
259
|
|
-func (ma *MessageAggregator) cleanupSession(sessionID string) {
|
|
260
|
|
- ma.mu.Lock()
|
|
261
|
|
- defer ma.mu.Unlock()
|
|
262
|
|
-
|
|
263
|
|
- keysToDelete := make([]string, 0)
|
|
264
|
|
- for key, state := range ma.messages {
|
|
265
|
|
- if state.SessionID == sessionID {
|
|
266
|
|
- keysToDelete = append(keysToDelete, key)
|
|
267
|
|
-
|
|
268
|
|
- // 检查是否有未触发的类型内容,强制触发完成钩子
|
|
269
|
|
- hasUnfinishedContent := false
|
|
270
|
|
-
|
|
271
|
|
- // 检查每种类型是否有内容但钩子未触发
|
|
272
|
|
- if state.ReasoningBuffer.Len() > 0 {
|
|
273
|
|
- if triggered, exists := state.HookTriggered["reasoning"]; !exists || !triggered {
|
|
274
|
|
- hasUnfinishedContent = true
|
|
275
|
|
- }
|
|
276
|
|
- }
|
|
277
|
|
- if state.ReplyBuffer.Len() > 0 {
|
|
278
|
|
- if triggered, exists := state.HookTriggered["text"]; !exists || !triggered {
|
|
279
|
|
- hasUnfinishedContent = true
|
|
280
|
|
- }
|
|
281
|
|
- }
|
|
282
|
|
- if state.ToolBuffer.Len() > 0 {
|
|
283
|
|
- if triggered, exists := state.HookTriggered["tool"]; !exists || !triggered {
|
|
284
|
|
- hasUnfinishedContent = true
|
|
285
|
|
- }
|
|
286
|
|
- }
|
|
287
|
|
-
|
|
288
|
|
- if hasUnfinishedContent {
|
|
289
|
|
- logger.Warn(fmt.Sprintf("强制完成未完成消息 sessionID=%s messageID=%s",
|
|
290
|
|
- sessionID, state.MessageID))
|
|
291
|
|
- ma.triggerCompletionHooks(state)
|
|
292
|
|
- }
|
|
293
|
|
- }
|
|
294
|
|
- }
|
|
295
|
|
-
|
|
296
|
|
- for _, key := range keysToDelete {
|
|
297
|
|
- delete(ma.messages, key)
|
|
298
|
|
- }
|
|
299
|
|
-
|
|
300
|
|
- if len(keysToDelete) > 0 {
|
|
301
|
|
- logger.Info(fmt.Sprintf("清理会话消息状态 sessionID=%s cleanedCount=%d",
|
|
302
|
|
- sessionID, len(keysToDelete)))
|
|
303
|
|
- }
|
|
304
|
|
-}
|
|
305
|
|
-
|
|
306
|
|
-// triggerTypeCompletionHooks 触发特定类型的完成钩子
|
|
307
|
|
-func (ma *MessageAggregator) triggerTypeCompletionHooks(state *MessageState, partType string) {
|
|
308
|
|
- // 避免重复触发
|
|
309
|
|
- if triggered, exists := state.HookTriggered[partType]; exists && triggered {
|
|
310
|
|
- logger.Debug(fmt.Sprintf("钩子已触发过,跳过 sessionID=%s messageID=%s type=%s",
|
|
311
|
|
- state.SessionID, state.MessageID, partType))
|
|
312
|
|
- return
|
|
313
|
|
- }
|
|
314
|
|
-
|
|
315
|
|
- // 获取该类型的缓冲区内容
|
|
316
|
|
- var completeText string
|
|
317
|
|
- var textLength int
|
|
318
|
|
-
|
|
319
|
|
- switch partType {
|
|
320
|
|
- case "reasoning":
|
|
321
|
|
- completeText = state.ReasoningBuffer.String()
|
|
322
|
|
- textLength = len(completeText)
|
|
323
|
|
- case "text":
|
|
324
|
|
- completeText = state.ReplyBuffer.String()
|
|
325
|
|
- textLength = len(completeText)
|
|
326
|
|
- case "tool":
|
|
327
|
|
- completeText = state.ToolBuffer.String()
|
|
328
|
|
- textLength = len(completeText)
|
|
329
|
|
- default:
|
|
330
|
|
- logger.Warn(fmt.Sprintf("未知类型,跳过钩子触发 sessionID=%s messageID=%s type=%s",
|
|
331
|
|
- state.SessionID, state.MessageID, partType))
|
|
332
|
|
- return
|
|
333
|
|
- }
|
|
334
|
|
-
|
|
335
|
|
- duration := time.Since(state.StartTime)
|
|
336
|
|
-
|
|
337
|
|
- // 记录完成事件
|
|
338
|
|
- logger.Info(fmt.Sprintf("🔔 类型完成总结 sessionID=%s messageID=%s type=%s textLength=%d duration=%v",
|
|
339
|
|
- state.SessionID, state.MessageID, partType, textLength, duration))
|
|
340
|
|
-
|
|
341
|
|
- if textLength > 0 {
|
|
342
|
|
- // 记录文本预览(前200字符)
|
|
343
|
|
- preview := completeText
|
|
344
|
|
- if len(preview) > 200 {
|
|
345
|
|
- preview = preview[:200] + "..."
|
|
346
|
|
- }
|
|
347
|
|
- logger.Info(fmt.Sprintf("📝 类型文本预览: %s", preview))
|
|
348
|
|
- } else {
|
|
349
|
|
- logger.Info("📭 空文本完成事件")
|
|
350
|
|
- }
|
|
351
|
|
-
|
|
352
|
|
- // 转换事件类型为消息类型枚举
|
|
353
|
|
- messageType := convertToMessageType(partType)
|
|
354
|
|
-
|
|
355
|
|
- // 调用消息完成回调函数(如果设置)- 使用新的类型化接口
|
|
356
|
|
- // 注意:这里调用现有的OnMessageCompleteFunc,传入特定类型的内容
|
|
357
|
|
- if ma.OnMessageCompleteFunc != nil {
|
|
358
|
|
- ma.OnMessageCompleteFunc(state.SessionID, state.MessageID, completeText, messageType, state.Metadata)
|
|
359
|
|
- }
|
|
360
|
|
-
|
|
361
|
|
- // 调用所有注册的钩子(保持向后兼容)
|
|
362
|
|
- for _, hook := range ma.hooks {
|
|
363
|
|
- hook.OnMessageComplete(state.SessionID, state.MessageID, completeText, partType, state.Metadata)
|
|
364
|
|
- }
|
|
365
|
|
-
|
|
366
|
|
- // 标记该类型的钩子已触发
|
|
367
|
|
- state.HookTriggered[partType] = true
|
|
368
|
|
- logger.Debug(fmt.Sprintf("✅ 类型钩子标记为已触发 sessionID=%s messageID=%s type=%s",
|
|
369
|
|
- state.SessionID, state.MessageID, partType))
|
|
370
|
|
-}
|
|
371
|
|
-
|
|
372
|
|
-// triggerCompletionHooks 触发完成钩子(向后兼容,触发所有类型)
|
|
373
|
|
-func (ma *MessageAggregator) triggerCompletionHooks(state *MessageState) {
|
|
374
|
|
- // 遍历所有支持的类型,触发各自的完成钩子
|
|
375
|
|
- types := []string{"reasoning", "text", "tool"}
|
|
376
|
|
- hasAnyContent := false
|
|
377
|
|
-
|
|
378
|
|
- for _, partType := range types {
|
|
379
|
|
- // 检查该类型是否有内容
|
|
380
|
|
- var hasContent bool
|
|
381
|
|
- switch partType {
|
|
382
|
|
- case "reasoning":
|
|
383
|
|
- hasContent = state.ReasoningBuffer.Len() > 0
|
|
384
|
|
- case "text":
|
|
385
|
|
- hasContent = state.ReplyBuffer.Len() > 0
|
|
386
|
|
- case "tool":
|
|
387
|
|
- hasContent = state.ToolBuffer.Len() > 0
|
|
388
|
|
- }
|
|
389
|
|
-
|
|
390
|
|
- if hasContent {
|
|
391
|
|
- hasAnyContent = true
|
|
392
|
|
- // 触发该类型的完成钩子(函数内部会检查是否已触发)
|
|
393
|
|
- ma.triggerTypeCompletionHooks(state, partType)
|
|
394
|
|
- }
|
|
395
|
|
- }
|
|
396
|
|
-
|
|
397
|
|
- // 如果没有内容,至少记录一个事件(保持向后兼容)
|
|
398
|
|
- if !hasAnyContent {
|
|
399
|
|
- logger.Info(fmt.Sprintf("📭 空消息完成事件 sessionID=%s messageID=%s",
|
|
400
|
|
- state.SessionID, state.MessageID))
|
|
401
|
|
- }
|
|
402
|
|
-}
|
|
403
|
|
-
|
|
404
|
17
|
// EventDispatcher 事件分发器 - 单例模式
|
|
405
|
18
|
type EventDispatcher struct {
|
|
406
|
|
- mu sync.RWMutex
|
|
407
|
|
- baseURL string
|
|
408
|
|
- port int
|
|
409
|
|
- subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels
|
|
410
|
|
- sessionUserCache *SessionUserCache // sessionID -> userID 映射缓存(用于用户验证)
|
|
411
|
|
- client *http.Client
|
|
412
|
|
- cancelFunc context.CancelFunc
|
|
413
|
|
- running bool
|
|
414
|
|
- messageAggregator *MessageAggregator // 消息聚合器
|
|
415
|
|
-}
|
|
416
|
|
-
|
|
417
|
|
-// EventData opencode事件数据结构 - 匹配实际事件格式
|
|
418
|
|
-type EventData struct {
|
|
419
|
|
- Directory string `json:"directory,omitempty"`
|
|
420
|
|
- Payload map[string]interface{} `json:"payload"`
|
|
421
|
|
-}
|
|
422
|
|
-
|
|
423
|
|
-// PayloadData payload内部结构(辅助类型)
|
|
424
|
|
-type PayloadData struct {
|
|
425
|
|
- Type string `json:"type"`
|
|
426
|
|
- Properties map[string]interface{} `json:"properties,omitempty"`
|
|
|
19
|
+ mu sync.RWMutex
|
|
|
20
|
+ baseURL string
|
|
|
21
|
+ port int
|
|
|
22
|
+ subscriptions map[string]map[chan string]struct{} // sessionID -> set of channels
|
|
|
23
|
+ client *http.Client
|
|
|
24
|
+ cancelFunc context.CancelFunc
|
|
|
25
|
+ running bool
|
|
427
|
26
|
}
|
|
428
|
27
|
|
|
429
|
28
|
// NewEventDispatcher 创建新的事件分发器
|
|
430
|
29
|
func NewEventDispatcher(baseURL string, port int) *EventDispatcher {
|
|
431
|
30
|
return &EventDispatcher{
|
|
432
|
|
- baseURL: baseURL,
|
|
433
|
|
- port: port,
|
|
434
|
|
- subscriptions: make(map[string]map[chan string]struct{}),
|
|
435
|
|
- sessionUserCache: NewSessionUserCache(20 * time.Minute),
|
|
436
|
|
- messageAggregator: NewMessageAggregator(),
|
|
|
31
|
+ baseURL: baseURL,
|
|
|
32
|
+ port: port,
|
|
|
33
|
+ subscriptions: make(map[string]map[chan string]struct{}),
|
|
437
|
34
|
client: &http.Client{
|
|
438
|
35
|
Timeout: 0, // 无超时限制,用于长连接
|
|
439
|
36
|
},
|
|
|
@@ -494,9 +91,6 @@ func (ed *EventDispatcher) Subscribe(sessionID, userID string) (<-chan string, e
|
|
494
|
91
|
ed.mu.Lock()
|
|
495
|
92
|
defer ed.mu.Unlock()
|
|
496
|
93
|
|
|
497
|
|
- // 缓存会话-用户映射(用于未来需要用户验证时)
|
|
498
|
|
- ed.sessionUserCache.Set(sessionID, userID)
|
|
499
|
|
-
|
|
500
|
94
|
// 创建缓冲通道
|
|
501
|
95
|
ch := make(chan string, 100)
|
|
502
|
96
|
|
|
|
@@ -538,18 +132,10 @@ func (ed *EventDispatcher) Unsubscribe(sessionID string, ch <-chan string) {
|
|
538
|
132
|
// 如果没有订阅者了,清理该会话的映射
|
|
539
|
133
|
if len(channels) == 0 {
|
|
540
|
134
|
delete(ed.subscriptions, sessionID)
|
|
541
|
|
- ed.sessionUserCache.Delete(sessionID)
|
|
542
|
135
|
}
|
|
543
|
136
|
}
|
|
544
|
137
|
}
|
|
545
|
138
|
|
|
546
|
|
-// RegisterSession 注册会话(前端调用SendPromptStream时调用)
|
|
547
|
|
-func (ed *EventDispatcher) RegisterSession(sessionID, userID string) {
|
|
548
|
|
- ed.sessionUserCache.Set(sessionID, userID)
|
|
549
|
|
- logger.Debug(fmt.Sprintf("会话已注册 sessionID=%s userID=%s",
|
|
550
|
|
- sessionID, userID))
|
|
551
|
|
-}
|
|
552
|
|
-
|
|
553
|
139
|
// buildSSEURL 构建SSE URL,避免端口重复
|
|
554
|
140
|
func (ed *EventDispatcher) buildSSEURL() string {
|
|
555
|
141
|
// 检查baseURL是否已包含端口
|
|
|
@@ -674,10 +260,6 @@ func (ed *EventDispatcher) connectAndProcessSSE(ctx context.Context, url string)
|
|
674
|
260
|
// 分发事件
|
|
675
|
261
|
ed.dispatchEvent(data)
|
|
676
|
262
|
|
|
677
|
|
- if eventCount%100 == 0 {
|
|
678
|
|
- //logger.Debug(fmt.Sprintf("事件处理统计 totalEvents=%d activeSessions=%d",
|
|
679
|
|
- // eventCount, len(ed.subscriptions)))
|
|
680
|
|
- }
|
|
681
|
263
|
}
|
|
682
|
264
|
}
|
|
683
|
265
|
}
|
|
|
@@ -688,32 +270,28 @@ func (ed *EventDispatcher) dispatchEvent(data string) {
|
|
688
|
270
|
// 解析事件数据获取sessionID
|
|
689
|
271
|
sessionID := extractSessionIDFromEvent(data)
|
|
690
|
272
|
|
|
691
|
|
- // 处理事件聚合(无论是否有sessionID都处理)
|
|
692
|
|
- if ed.messageAggregator != nil && sessionID != "" {
|
|
693
|
|
- ed.messageAggregator.ProcessEvent(data, sessionID)
|
|
694
|
|
- }
|
|
695
|
|
-
|
|
696
|
273
|
if sessionID == "" {
|
|
697
|
|
- // 没有sessionID的事件(如全局心跳)分发给所有订阅者
|
|
698
|
|
- //logger.Debug(fmt.Sprintf("广播全局事件 dataPreview=%s", safeSubstring(data, 0, 100)))
|
|
699
|
|
- ed.broadcastToAll(data)
|
|
|
274
|
+ // 没有sessionID的事件(如全局心跳)丢弃,不广播给所有订阅者
|
|
|
275
|
+ // 确保按sessionID严格隔离,避免多用户消息交叉
|
|
700
|
276
|
return
|
|
701
|
277
|
}
|
|
702
|
278
|
|
|
703
|
|
- // 只记录非增量文本事件的路由日志,减少日志量
|
|
704
|
|
- shouldLog := true
|
|
|
279
|
+ // 只记录关键事件的路由日志,减少日志输出
|
|
705
|
280
|
var eventMap map[string]interface{}
|
|
706
|
281
|
if err := json.Unmarshal([]byte(data), &eventMap); err == nil {
|
|
707
|
|
- eventType := getEventType(eventMap)
|
|
708
|
|
- // message.part.updated 是最频繁的事件,跳过其路由日志
|
|
709
|
|
- if eventType == "message.part.updated" {
|
|
710
|
|
- shouldLog = false
|
|
|
282
|
+ // 提取事件类型
|
|
|
283
|
+ var eventType string
|
|
|
284
|
+ if payload, ok := eventMap["payload"].(map[string]interface{}); ok {
|
|
|
285
|
+ if t, ok := payload["type"].(string); ok {
|
|
|
286
|
+ eventType = t
|
|
|
287
|
+ }
|
|
|
288
|
+ }
|
|
|
289
|
+ // 只记录关键事件类型的路由信息
|
|
|
290
|
+ switch eventType {
|
|
|
291
|
+ case "session.status", "message.updated", "session.diff", "session.idle":
|
|
|
292
|
+ logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s type=%s",
|
|
|
293
|
+ sessionID, eventType))
|
|
711
|
294
|
}
|
|
712
|
|
- }
|
|
713
|
|
-
|
|
714
|
|
- if shouldLog {
|
|
715
|
|
- logger.Debug(fmt.Sprintf("路由事件到会话 sessionID=%s dataPreview=%s",
|
|
716
|
|
- sessionID, safeSubstring(data, 0, 100)))
|
|
717
|
295
|
}
|
|
718
|
296
|
|
|
719
|
297
|
// 只分发给订阅该会话的通道
|
|
|
@@ -723,7 +301,6 @@ func (ed *EventDispatcher) dispatchEvent(data string) {
|
|
723
|
301
|
|
|
724
|
302
|
if !exists {
|
|
725
|
303
|
// 没有该会话的订阅者,忽略事件
|
|
726
|
|
- //logger.Debug(fmt.Sprintf("忽略事件,无订阅者 sessionID=%s", sessionID))
|
|
727
|
304
|
return
|
|
728
|
305
|
}
|
|
729
|
306
|
|
|
|
@@ -742,52 +319,18 @@ func (ed *EventDispatcher) dispatchEvent(data string) {
|
|
742
|
319
|
ed.mu.RUnlock()
|
|
743
|
320
|
}
|
|
744
|
321
|
|
|
745
|
|
-// broadcastToAll 广播事件给所有订阅者(用于全局事件如心跳)
|
|
746
|
|
-func (ed *EventDispatcher) broadcastToAll(data string) {
|
|
747
|
|
- //logger.Debug(fmt.Sprintf("广播事件给所有订阅者 dataPreview=%s", safeSubstring(data, 0, 100)))
|
|
748
|
|
- ed.mu.RLock()
|
|
749
|
|
- defer ed.mu.RUnlock()
|
|
750
|
|
-
|
|
751
|
|
- for sessionID, channels := range ed.subscriptions {
|
|
752
|
|
- for ch := range channels {
|
|
753
|
|
- select {
|
|
754
|
|
- case ch <- data:
|
|
755
|
|
- // 成功发送
|
|
756
|
|
- default:
|
|
757
|
|
- // 通道已满,丢弃事件
|
|
758
|
|
- logger.Warn(fmt.Sprintf("事件通道已满,丢弃全局事件 sessionID=%s",
|
|
759
|
|
- sessionID))
|
|
760
|
|
- }
|
|
761
|
|
- }
|
|
762
|
|
- }
|
|
763
|
|
-}
|
|
764
|
|
-
|
|
765
|
322
|
// extractSessionIDFromEvent 从事件数据中提取sessionID
|
|
766
|
323
|
func extractSessionIDFromEvent(data string) string {
|
|
767
|
324
|
// 尝试解析为JSON
|
|
768
|
325
|
var eventMap map[string]interface{}
|
|
769
|
326
|
if err := json.Unmarshal([]byte(data), &eventMap); err != nil {
|
|
770
|
|
- logger.ErrorC(fmt.Sprintf("无法解析事件JSON error=%s dataPreview=%s",
|
|
771
|
|
- err.Error(), safeSubstring(data, 0, 200)))
|
|
|
327
|
+ logger.Error("无法解析事件JSON", "error", err.Error(), "dataPreview", safeSubstring(data, 0, 200))
|
|
772
|
328
|
return ""
|
|
773
|
329
|
}
|
|
774
|
330
|
|
|
775
|
|
- // 添加调试日志,显示完整事件结构(仅调试时启用)
|
|
776
|
|
- debugMode := true
|
|
777
|
|
- if debugMode {
|
|
778
|
|
- eventJSON, _ := json.MarshalIndent(eventMap, "", " ")
|
|
779
|
|
- logger.Debug(fmt.Sprintf("事件数据结构 eventStructure=%s",
|
|
780
|
|
- string(eventJSON)))
|
|
781
|
|
- }
|
|
782
|
|
-
|
|
783
|
331
|
// 递归查找sessionID字段
|
|
784
|
332
|
sessionID := findSessionIDRecursive(eventMap)
|
|
785
|
333
|
|
|
786
|
|
- if sessionID == "" {
|
|
787
|
|
- //logger.Debug(fmt.Sprintf("未找到sessionID字段 eventType=%s dataPreview=%s",
|
|
788
|
|
- // getEventType(eventMap), safeSubstring(data, 0, 100)))
|
|
789
|
|
- }
|
|
790
|
|
-
|
|
791
|
334
|
return sessionID
|
|
792
|
335
|
}
|
|
793
|
336
|
|
|
|
@@ -855,37 +398,6 @@ func findSessionIDRecursive(data interface{}) string {
|
|
855
|
398
|
return ""
|
|
856
|
399
|
}
|
|
857
|
400
|
|
|
858
|
|
-// getEventType 获取事件类型
|
|
859
|
|
-func getEventType(eventMap map[string]interface{}) string {
|
|
860
|
|
- if payload, ok := eventMap["payload"].(map[string]interface{}); ok {
|
|
861
|
|
- if eventType, ok := payload["type"].(string); ok {
|
|
862
|
|
- return eventType
|
|
863
|
|
- }
|
|
864
|
|
- }
|
|
865
|
|
- return "unknown"
|
|
866
|
|
-}
|
|
867
|
|
-
|
|
868
|
|
-// convertToMessageType 将事件类型转换为消息类型枚举
|
|
869
|
|
-func convertToMessageType(eventType string) MessageType {
|
|
870
|
|
- switch eventType {
|
|
871
|
|
- case "reasoning":
|
|
872
|
|
- return MessageTypeThinking
|
|
873
|
|
- case "text":
|
|
874
|
|
- return MessageTypeReply
|
|
875
|
|
- case "tool":
|
|
876
|
|
- return MessageTypeTool
|
|
877
|
|
- default:
|
|
878
|
|
- // 尝试识别其他类型
|
|
879
|
|
- if strings.Contains(eventType, "reasoning") || strings.Contains(eventType, "thinking") {
|
|
880
|
|
- return MessageTypeThinking
|
|
881
|
|
- }
|
|
882
|
|
- if strings.Contains(eventType, "tool") || strings.Contains(eventType, "function") {
|
|
883
|
|
- return MessageTypeTool
|
|
884
|
|
- }
|
|
885
|
|
- return MessageTypeUnknown
|
|
886
|
|
- }
|
|
887
|
|
-}
|
|
888
|
|
-
|
|
889
|
401
|
// safeSubstring 安全的子字符串函数
|
|
890
|
402
|
func safeSubstring(s string, start, length int) string {
|
|
891
|
403
|
if start < 0 {
|
|
|
@@ -901,40 +413,6 @@ func safeSubstring(s string, start, length int) string {
|
|
901
|
413
|
return s[start:end]
|
|
902
|
414
|
}
|
|
903
|
415
|
|
|
904
|
|
-// getHookTriggerSource 获取钩子触发源(用于调试)
|
|
905
|
|
-func getHookTriggerSource(state *MessageState) string {
|
|
906
|
|
- // 检查是否有任何类型的钩子已触发
|
|
907
|
|
- for _, triggered := range state.HookTriggered {
|
|
908
|
|
- if triggered {
|
|
909
|
|
- return "completed"
|
|
910
|
|
- }
|
|
911
|
|
- }
|
|
912
|
|
- return "unknown"
|
|
913
|
|
-}
|
|
914
|
|
-
|
|
915
|
|
-// RegisterHook 注册完成钩子
|
|
916
|
|
-func (ed *EventDispatcher) RegisterHook(hook CompletionHook) {
|
|
917
|
|
- if ed.messageAggregator != nil {
|
|
918
|
|
- ed.messageAggregator.RegisterHook(hook)
|
|
919
|
|
- }
|
|
920
|
|
-}
|
|
921
|
|
-
|
|
922
|
|
-// SetOnMessageCompleteFunc 设置消息完成回调函数
|
|
923
|
|
-func (ed *EventDispatcher) SetOnMessageCompleteFunc(f func(sessionID string, messageID string, completeText string, messageType MessageType, metadata map[string]interface{})) {
|
|
924
|
|
- if ed.messageAggregator != nil {
|
|
925
|
|
- ed.messageAggregator.OnMessageCompleteFunc = f
|
|
926
|
|
- logger.Info("✅ 消息完成回调函数已设置")
|
|
927
|
|
- }
|
|
928
|
|
-}
|
|
929
|
|
-
|
|
930
|
|
-// SetOnEventProcessedFunc 设置事件处理回调函数
|
|
931
|
|
-func (ed *EventDispatcher) SetOnEventProcessedFunc(f func(sessionID string, eventType string, eventData string, eventMap map[string]interface{})) {
|
|
932
|
|
- if ed.messageAggregator != nil {
|
|
933
|
|
- ed.messageAggregator.OnEventProcessedFunc = f
|
|
934
|
|
- logger.Info("✅ 事件处理回调函数已设置")
|
|
935
|
|
- }
|
|
936
|
|
-}
|
|
937
|
|
-
|
|
938
|
416
|
// GetInstance 获取单例实例(线程安全)
|
|
939
|
417
|
var (
|
|
940
|
418
|
instance *EventDispatcher
|
|
|
@@ -948,30 +426,3 @@ func GetEventDispatcher(baseURL string, port int) *EventDispatcher {
|
|
948
|
426
|
})
|
|
949
|
427
|
return instance
|
|
950
|
428
|
}
|
|
951
|
|
-
|
|
952
|
|
-// DiagnosticHook 诊断钩子实现(用于调试和测试)
|
|
953
|
|
-type DiagnosticHook struct{}
|
|
954
|
|
-
|
|
955
|
|
-func (h *DiagnosticHook) OnMessageComplete(sessionID string, messageID string, completeText string, eventType string, metadata map[string]interface{}) {
|
|
956
|
|
- logger.Info(fmt.Sprintf("🔍 诊断钩子触发: session=%s message=%s type=%s textLength=%d",
|
|
957
|
|
- sessionID, messageID, eventType, len(completeText)))
|
|
958
|
|
-
|
|
959
|
|
- if len(completeText) > 0 {
|
|
960
|
|
- preview := completeText
|
|
961
|
|
- if len(preview) > 150 {
|
|
962
|
|
- preview = preview[:150] + "..."
|
|
963
|
|
- }
|
|
964
|
|
- logger.Info(fmt.Sprintf("📋 诊断钩子文本预览: %s", preview))
|
|
965
|
|
- } else {
|
|
966
|
|
- logger.Info("📭 诊断钩子: 空文本")
|
|
967
|
|
- }
|
|
968
|
|
-
|
|
969
|
|
- // 记录元数据(如果有)
|
|
970
|
|
- if metadata != nil && len(metadata) > 0 {
|
|
971
|
|
- logger.Info(fmt.Sprintf("📊 诊断钩子元数据: %+v", metadata))
|
|
972
|
|
- }
|
|
973
|
|
-}
|
|
974
|
|
-
|
|
975
|
|
-// 使用示例:
|
|
976
|
|
-// dispatcher := GetEventDispatcher("http://localhost", 3000)
|
|
977
|
|
-// dispatcher.RegisterHook(&DiagnosticHook{})
|