| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- package opencode
-
- import (
- "context"
- "fmt"
- "sync"
- )
-
- // 确保 MockClient 实现 OpenCodeClient 接口
- var _ OpenCodeClient = (*MockClient)(nil)
-
- // MockClient opencode 客户端的模拟实现
- type MockClient struct {
- mu sync.RWMutex
- sessions map[string]*Session
- nextSessionID int
- port int
- baseURL string
- shouldFail bool
- failMessage string
- }
-
- // NewMockClient 创建新的模拟客户端
- func NewMockClient(port int) *MockClient {
- return &MockClient{
- sessions: make(map[string]*Session),
- nextSessionID: 1,
- port: port,
- baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
- shouldFail: false,
- }
- }
-
- // SetShouldFail 设置客户端是否应该失败
- func (m *MockClient) SetShouldFail(shouldFail bool, message string) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.shouldFail = shouldFail
- m.failMessage = message
- }
-
- // CreateSession 模拟创建会话
- func (m *MockClient) CreateSession(ctx context.Context, title string) (*Session, error) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- if m.shouldFail {
- return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
- }
-
- sessionID := fmt.Sprintf("ses-mock-%d", m.nextSessionID)
- m.nextSessionID++
-
- session := &Session{
- ID: sessionID,
- Title: title,
- CreatedAt: "2024-01-01T00:00:00Z",
- }
-
- m.sessions[sessionID] = session
- return session, nil
- }
-
- // SendPrompt 模拟发送消息(同步)
- func (m *MockClient) SendPrompt(ctx context.Context, sessionID string, prompt *PromptRequest) (*PromptResponse, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- if m.shouldFail {
- return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
- }
-
- if _, exists := m.sessions[sessionID]; !exists {
- return nil, fmt.Errorf("会话不存在: %s", sessionID)
- }
-
- response := &PromptResponse{
- Info: AssistantMessage{
- ID: "msg-mock-001",
- Role: "assistant",
- SessionID: sessionID,
- Content: "这是模拟的响应:" + prompt.Parts[0].Text,
- Agent: "opencode",
- ModelID: "mock-model",
- ProviderID: "mock-provider",
- Tokens: TokenInfo{
- Input: 10,
- Output: 20,
- },
- Time: map[string]interface{}{
- "started": "2024-01-01T00:00:00Z",
- "finished": "2024-01-01T00:00:01Z",
- },
- },
- Parts: []interface{}{
- map[string]interface{}{
- "type": "text",
- "text": "这是模拟的响应:" + prompt.Parts[0].Text,
- },
- },
- }
-
- return response, nil
- }
-
- // SendPromptStream 模拟发送消息(流式)
- func (m *MockClient) SendPromptStream(ctx context.Context, sessionID string, prompt *PromptRequest) (<-chan string, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- if m.shouldFail {
- return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
- }
-
- if _, exists := m.sessions[sessionID]; !exists {
- return nil, fmt.Errorf("会话不存在: %s", sessionID)
- }
-
- ch := make(chan string, 10)
-
- // 模拟流式响应
- go func() {
- defer close(ch)
-
- messages := []string{
- `{"type": "text", "text": "这是"}`,
- `{"type": "text", "text": "模拟的"}`,
- `{"type": "text", "text": "流式响应。"}`,
- `{"info": {"id": "msg-mock-stream-001", "role": "assistant"}}`,
- }
-
- for _, msg := range messages {
- select {
- case <-ctx.Done():
- return
- case ch <- msg:
- }
- }
- }()
-
- return ch, nil
- }
-
- // GetSession 模拟获取会话
- func (m *MockClient) GetSession(ctx context.Context, sessionID string) (*Session, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- if m.shouldFail {
- return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
- }
-
- session, exists := m.sessions[sessionID]
- if !exists {
- return nil, fmt.Errorf("会话不存在: %s", sessionID)
- }
-
- return session, nil
- }
-
- // ListSessions 模拟获取会话列表
- func (m *MockClient) ListSessions(ctx context.Context) ([]Session, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- if m.shouldFail {
- return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
- }
-
- sessions := make([]Session, 0, len(m.sessions))
- for _, session := range m.sessions {
- sessions = append(sessions, *session)
- }
-
- return sessions, nil
- }
-
- // GetBaseURL 获取基础URL
- func (m *MockClient) GetBaseURL() string {
- return m.baseURL
- }
-
- // GetSessionMessages 模拟获取会话消息
- func (m *MockClient) GetSessionMessages(ctx context.Context, sessionID string, limit int) ([]SessionMessage, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- if m.shouldFail {
- return nil, fmt.Errorf("模拟错误: %s", m.failMessage)
- }
-
- if _, exists := m.sessions[sessionID]; !exists {
- return nil, fmt.Errorf("会话不存在: %s", sessionID)
- }
-
- // 创建模拟消息
- messages := []SessionMessage{
- {
- Info: map[string]interface{}{
- "id": "msg-mock-001",
- "sessionID": sessionID,
- "role": "user",
- "time": map[string]interface{}{
- "created": float64(1770797040043),
- },
- "agent": "code-sql",
- "model": map[string]interface{}{
- "providerID": "deepseek",
- "modelID": "deepseek-reasoner",
- },
- },
- Parts: []map[string]interface{}{
- {
- "id": "prt-mock-001",
- "type": "text",
- "text": "编写查询昨天销售总额的sql代码。简单,用户没有提到的,保持最简单的方式返回内容",
- },
- },
- },
- {
- Info: map[string]interface{}{
- "id": "msg-mock-002",
- "sessionID": sessionID,
- "role": "assistant",
- "time": map[string]interface{}{
- "created": float64(1770797040072),
- "completed": float64(1770797045880),
- },
- "parentID": "msg-mock-001",
- "modelID": "deepseek-reasoner",
- "providerID": "deepseek",
- "mode": "code-sql",
- "agent": "code-sql",
- "path": map[string]interface{}{
- "cwd": "/Users/kenqdy/Documents/v-bdx-workspace",
- "root": "/",
- },
- "cost": 0.000514752,
- "tokens": map[string]interface{}{
- "input": 55,
- "output": 114,
- "reasoning": 68,
- "cache": map[string]interface{}{
- "read": 15104,
- "write": 0,
- },
- },
- "finish": "tool-calls",
- },
- Parts: []map[string]interface{}{
- {
- "id": "prt-mock-002",
- "type": "step-start",
- },
- {
- "id": "prt-mock-003",
- "type": "reasoning",
- "text": "用户想要查询昨天销售总额的SQL代码...",
- },
- },
- },
- }
-
- // 如果有限制,截取消息
- if limit > 0 && limit < len(messages) {
- messages = messages[:limit]
- }
-
- return messages, nil
- }
-
- // GetPort 获取端口
- func (m *MockClient) GetPort() int {
- return m.port
- }
|