feat(openai): completion::stream, response::stream and basic model
src: gitea.starryskymeow.cn/xkm/translate
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/.idea
|
||||
45
model.go
Normal file
45
model.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package llm_api
|
||||
|
||||
type OpenaiChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type OpenaiChatCompletionReq struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OpenaiChatMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OpenaiResponseReasoning struct {
|
||||
Effort string `json:"effort"`
|
||||
Summary string `json:"summary,omitempty"` // auto, concise, detailed
|
||||
}
|
||||
|
||||
type OpenaiChatResponseReq struct {
|
||||
Model string `json:"model"`
|
||||
Input []OpenaiChatMessage `json:"input"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Reasoning OpenaiResponseReasoning `json:"reasoning,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OpenaiResponseStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Delta string `json:"delta"`
|
||||
}
|
||||
|
||||
type OpenaiCompletionStreamEvent struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
147
openai-completion.go
Normal file
147
openai-completion.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package llm_api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func OpenaiStreamChatCompletions(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
baseURL string,
|
||||
apiKey string,
|
||||
model string,
|
||||
reasoningEffort string,
|
||||
temperature *float64,
|
||||
msgs []OpenaiChatMessage,
|
||||
) (io.ReadCloser, error) {
|
||||
if msgs == nil || len(msgs) == 0 {
|
||||
return nil, errors.New("missing messages")
|
||||
}
|
||||
endpoint := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
body := OpenaiChatCompletionReq{
|
||||
Model: model,
|
||||
Messages: msgs,
|
||||
Temperature: temperature,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
defer func(pw *io.PipeWriter) {
|
||||
err := pw.Close()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(pw)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
_ = pw.CloseWithError(fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(b))))
|
||||
return
|
||||
}
|
||||
|
||||
sc := bufio.NewScanner(resp.Body)
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
|
||||
var dataLines []string
|
||||
flushEvent := func() bool {
|
||||
if len(dataLines) == 0 {
|
||||
_ = pw.Close()
|
||||
return true
|
||||
}
|
||||
data := strings.Join(dataLines, "\n")
|
||||
dataLines = dataLines[:0]
|
||||
if data == "[DONE]" {
|
||||
_ = pw.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
var evt OpenaiCompletionStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &evt); err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("failed to unmarshal event: %w, data=%q", err, data))
|
||||
return false
|
||||
}
|
||||
if evt.Error != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("api error: %+v", evt.Error))
|
||||
return false
|
||||
}
|
||||
if len(evt.Choices) > 0 {
|
||||
chunk := evt.Choices[0].Delta.Content
|
||||
if chunk != "" {
|
||||
if _, err := io.WriteString(pw, chunk); err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
if line == "" {
|
||||
if ok := flushEvent(); !ok {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
v := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
dataLines = append(dataLines, v)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dataLines) > 0 {
|
||||
_ = flushEvent()
|
||||
return
|
||||
}
|
||||
|
||||
if err := sc.Err(); err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return pr, nil
|
||||
}
|
||||
139
openai-response.go
Normal file
139
openai-response.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package llm_api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func OpenaiStreamChatResponses(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
baseURL string,
|
||||
apiKey string,
|
||||
model string,
|
||||
reasoningEffort string,
|
||||
temperature *float64,
|
||||
msgs []OpenaiChatMessage,
|
||||
) (io.ReadCloser, error) {
|
||||
if msgs == nil || len(msgs) == 0 {
|
||||
return nil, errors.New("missing messages")
|
||||
}
|
||||
endpoint := strings.TrimRight(baseURL, "/") + "/responses"
|
||||
|
||||
body := OpenaiChatResponseReq{
|
||||
Model: model,
|
||||
Input: msgs,
|
||||
Temperature: temperature,
|
||||
Reasoning: OpenaiResponseReasoning{Effort: reasoningEffort},
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
defer func(pw *io.PipeWriter) {
|
||||
err := pw.Close()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(pw)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
_ = pw.CloseWithError(fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(b))))
|
||||
return
|
||||
}
|
||||
|
||||
sc := bufio.NewScanner(resp.Body)
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
|
||||
var dataLines []string
|
||||
flushEvent := func() bool {
|
||||
if len(dataLines) == 0 {
|
||||
_ = pw.Close()
|
||||
return true
|
||||
}
|
||||
data := strings.Join(dataLines, "\n")
|
||||
dataLines = dataLines[:0]
|
||||
|
||||
var evt OpenaiResponseStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &evt); err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("failed to unmarshal event: %w, data=%q", err, data))
|
||||
return false
|
||||
}
|
||||
if evt.Type == "response.output_text.delta" {
|
||||
if _, err := io.WriteString(pw, evt.Delta); err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
if line == "" {
|
||||
if ok := flushEvent(); !ok {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
v := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
dataLines = append(dataLines, v)
|
||||
}
|
||||
if strings.HasPrefix(line, "event: response.completed") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(dataLines) > 0 {
|
||||
_ = flushEvent()
|
||||
return
|
||||
}
|
||||
|
||||
if err := sc.Err(); err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return pr, nil
|
||||
}
|
||||
Reference in New Issue
Block a user