| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- package factory
-
- import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
- )
-
- // 极简HTTP客户端实现
- type SimpleHTTPClient struct {
- config *ModelConfig
- client *http.Client
- }
-
- // NewSimpleHTTPClient 创建客户端(每次请求创建新的)
- func NewSimpleHTTPClient(config *ModelConfig) *SimpleHTTPClient {
- return &SimpleHTTPClient{
- config: config,
- client: &http.Client{
- Timeout: time.Duration(config.Timeout) * time.Second,
- },
- }
- }
-
- // Chat 一次性返回(阻塞式)
- func (c *SimpleHTTPClient) Chat(ctx context.Context, messages []ChatMessage) (string, error) {
- // 构建请求
- reqBody := map[string]interface{}{
- "model": c.config.Model,
- "messages": messages,
- "max_tokens": c.config.MaxTokens,
- "temperature": c.config.Temperature,
- "stream": false,
- }
-
- // 根据不同提供商调整请求格式
- reqBody = c.adjustRequestForProvider(reqBody)
-
- // 发送请求
- resp, err := c.doRequest(ctx, reqBody)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
-
- // 解析响应
- var result map[string]interface{}
- if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- return "", fmt.Errorf("解析响应失败: %v", err)
- }
-
- // 提取回复内容
- return c.extractContent(result), nil
- }
-
- // ChatStream 流式返回
- func (c *SimpleHTTPClient) ChatStream(ctx context.Context, messages []ChatMessage) (<-chan string, error) {
- ch := make(chan string)
-
- go func() {
- defer close(ch)
-
- // 构建请求
- reqBody := map[string]interface{}{
- "model": c.config.Model,
- "messages": messages,
- "max_tokens": c.config.MaxTokens,
- "temperature": c.config.Temperature,
- "stream": true,
- }
-
- reqBody = c.adjustRequestForProvider(reqBody)
-
- // 发送请求
- resp, err := c.doRequest(ctx, reqBody)
- if err != nil {
- ch <- fmt.Sprintf("错误: %v", err)
- return
- }
- defer resp.Body.Close()
-
- // 流式读取
- reader := c.createStreamReader(resp.Body)
- for {
- select {
- case <-ctx.Done():
- return
- default:
- line, err := reader()
- if err != nil {
- if err != io.EOF {
- ch <- fmt.Sprintf("读取错误: %v", err)
- }
- return
- }
- if line != "" {
- ch <- line
- }
- }
- }
- }()
-
- return ch, nil
- }
-
- // doRequest 执行HTTP请求
- func (c *SimpleHTTPClient) doRequest(ctx context.Context, body map[string]interface{}) (*http.Response, error) {
- // 序列化请求体
- jsonData, err := json.Marshal(body)
- if err != nil {
- return nil, fmt.Errorf("序列化请求失败: %v", err)
- }
-
- // 确定端点URL
- endpoint := c.getEndpointURL()
-
- // 创建请求
- req, err := http.NewRequestWithContext(ctx, "POST", endpoint, strings.NewReader(string(jsonData)))
- if err != nil {
- return nil, fmt.Errorf("创建请求失败: %v", err)
- }
-
- // 设置请求头
- c.setHeaders(req)
-
- // 发送请求
- resp, err := c.client.Do(req)
- if err != nil {
- return nil, fmt.Errorf("请求失败: %v", err)
- }
-
- // 检查状态码
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- resp.Body.Close()
- return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body))
- }
-
- return resp, nil
- }
-
- // getEndpointURL 获取API端点URL
- func (c *SimpleHTTPClient) getEndpointURL() string {
- switch c.config.Provider {
- case "openai":
- return c.config.BaseURL + "/chat/completions"
- case "deepseek":
- return c.config.BaseURL + "/chat/completions"
- case "claude":
- return c.config.BaseURL + "/messages"
- default:
- return c.config.BaseURL + "/chat/completions"
- }
- }
-
- // setHeaders 设置请求头
- func (c *SimpleHTTPClient) setHeaders(req *http.Request) {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
-
- // 提供商特定头
- switch c.config.Provider {
- case "claude":
- req.Header.Set("x-api-key", c.config.APIKey)
- req.Header.Set("anthropic-version", "2023-06-01")
- case "openai", "deepseek":
- // 标准头
- }
- }
-
- // adjustRequestForProvider 调整请求格式
- func (c *SimpleHTTPClient) adjustRequestForProvider(reqBody map[string]interface{}) map[string]interface{} {
- switch c.config.Provider {
- case "claude":
- // Claude使用不同的格式
- return map[string]interface{}{
- "model": c.config.Model,
- "messages": reqBody["messages"],
- "max_tokens": c.config.MaxTokens,
- }
- default:
- return reqBody
- }
- }
-
- // extractContent 从响应中提取内容
- func (c *SimpleHTTPClient) extractContent(resp map[string]interface{}) string {
- switch c.config.Provider {
- case "openai", "deepseek":
- if choices, ok := resp["choices"].([]interface{}); ok && len(choices) > 0 {
- if choice, ok := choices[0].(map[string]interface{}); ok {
- if message, ok := choice["message"].(map[string]interface{}); ok {
- if content, ok := message["content"].(string); ok {
- return content
- }
- }
- }
- }
- case "claude":
- if content, ok := resp["content"].([]interface{}); ok && len(content) > 0 {
- if first, ok := content[0].(map[string]interface{}); ok {
- if text, ok := first["text"].(string); ok {
- return text
- }
- }
- }
- }
- return ""
- }
-
- // createStreamReader 创建流式读取器
- func (c *SimpleHTTPClient) createStreamReader(body io.Reader) func() (string, error) {
- buf := make([]byte, 4096)
- var leftover []byte
-
- return func() (string, error) {
- // 读取数据
- n, err := body.Read(buf)
- if err != nil {
- return "", err
- }
-
- data := append(leftover, buf[:n]...)
- lines := strings.Split(string(data), "\n")
-
- // 处理完整的行
- var result strings.Builder
- for i, line := range lines {
- if i == len(lines)-1 {
- // 最后一行可能不完整,留到下次
- leftover = []byte(line)
- continue
- }
-
- line = strings.TrimSpace(line)
- if line == "" || !strings.HasPrefix(line, "data: ") {
- continue
- }
-
- // 去除"data: "前缀
- line = line[6:]
- if line == "[DONE]" {
- return "", io.EOF
- }
-
- // 解析JSON
- var chunk map[string]interface{}
- if err := json.Unmarshal([]byte(line), &chunk); err != nil {
- continue
- }
-
- // 提取内容
- content := c.extractStreamContent(chunk)
- if content != "" {
- result.WriteString(content)
- }
- }
-
- return result.String(), nil
- }
- }
-
- // extractStreamContent 提取流式响应内容
- func (c *SimpleHTTPClient) extractStreamContent(chunk map[string]interface{}) string {
- switch c.config.Provider {
- case "openai", "deepseek":
- if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
- if choice, ok := choices[0].(map[string]interface{}); ok {
- if delta, ok := choice["delta"].(map[string]interface{}); ok {
- if content, ok := delta["content"].(string); ok {
- return content
- }
- }
- }
- }
- case "claude":
- if content, ok := chunk["content"].([]interface{}); ok && len(content) > 0 {
- if first, ok := content[0].(map[string]interface{}); ok {
- if text, ok := first["text"].(string); ok {
- return text
- }
- }
- }
- }
- return ""
- }
-
- // ==================== 工厂函数 ====================
-
- // CreateClient 工厂函数:根据配置创建客户端
- func CreateClient(configKey string) (AIClient, error) {
- // 这里模拟从"数据库"读取配置
- // 实际项目中应该从数据库/配置中心读取
- config, err := GetConfigFromMockDB(configKey)
- if err != nil {
- return nil, err
- }
-
- return NewSimpleHTTPClient(config), nil
- }
-
- // GetConfigFromMockDB 模拟从数据库获取配置
- func GetConfigFromMockDB(configKey string) (*ModelConfig, error) {
- // 这里调用模拟数据库
- // 实际应该调用数据库查询
- return &ModelConfig{
- Provider: "openai",
- BaseURL: "https://api.openai.com/v1",
- APIKey: "fake-key-for-demo",
- Model: "gpt-3.5-turbo",
- MaxTokens: 1000,
- Temperature: 0.7,
- Timeout: 30,
- }, nil
- }
|