Açıklama Yok
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

factory.go 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. package factory
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. )
  11. // 极简HTTP客户端实现
  12. type SimpleHTTPClient struct {
  13. config *ModelConfig
  14. client *http.Client
  15. }
  16. // NewSimpleHTTPClient 创建客户端(每次请求创建新的)
  17. func NewSimpleHTTPClient(config *ModelConfig) *SimpleHTTPClient {
  18. return &SimpleHTTPClient{
  19. config: config,
  20. client: &http.Client{
  21. Timeout: time.Duration(config.Timeout) * time.Second,
  22. },
  23. }
  24. }
  25. // Chat 一次性返回(阻塞式)
  26. func (c *SimpleHTTPClient) Chat(ctx context.Context, messages []ChatMessage) (string, error) {
  27. // 构建请求
  28. reqBody := map[string]interface{}{
  29. "model": c.config.Model,
  30. "messages": messages,
  31. "max_tokens": c.config.MaxTokens,
  32. "temperature": c.config.Temperature,
  33. "stream": false,
  34. }
  35. // 根据不同提供商调整请求格式
  36. reqBody = c.adjustRequestForProvider(reqBody)
  37. // 发送请求
  38. resp, err := c.doRequest(ctx, reqBody)
  39. if err != nil {
  40. return "", err
  41. }
  42. defer resp.Body.Close()
  43. // 解析响应
  44. var result map[string]interface{}
  45. if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
  46. return "", fmt.Errorf("解析响应失败: %v", err)
  47. }
  48. // 提取回复内容
  49. return c.extractContent(result), nil
  50. }
  51. // ChatStream 流式返回
  52. func (c *SimpleHTTPClient) ChatStream(ctx context.Context, messages []ChatMessage) (<-chan string, error) {
  53. ch := make(chan string)
  54. go func() {
  55. defer close(ch)
  56. // 构建请求
  57. reqBody := map[string]interface{}{
  58. "model": c.config.Model,
  59. "messages": messages,
  60. "max_tokens": c.config.MaxTokens,
  61. "temperature": c.config.Temperature,
  62. "stream": true,
  63. }
  64. reqBody = c.adjustRequestForProvider(reqBody)
  65. // 发送请求
  66. resp, err := c.doRequest(ctx, reqBody)
  67. if err != nil {
  68. ch <- fmt.Sprintf("错误: %v", err)
  69. return
  70. }
  71. defer resp.Body.Close()
  72. // 流式读取
  73. reader := c.createStreamReader(resp.Body)
  74. for {
  75. select {
  76. case <-ctx.Done():
  77. return
  78. default:
  79. line, err := reader()
  80. if err != nil {
  81. if err != io.EOF {
  82. ch <- fmt.Sprintf("读取错误: %v", err)
  83. }
  84. return
  85. }
  86. if line != "" {
  87. ch <- line
  88. }
  89. }
  90. }
  91. }()
  92. return ch, nil
  93. }
  94. // doRequest 执行HTTP请求
  95. func (c *SimpleHTTPClient) doRequest(ctx context.Context, body map[string]interface{}) (*http.Response, error) {
  96. // 序列化请求体
  97. jsonData, err := json.Marshal(body)
  98. if err != nil {
  99. return nil, fmt.Errorf("序列化请求失败: %v", err)
  100. }
  101. // 确定端点URL
  102. endpoint := c.getEndpointURL()
  103. // 创建请求
  104. req, err := http.NewRequestWithContext(ctx, "POST", endpoint, strings.NewReader(string(jsonData)))
  105. if err != nil {
  106. return nil, fmt.Errorf("创建请求失败: %v", err)
  107. }
  108. // 设置请求头
  109. c.setHeaders(req)
  110. // 发送请求
  111. resp, err := c.client.Do(req)
  112. if err != nil {
  113. return nil, fmt.Errorf("请求失败: %v", err)
  114. }
  115. // 检查状态码
  116. if resp.StatusCode != http.StatusOK {
  117. body, _ := io.ReadAll(resp.Body)
  118. resp.Body.Close()
  119. return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body))
  120. }
  121. return resp, nil
  122. }
  123. // getEndpointURL 获取API端点URL
  124. func (c *SimpleHTTPClient) getEndpointURL() string {
  125. switch c.config.Provider {
  126. case "openai":
  127. return c.config.BaseURL + "/chat/completions"
  128. case "deepseek":
  129. return c.config.BaseURL + "/chat/completions"
  130. case "claude":
  131. return c.config.BaseURL + "/messages"
  132. default:
  133. return c.config.BaseURL + "/chat/completions"
  134. }
  135. }
  136. // setHeaders 设置请求头
  137. func (c *SimpleHTTPClient) setHeaders(req *http.Request) {
  138. req.Header.Set("Content-Type", "application/json")
  139. req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
  140. // 提供商特定头
  141. switch c.config.Provider {
  142. case "claude":
  143. req.Header.Set("x-api-key", c.config.APIKey)
  144. req.Header.Set("anthropic-version", "2023-06-01")
  145. case "openai", "deepseek":
  146. // 标准头
  147. }
  148. }
  149. // adjustRequestForProvider 调整请求格式
  150. func (c *SimpleHTTPClient) adjustRequestForProvider(reqBody map[string]interface{}) map[string]interface{} {
  151. switch c.config.Provider {
  152. case "claude":
  153. // Claude使用不同的格式
  154. return map[string]interface{}{
  155. "model": c.config.Model,
  156. "messages": reqBody["messages"],
  157. "max_tokens": c.config.MaxTokens,
  158. }
  159. default:
  160. return reqBody
  161. }
  162. }
  163. // extractContent 从响应中提取内容
  164. func (c *SimpleHTTPClient) extractContent(resp map[string]interface{}) string {
  165. switch c.config.Provider {
  166. case "openai", "deepseek":
  167. if choices, ok := resp["choices"].([]interface{}); ok && len(choices) > 0 {
  168. if choice, ok := choices[0].(map[string]interface{}); ok {
  169. if message, ok := choice["message"].(map[string]interface{}); ok {
  170. if content, ok := message["content"].(string); ok {
  171. return content
  172. }
  173. }
  174. }
  175. }
  176. case "claude":
  177. if content, ok := resp["content"].([]interface{}); ok && len(content) > 0 {
  178. if first, ok := content[0].(map[string]interface{}); ok {
  179. if text, ok := first["text"].(string); ok {
  180. return text
  181. }
  182. }
  183. }
  184. }
  185. return ""
  186. }
  187. // createStreamReader 创建流式读取器
  188. func (c *SimpleHTTPClient) createStreamReader(body io.Reader) func() (string, error) {
  189. buf := make([]byte, 4096)
  190. var leftover []byte
  191. return func() (string, error) {
  192. // 读取数据
  193. n, err := body.Read(buf)
  194. if err != nil {
  195. return "", err
  196. }
  197. data := append(leftover, buf[:n]...)
  198. lines := strings.Split(string(data), "\n")
  199. // 处理完整的行
  200. var result strings.Builder
  201. for i, line := range lines {
  202. if i == len(lines)-1 {
  203. // 最后一行可能不完整,留到下次
  204. leftover = []byte(line)
  205. continue
  206. }
  207. line = strings.TrimSpace(line)
  208. if line == "" || !strings.HasPrefix(line, "data: ") {
  209. continue
  210. }
  211. // 去除"data: "前缀
  212. line = line[6:]
  213. if line == "[DONE]" {
  214. return "", io.EOF
  215. }
  216. // 解析JSON
  217. var chunk map[string]interface{}
  218. if err := json.Unmarshal([]byte(line), &chunk); err != nil {
  219. continue
  220. }
  221. // 提取内容
  222. content := c.extractStreamContent(chunk)
  223. if content != "" {
  224. result.WriteString(content)
  225. }
  226. }
  227. return result.String(), nil
  228. }
  229. }
  230. // extractStreamContent 提取流式响应内容
  231. func (c *SimpleHTTPClient) extractStreamContent(chunk map[string]interface{}) string {
  232. switch c.config.Provider {
  233. case "openai", "deepseek":
  234. if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
  235. if choice, ok := choices[0].(map[string]interface{}); ok {
  236. if delta, ok := choice["delta"].(map[string]interface{}); ok {
  237. if content, ok := delta["content"].(string); ok {
  238. return content
  239. }
  240. }
  241. }
  242. }
  243. case "claude":
  244. if content, ok := chunk["content"].([]interface{}); ok && len(content) > 0 {
  245. if first, ok := content[0].(map[string]interface{}); ok {
  246. if text, ok := first["text"].(string); ok {
  247. return text
  248. }
  249. }
  250. }
  251. }
  252. return ""
  253. }
  254. // ==================== 工厂函数 ====================
  255. // CreateClient 工厂函数:根据配置创建客户端
  256. func CreateClient(configKey string) (AIClient, error) {
  257. // 这里模拟从"数据库"读取配置
  258. // 实际项目中应该从数据库/配置中心读取
  259. config, err := GetConfigFromMockDB(configKey)
  260. if err != nil {
  261. return nil, err
  262. }
  263. return NewSimpleHTTPClient(config), nil
  264. }
  265. // GetConfigFromMockDB 模拟从数据库获取配置
  266. func GetConfigFromMockDB(configKey string) (*ModelConfig, error) {
  267. // 这里调用模拟数据库
  268. // 实际应该调用数据库查询
  269. return &ModelConfig{
  270. Provider: "openai",
  271. BaseURL: "https://api.openai.com/v1",
  272. APIKey: "fake-key-for-demo",
  273. Model: "gpt-3.5-turbo",
  274. MaxTokens: 1000,
  275. Temperature: 0.7,
  276. Timeout: 30,
  277. }, nil
  278. }