Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加 billing 开关/ 添加 审计功能 #1530

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ build
logs
data
/web/node_modules
cmd.md
cmd.md
vendor/*
27 changes: 27 additions & 0 deletions common/audit/audit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package audit

import (
"github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
)

var (
loger *lumberjack.Logger
logger *logrus.Logger
)

func init() {
loger = &lumberjack.Logger{
Filename: "logs/audit.log",
MaxSize: 50, // megabytes
MaxBackups: 300,
MaxAge: 90, // days
}
logger = logrus.New()
logger.SetOutput(loger)
logger.SetFormatter(&logrus.JSONFormatter{})
}

func Logger() *logrus.Logger {
return logger
}
79 changes: 79 additions & 0 deletions common/audit/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package audit

import (
"bytes"
"encoding/base64"
"io"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)

type AuditLogger struct {
gin.ResponseWriter
buf *bytes.Buffer
}

func (l *AuditLogger) Write(p []byte) (int, error) {
l.buf.Write(p)
return l.ResponseWriter.Write(p)
}

func CaptureResponseBody(c *gin.Context) *bytes.Buffer {
al := &AuditLogger{
ResponseWriter: c.Writer,
buf: &bytes.Buffer{},
}
c.Writer = al
return al.buf
}

func B64encode(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}

type AuditReadCloser struct {
Reader io.Reader
Closer io.Closer
Buffer *bytes.Buffer
}

func (arc *AuditReadCloser) Read(p []byte) (int, error) {
n, err := arc.Reader.Read(p)
if n > 0 {
arc.Buffer.Write(p[:n])
}
return n, err
}

func (arc *AuditReadCloser) Close() error {
return arc.Closer.Close()
}

func CaptureHTTPResponseBody(resp *http.Response) *bytes.Buffer {
buf := &bytes.Buffer{}
arc := &AuditReadCloser{
Reader: resp.Body,
Closer: resp.Body,
Buffer: buf,
}
resp.Body = arc
return buf
}

func ParseOPENAIStreamResponse(buf *bytes.Buffer) string {
lines := strings.Split(buf.String(), "\n")
bts := []string{}
for _, line := range lines {
line = strings.TrimSpace(line)
line = strings.Trim(line, "\n")
if strings.HasPrefix(string(line), "data:") {
line = line[5:]
}
content := gjson.Get(line, "choices.0.delta.content").String()
bts = append(bts, content)
}
return strings.Join(bts, "")
}
6 changes: 5 additions & 1 deletion common/config/config.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package config

import (
"github.com/songquanpeng/one-api/common/env"
"os"
"strconv"
"strings"
"sync"
"time"

"github.com/songquanpeng/one-api/common/env"

"github.com/google/uuid"
)

Expand Down Expand Up @@ -55,6 +56,8 @@ var EmailDomainWhitelist = []string{
var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true"
var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true"
var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true"
var ClientAuditEnabled = env.Bool("CLIENT_AUDIT_ENABLED", false)
var UpstreamAuditEnabled = env.Bool("UPSTREAM_AUDIT_ENABLED", false)

var LogConsumeEnabled = true

Expand Down Expand Up @@ -135,6 +138,7 @@ var (

var RateLimitKeyExpirationDuration = 20 * time.Minute

var EnableBilling = env.Bool("ENABLE_BILLING", true)
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
Expand Down
87 changes: 72 additions & 15 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/audit"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
Expand All @@ -17,48 +18,97 @@ import (
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)

// https://platform.openai.com/docs/api-reference/chat

func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
type Options struct {
Debug bool
EnableMonitor bool
EnableBilling bool
}

type RelayController struct {
opts Options
controller.RelayInstance
monitor.MonitorInstance
}

func NewRelayController(opts Options) *RelayController {
ctrl := &RelayController{
opts: opts,
}
ctrl.RelayInstance = controller.NewRelayInstance(controller.Options{
EnableBilling: opts.EnableBilling,
})
if opts.EnableMonitor {
ctrl.MonitorInstance = monitor.NewMonitorInstance()
}
return ctrl
}

func (ctrl *RelayController) relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
if config.ClientAuditEnabled {
buf := audit.CaptureResponseBody(c)
m := meta.GetByContext(c)
defer func() {
audit.Logger().
WithField("raw", audit.B64encode(buf.Bytes())).
WithField("parsed", audit.ParseOPENAIStreamResponse(buf)).
WithField("requestid", c.GetString(helper.RequestIdKey)).
WithFields(m.ToLogrusFields()).
Info("client response")
}()
}
var err *model.ErrorWithStatusCode
switch relayMode {
case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
err = ctrl.RelayImageHelper(c, relayMode)
case relaymode.AudioSpeech:
fallthrough
case relaymode.AudioTranslation:
fallthrough
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
err = ctrl.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
err = ctrl.RelayTextHelper(c)
}
return err
}

func Relay(c *gin.Context) {
func (ctrl *RelayController) Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
if config.ClientAuditEnabled {
requestBody, _ := common.GetRequestBody(c)
m := meta.GetByContext(c)
audit.Logger().
WithField("raw", audit.B64encode(requestBody)).
WithField("requestid", c.GetString(helper.RequestIdKey)).
WithFields(m.ToLogrusFields()).
Info("client request")
}
channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt("id")
bizErr := relayHelper(c, relayMode)
bizErr := ctrl.relayHelper(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
if ctrl.MonitorInstance != nil {
ctrl.Emit(channelId, true)
}
return
}
lastFailedChannelId := channelId
channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
userId := c.GetInt(ctxkey.Id)
go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
Expand All @@ -77,15 +127,19 @@ func Relay(c *gin.Context) {
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
if err != nil {
logger.Errorf(ctx, "GetRequestBody failed: %+v", err)
break
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relayHelper(c, relayMode)
bizErr = ctrl.relayHelper(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
Expand Down Expand Up @@ -117,13 +171,16 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true
}

func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
func (ctrl *RelayController) processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
if ctrl.MonitorInstance == nil {
return
}
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
if ctrl.ShouldDisableChannel(&err.Error, err.StatusCode) {
ctrl.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
ctrl.Emit(channelId, false)
}
}

Expand Down
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ require (
github.com/jinzhu/copier v0.4.0
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.7
github.com/sirupsen/logrus v1.8.1
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1
golang.org/x/crypto v0.23.0
golang.org/x/image v0.16.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5
Expand Down Expand Up @@ -73,6 +76,8 @@ require (
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
Expand Down
Loading