package opencode import ( "context" "fmt" "sync" ) // 确保 MockClient 实现 OpenCodeClient 接口 var _ OpenCodeClient = (*MockClient)(nil) // MockClient opencode 客户端的模拟实现 type MockClient struct { mu sync.RWMutex sessions map[string]*Session nextSessionID int port int baseURL string shouldFail bool failMessage string } // NewMockClient 创建新的模拟客户端 func NewMockClient(port int) *MockClient { return &MockClient{ sessions: make(map[string]*Session), nextSessionID: 1, port: port, baseURL: fmt.Sprintf("http://127.0.0.1:%d", port), shouldFail: false, } } // SetShouldFail 设置客户端是否应该失败 func (m *MockClient) SetShouldFail(shouldFail bool, message string) { m.mu.Lock() defer m.mu.Unlock() m.shouldFail = shouldFail m.failMessage = message } // CreateSession 模拟创建会话 func (m *MockClient) CreateSession(ctx context.Context, title string) (*Session, error) { m.mu.Lock() defer m.mu.Unlock() if m.shouldFail { return nil, fmt.Errorf("模拟错误: %s", m.failMessage) } sessionID := fmt.Sprintf("ses-mock-%d", m.nextSessionID) m.nextSessionID++ session := &Session{ ID: sessionID, Title: title, CreatedAt: "2024-01-01T00:00:00Z", } m.sessions[sessionID] = session return session, nil } // SendPrompt 模拟发送消息(同步) func (m *MockClient) SendPrompt(ctx context.Context, sessionID string, prompt *PromptRequest) (*PromptResponse, error) { m.mu.RLock() defer m.mu.RUnlock() if m.shouldFail { return nil, fmt.Errorf("模拟错误: %s", m.failMessage) } if _, exists := m.sessions[sessionID]; !exists { return nil, fmt.Errorf("会话不存在: %s", sessionID) } response := &PromptResponse{ Info: AssistantMessage{ ID: "msg-mock-001", Role: "assistant", SessionID: sessionID, Content: "这是模拟的响应:" + prompt.Parts[0].Text, Agent: "opencode", ModelID: "mock-model", ProviderID: "mock-provider", Tokens: TokenInfo{ Input: 10, Output: 20, }, Time: map[string]interface{}{ "started": "2024-01-01T00:00:00Z", "finished": "2024-01-01T00:00:01Z", }, }, Parts: []interface{}{ map[string]interface{}{ "type": "text", "text": "这是模拟的响应:" + prompt.Parts[0].Text, }, }, } return response, nil } // SendPromptStream 模拟发送消息(流式) func (m *MockClient) SendPromptStream(ctx context.Context, sessionID string, prompt *PromptRequest) (<-chan string, error) { m.mu.RLock() defer m.mu.RUnlock() if m.shouldFail { return nil, fmt.Errorf("模拟错误: %s", m.failMessage) } if _, exists := m.sessions[sessionID]; !exists { return nil, fmt.Errorf("会话不存在: %s", sessionID) } ch := make(chan string, 10) // 模拟流式响应 go func() { defer close(ch) messages := []string{ `{"type": "text", "text": "这是"}`, `{"type": "text", "text": "模拟的"}`, `{"type": "text", "text": "流式响应。"}`, `{"info": {"id": "msg-mock-stream-001", "role": "assistant"}}`, } for _, msg := range messages { select { case <-ctx.Done(): return case ch <- msg: } } }() return ch, nil } // GetSession 模拟获取会话 func (m *MockClient) GetSession(ctx context.Context, sessionID string) (*Session, error) { m.mu.RLock() defer m.mu.RUnlock() if m.shouldFail { return nil, fmt.Errorf("模拟错误: %s", m.failMessage) } session, exists := m.sessions[sessionID] if !exists { return nil, fmt.Errorf("会话不存在: %s", sessionID) } return session, nil } // ListSessions 模拟获取会话列表 func (m *MockClient) ListSessions(ctx context.Context) ([]Session, error) { m.mu.RLock() defer m.mu.RUnlock() if m.shouldFail { return nil, fmt.Errorf("模拟错误: %s", m.failMessage) } sessions := make([]Session, 0, len(m.sessions)) for _, session := range m.sessions { sessions = append(sessions, *session) } return sessions, nil } // GetBaseURL 获取基础URL func (m *MockClient) GetBaseURL() string { return m.baseURL } // GetSessionMessages 模拟获取会话消息 func (m *MockClient) GetSessionMessages(ctx context.Context, sessionID string, limit int) ([]SessionMessage, error) { m.mu.RLock() defer m.mu.RUnlock() if m.shouldFail { return nil, fmt.Errorf("模拟错误: %s", m.failMessage) } if _, exists := m.sessions[sessionID]; !exists { return nil, fmt.Errorf("会话不存在: %s", sessionID) } // 创建模拟消息 messages := []SessionMessage{ { Info: map[string]interface{}{ "id": "msg-mock-001", "sessionID": sessionID, "role": "user", "time": map[string]interface{}{ "created": float64(1770797040043), }, "agent": "code-sql", "model": map[string]interface{}{ "providerID": "deepseek", "modelID": "deepseek-reasoner", }, }, Parts: []map[string]interface{}{ { "id": "prt-mock-001", "type": "text", "text": "编写查询昨天销售总额的sql代码。简单,用户没有提到的,保持最简单的方式返回内容", }, }, }, { Info: map[string]interface{}{ "id": "msg-mock-002", "sessionID": sessionID, "role": "assistant", "time": map[string]interface{}{ "created": float64(1770797040072), "completed": float64(1770797045880), }, "parentID": "msg-mock-001", "modelID": "deepseek-reasoner", "providerID": "deepseek", "mode": "code-sql", "agent": "code-sql", "path": map[string]interface{}{ "cwd": "/Users/kenqdy/Documents/v-bdx-workspace", "root": "/", }, "cost": 0.000514752, "tokens": map[string]interface{}{ "input": 55, "output": 114, "reasoning": 68, "cache": map[string]interface{}{ "read": 15104, "write": 0, }, }, "finish": "tool-calls", }, Parts: []map[string]interface{}{ { "id": "prt-mock-002", "type": "step-start", }, { "id": "prt-mock-003", "type": "reasoning", "text": "用户想要查询昨天销售总额的SQL代码...", }, }, }, } // 如果有限制,截取消息 if limit > 0 && limit < len(messages) { messages = messages[:limit] } return messages, nil } // GetPort 获取端口 func (m *MockClient) GetPort() int { return m.port }