Açıklama Yok
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mock_client.go 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package opencode
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. )
  7. // 确保 MockClient 实现 OpenCodeClient 接口
  8. var _ OpenCodeClient = (*MockClient)(nil)
  9. // MockClient opencode 客户端的模拟实现
  10. type MockClient struct {
  11. mu sync.RWMutex
  12. sessions map[string]*Session
  13. nextSessionID int
  14. port int
  15. baseURL string
  16. shouldFail bool
  17. failMessage string
  18. }
  19. // NewMockClient 创建新的模拟客户端
  20. func NewMockClient(port int) *MockClient {
  21. return &MockClient{
  22. sessions: make(map[string]*Session),
  23. nextSessionID: 1,
  24. port: port,
  25. baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
  26. shouldFail: false,
  27. }
  28. }
  29. // SetShouldFail 设置客户端是否应该失败
  30. func (m *MockClient) SetShouldFail(shouldFail bool, message string) {
  31. m.mu.Lock()
  32. defer m.mu.Unlock()
  33. m.shouldFail = shouldFail
  34. m.failMessage = message
  35. }
  36. // CreateSession 模拟创建会话
  37. func (m *MockClient) CreateSession(ctx context.Context, title string) (*Session, error) {
  38. m.mu.Lock()
  39. defer m.mu.Unlock()
  40. if m.shouldFail {
  41. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  42. }
  43. sessionID := fmt.Sprintf("ses-mock-%d", m.nextSessionID)
  44. m.nextSessionID++
  45. session := &Session{
  46. ID: sessionID,
  47. Title: title,
  48. CreatedAt: "2024-01-01T00:00:00Z",
  49. }
  50. m.sessions[sessionID] = session
  51. return session, nil
  52. }
  53. // SendPrompt 模拟发送消息(同步)
  54. func (m *MockClient) SendPrompt(ctx context.Context, sessionID string, prompt *PromptRequest) (*PromptResponse, error) {
  55. m.mu.RLock()
  56. defer m.mu.RUnlock()
  57. if m.shouldFail {
  58. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  59. }
  60. if _, exists := m.sessions[sessionID]; !exists {
  61. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  62. }
  63. response := &PromptResponse{
  64. Info: AssistantMessage{
  65. ID: "msg-mock-001",
  66. Role: "assistant",
  67. SessionID: sessionID,
  68. Content: "这是模拟的响应:" + prompt.Parts[0].Text,
  69. Agent: "opencode",
  70. ModelID: "mock-model",
  71. ProviderID: "mock-provider",
  72. Tokens: TokenInfo{
  73. Input: 10,
  74. Output: 20,
  75. },
  76. Time: map[string]interface{}{
  77. "started": "2024-01-01T00:00:00Z",
  78. "finished": "2024-01-01T00:00:01Z",
  79. },
  80. },
  81. Parts: []interface{}{
  82. map[string]interface{}{
  83. "type": "text",
  84. "text": "这是模拟的响应:" + prompt.Parts[0].Text,
  85. },
  86. },
  87. }
  88. return response, nil
  89. }
  90. // SendPromptStream 模拟发送消息(流式)
  91. func (m *MockClient) SendPromptStream(ctx context.Context, sessionID string, prompt *PromptRequest) (<-chan string, error) {
  92. m.mu.RLock()
  93. defer m.mu.RUnlock()
  94. if m.shouldFail {
  95. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  96. }
  97. if _, exists := m.sessions[sessionID]; !exists {
  98. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  99. }
  100. ch := make(chan string, 10)
  101. // 模拟流式响应
  102. go func() {
  103. defer close(ch)
  104. messages := []string{
  105. `{"type": "text", "text": "这是"}`,
  106. `{"type": "text", "text": "模拟的"}`,
  107. `{"type": "text", "text": "流式响应。"}`,
  108. `{"info": {"id": "msg-mock-stream-001", "role": "assistant"}}`,
  109. }
  110. for _, msg := range messages {
  111. select {
  112. case <-ctx.Done():
  113. return
  114. case ch <- msg:
  115. }
  116. }
  117. }()
  118. return ch, nil
  119. }
  120. // GetSession 模拟获取会话
  121. func (m *MockClient) GetSession(ctx context.Context, sessionID string) (*Session, error) {
  122. m.mu.RLock()
  123. defer m.mu.RUnlock()
  124. if m.shouldFail {
  125. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  126. }
  127. session, exists := m.sessions[sessionID]
  128. if !exists {
  129. return nil, fmt.Errorf("会话不存在: %s", sessionID)
  130. }
  131. return session, nil
  132. }
  133. // ListSessions 模拟获取会话列表
  134. func (m *MockClient) ListSessions(ctx context.Context) ([]Session, error) {
  135. m.mu.RLock()
  136. defer m.mu.RUnlock()
  137. if m.shouldFail {
  138. return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
  139. }
  140. sessions := make([]Session, 0, len(m.sessions))
  141. for _, session := range m.sessions {
  142. sessions = append(sessions, *session)
  143. }
  144. return sessions, nil
  145. }
  146. // GetBaseURL 获取基础URL
  147. func (m *MockClient) GetBaseURL() string {
  148. return m.baseURL
  149. }
  150. // GetPort 获取端口
  151. func (m *MockClient) GetPort() int {
  152. return m.port
  153. }