暂无描述
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

mock_client.go 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package opencode
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. )
  7. // 确保 MockClient 实现 OpenCodeClient 接口
  8. var _ OpenCodeClient = (*MockClient)(nil)
  9. // MockClient opencode 客户端的模拟实现
  10. type MockClient struct {
  11. mu sync.RWMutex
  12. sessions map[string]*Session
  13. nextSessionID int
  14. port int
  15. baseURL string
  16. shouldFail bool
  17. failMessage string
  18. }
  19. // NewMockClient 创建新的模拟客户端
  20. func NewMockClient(port int) *MockClient {
  21. return &MockClient{
  22. sessions: make(map[string]*Session),
  23. nextSessionID: 1,
  24. port: port,
  25. baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
  26. shouldFail: false,
  27. }
  28. }
  29. // SetShouldFail 设置客户端是否应该失败
  30. func (m *MockClient) SetShouldFail(shouldFail bool, message string) {
  31. m.mu.Lock()
  32. defer m.mu.Unlock()
  33. m.shouldFail = shouldFail
  34. m.failMessage = message
  35. }
  36. // CreateSession 模拟创建会话
  37. func (m *MockClient) CreateSession(ctx context.Context, title string) (*Session, error) {
  38. m.mu.Lock()
  39. defer m.mu.Unlock()
  40. if m.shouldFail {
  41. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  42. }
  43. sessionID := fmt.Sprintf("ses-mock-%d", m.nextSessionID)
  44. m.nextSessionID++
  45. session := &Session{
  46. ID: sessionID,
  47. Title: title,
  48. CreatedAt: "2024-01-01T00:00:00Z",
  49. }
  50. m.sessions[sessionID] = session
  51. return session, nil
  52. }
  53. // SendPrompt 模拟发送消息(同步)
  54. func (m *MockClient) SendPrompt(ctx context.Context, sessionID string, prompt *PromptRequest) (*PromptResponse, error) {
  55. m.mu.RLock()
  56. defer m.mu.RUnlock()
  57. if m.shouldFail {
  58. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  59. }
  60. if _, exists := m.sessions[sessionID]; !exists {
  61. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  62. }
  63. response := &PromptResponse{
  64. Info: AssistantMessage{
  65. ID: "msg-mock-001",
  66. Role: "assistant",
  67. SessionID: sessionID,
  68. Content: "这是模拟的响应:" + prompt.Parts[0].Text,
  69. Agent: "opencode",
  70. ModelID: "mock-model",
  71. ProviderID: "mock-provider",
  72. Tokens: TokenInfo{
  73. Input: 10,
  74. Output: 20,
  75. },
  76. Time: map[string]interface{}{
  77. "started": "2024-01-01T00:00:00Z",
  78. "finished": "2024-01-01T00:00:01Z",
  79. },
  80. },
  81. Parts: []interface{}{
  82. map[string]interface{}{
  83. "type": "text",
  84. "text": "这是模拟的响应:" + prompt.Parts[0].Text,
  85. },
  86. },
  87. }
  88. return response, nil
  89. }
  90. // SendPromptStream 模拟发送消息(流式)
  91. func (m *MockClient) SendPromptStream(ctx context.Context, sessionID string, prompt *PromptRequest) (<-chan string, error) {
  92. m.mu.RLock()
  93. defer m.mu.RUnlock()
  94. if m.shouldFail {
  95. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  96. }
  97. if _, exists := m.sessions[sessionID]; !exists {
  98. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  99. }
  100. ch := make(chan string, 10)
  101. // 模拟流式响应
  102. go func() {
  103. defer close(ch)
  104. messages := []string{
  105. `{"type": "text", "text": "这是"}`,
  106. `{"type": "text", "text": "模拟的"}`,
  107. `{"type": "text", "text": "流式响应。"}`,
  108. `{"info": {"id": "msg-mock-stream-001", "role": "assistant"}}`,
  109. }
  110. for _, msg := range messages {
  111. select {
  112. case <-ctx.Done():
  113. return
  114. case ch <- msg:
  115. }
  116. }
  117. }()
  118. return ch, nil
  119. }
  120. // GetSession 模拟获取会话
  121. func (m *MockClient) GetSession(ctx context.Context, sessionID string) (*Session, error) {
  122. m.mu.RLock()
  123. defer m.mu.RUnlock()
  124. if m.shouldFail {
  125. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  126. }
  127. session, exists := m.sessions[sessionID]
  128. if !exists {
  129. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  130. }
  131. return session, nil
  132. }
  133. // ListSessions 模拟获取会话列表
  134. func (m *MockClient) ListSessions(ctx context.Context) ([]Session, error) {
  135. m.mu.RLock()
  136. defer m.mu.RUnlock()
  137. if m.shouldFail {
  138. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  139. }
  140. sessions := make([]Session, 0, len(m.sessions))
  141. for _, session := range m.sessions {
  142. sessions = append(sessions, *session)
  143. }
  144. return sessions, nil
  145. }
  146. // GetBaseURL 获取基础URL
  147. func (m *MockClient) GetBaseURL() string {
  148. return m.baseURL
  149. }
  150. // GetSessionMessages 模拟获取会话消息
  151. func (m *MockClient) GetSessionMessages(ctx context.Context, sessionID string, limit int) ([]SessionMessage, error) {
  152. m.mu.RLock()
  153. defer m.mu.RUnlock()
  154. if m.shouldFail {
  155. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  156. }
  157. if _, exists := m.sessions[sessionID]; !exists {
  158. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  159. }
  160. // 创建模拟消息
  161. messages := []SessionMessage{
  162. {
  163. Info: map[string]interface{}{
  164. "id": "msg-mock-001",
  165. "sessionID": sessionID,
  166. "role": "user",
  167. "time": map[string]interface{}{
  168. "created": float64(1770797040043),
  169. },
  170. "agent": "code-sql",
  171. "model": map[string]interface{}{
  172. "providerID": "deepseek",
  173. "modelID": "deepseek-reasoner",
  174. },
  175. },
  176. Parts: []map[string]interface{}{
  177. {
  178. "id": "prt-mock-001",
  179. "type": "text",
  180. "text": "编写查询昨天销售总额的sql代码。简单,用户没有提到的,保持最简单的方式返回内容",
  181. },
  182. },
  183. },
  184. {
  185. Info: map[string]interface{}{
  186. "id": "msg-mock-002",
  187. "sessionID": sessionID,
  188. "role": "assistant",
  189. "time": map[string]interface{}{
  190. "created": float64(1770797040072),
  191. "completed": float64(1770797045880),
  192. },
  193. "parentID": "msg-mock-001",
  194. "modelID": "deepseek-reasoner",
  195. "providerID": "deepseek",
  196. "mode": "code-sql",
  197. "agent": "code-sql",
  198. "path": map[string]interface{}{
  199. "cwd": "/Users/kenqdy/Documents/v-bdx-workspace",
  200. "root": "/",
  201. },
  202. "cost": 0.000514752,
  203. "tokens": map[string]interface{}{
  204. "input": 55,
  205. "output": 114,
  206. "reasoning": 68,
  207. "cache": map[string]interface{}{
  208. "read": 15104,
  209. "write": 0,
  210. },
  211. },
  212. "finish": "tool-calls",
  213. },
  214. Parts: []map[string]interface{}{
  215. {
  216. "id": "prt-mock-002",
  217. "type": "step-start",
  218. },
  219. {
  220. "id": "prt-mock-003",
  221. "type": "reasoning",
  222. "text": "用户想要查询昨天销售总额的SQL代码...",
  223. },
  224. },
  225. },
  226. }
  227. // 如果有限制,截取消息
  228. if limit > 0 && limit < len(messages) {
  229. messages = messages[:limit]
  230. }
  231. return messages, nil
  232. }
  233. // GetPort 获取端口
  234. func (m *MockClient) GetPort() int {
  235. return m.port
  236. }