Skip to content

Commit 73a523a

Browse files
📝 Add docstrings to main
Docstrings generation was requested by @KimmyXYC. * #8 (comment) The following files were modified: * `KimmyXYC/cmd/server/main.go` * `KimmyXYC/internal/db/db.go` * `KimmyXYC/internal/httpserver/router.go` * `KimmyXYC/internal/provider/openai.go` * `KimmyXYC/internal/provider/provider.go` * `KimmyXYC/internal/services/auth_service.go` * `KimmyXYC/internal/services/chat_service.go` * `KimmyXYC/pkg/auth/token.go` * `KimmyXYC/pkg/middleware/auth.go` * `KimmyXYC/web/js/api.js` * `KimmyXYC/web/js/auth.js` * `KimmyXYC/web/js/chat.js` * `KimmyXYC/web/js/main.js` * `KimmyXYC/web/js/state.js`
1 parent 5852d28 commit 73a523a

File tree

14 files changed

+1386
-0
lines changed

14 files changed

+1386
-0
lines changed

KimmyXYC/cmd/server/main.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package main
2+
3+
import (
4+
"log"
5+
"os"
6+
7+
"github.com/joho/godotenv"
8+
9+
"AIBackend/internal/db"
10+
"AIBackend/internal/httpserver"
11+
"AIBackend/internal/provider"
12+
)
13+
14+
// main 是程序的入口点。它可选加载 `.env`,使用 `DATABASE_URL` 建立并迁移数据库,
15+
// 从环境创建 LLM 提供者,并使用 `ADDR`(默认 ":8080")启动 HTTP 服务器;在连接、迁移或启动失败时记录致命错误,在缺少 `DATABASE_URL` 时记录警告。
16+
func main() {
17+
// Load .env if present (dev convenience)
18+
_ = godotenv.Load()
19+
20+
// Initialize DB
21+
pgURL := os.Getenv("DATABASE_URL")
22+
if pgURL == "" {
23+
log.Println("WARNING: DATABASE_URL is not set. The server may fail to start when DB is required.")
24+
}
25+
gormDB, err := db.Connect(pgURL)
26+
if err != nil {
27+
log.Fatalf("failed to connect database: %v", err)
28+
}
29+
if err := db.AutoMigrate(gormDB); err != nil {
30+
log.Fatalf("failed to migrate database: %v", err)
31+
}
32+
33+
// Initialize LLM provider (Mock by default)
34+
llm := provider.NewProviderFromEnv()
35+
36+
// Start HTTP server
37+
r := httpserver.NewRouter(gormDB, llm)
38+
addr := os.Getenv("ADDR")
39+
if addr == "" {
40+
addr = ":8080"
41+
}
42+
log.Printf("Server listening on %s", addr)
43+
if err := r.Run(addr); err != nil {
44+
log.Fatalf("server error: %v", err)
45+
}
46+
}

KimmyXYC/internal/db/db.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package db
2+
3+
import (
4+
"fmt"
5+
6+
"gorm.io/driver/postgres"
7+
"gorm.io/gorm"
8+
9+
"AIBackend/internal/models"
10+
)
11+
12+
// Connect 使用提供的数据库 URL 打开一个 PostgreSQL 连接。
13+
// 如果传入的 databaseURL 为空,则使用本地默认 DSN:
14+
// postgres://postgres:postgres@localhost:5432/aibackend?sslmode=disable。
15+
// 返回已打开的 *gorm.DB;在无法建立连接时返回非 nil 错误。
16+
func Connect(databaseURL string) (*gorm.DB, error) {
17+
if databaseURL == "" {
18+
// Provide a friendly default to help first run; it will still fail if DB not available.
19+
databaseURL = "postgres://postgres:postgres@localhost:5432/aibackend?sslmode=disable"
20+
}
21+
dsn := databaseURL
22+
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
23+
if err != nil {
24+
return nil, fmt.Errorf("connect postgres: %w", err)
25+
}
26+
return db, nil
27+
}
28+
29+
// AutoMigrate 在数据库上应用 User、Conversation 和 Message 模型的自动迁移。
30+
// 如果迁移失败,返回相应的错误。
31+
func AutoMigrate(db *gorm.DB) error {
32+
return db.AutoMigrate(
33+
&models.User{},
34+
&models.Conversation{},
35+
&models.Message{},
36+
)
37+
}
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
package httpserver
2+
3+
import (
4+
"net/http"
5+
"strconv"
6+
"time"
7+
8+
"github.com/gin-gonic/gin"
9+
"gorm.io/gorm"
10+
11+
"AIBackend/internal/provider"
12+
"AIBackend/internal/services"
13+
"AIBackend/pkg/middleware"
14+
)
15+
16+
type Server struct {
17+
Auth *services.AuthService
18+
Chat *services.ChatService
19+
}
20+
21+
// NewRouter 创建并返回已配置的 Gin 引擎,注册健康检查、认证相关路由、带鉴权的会话与聊天 API(包含可选的 SSE 流式聊天)并提供前端静态文件服务。
22+
func NewRouter(db *gorm.DB, llm provider.LLMProvider) *gin.Engine {
23+
g := gin.Default()
24+
25+
server := &Server{
26+
Auth: services.NewAuthService(db),
27+
Chat: services.NewChatService(db, llm),
28+
}
29+
30+
g.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
31+
32+
api := g.Group("/api")
33+
{
34+
auth := api.Group("/auth")
35+
auth.POST("/register", server.handleRegister)
36+
auth.POST("/login", server.handleLogin)
37+
}
38+
39+
protected := api.Group("")
40+
protected.Use(middleware.AuthRequired())
41+
{
42+
protected.GET("/me", server.handleMe)
43+
protected.GET("/conversations", server.handleListConversations)
44+
protected.GET("/conversations/:id/messages", server.handleGetMessages)
45+
protected.POST("/chat", middleware.ModelAccess(), server.handleChat)
46+
}
47+
48+
// Serve static frontend files without conflicting wildcard
49+
g.StaticFile("/", "./web/index.html")
50+
g.Static("/css", "./web/css")
51+
g.Static("/js", "./web/js")
52+
53+
return g
54+
}
55+
56+
type registerReq struct {
57+
Email string `json:"email" binding:"required"`
58+
Password string `json:"password" binding:"required"`
59+
Role string `json:"role"`
60+
}
61+
62+
type loginReq struct {
63+
Email string `json:"email" binding:"required"`
64+
Password string `json:"password" binding:"required"`
65+
}
66+
67+
func (s *Server) handleRegister(c *gin.Context) {
68+
var req registerReq
69+
if err := c.ShouldBindJSON(&req); err != nil {
70+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
71+
return
72+
}
73+
user, token, err := s.Auth.Register(req.Email, req.Password, req.Role)
74+
if err != nil {
75+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
76+
return
77+
}
78+
c.JSON(http.StatusOK, gin.H{"user": user, "token": token})
79+
}
80+
81+
func (s *Server) handleLogin(c *gin.Context) {
82+
var req loginReq
83+
if err := c.ShouldBindJSON(&req); err != nil {
84+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
85+
return
86+
}
87+
user, token, err := s.Auth.Login(req.Email, req.Password)
88+
if err != nil {
89+
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
90+
return
91+
}
92+
c.JSON(http.StatusOK, gin.H{"user": user, "token": token})
93+
}
94+
95+
func (s *Server) handleMe(c *gin.Context) {
96+
c.JSON(http.StatusOK, gin.H{
97+
"user_id": c.GetUint("user_id"),
98+
"user_email": c.GetString("user_email"),
99+
"user_role": c.GetString("user_role"),
100+
})
101+
}
102+
103+
func (s *Server) handleListConversations(c *gin.Context) {
104+
uid := c.GetUint("user_id")
105+
convs, err := s.Chat.ListConversations(uid)
106+
if err != nil {
107+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
108+
return
109+
}
110+
c.JSON(http.StatusOK, gin.H{"conversations": convs})
111+
}
112+
113+
func (s *Server) handleGetMessages(c *gin.Context) {
114+
uid := c.GetUint("user_id")
115+
idStr := c.Param("id")
116+
id64, _ := strconv.ParseUint(idStr, 10, 64)
117+
msgs, err := s.Chat.GetMessages(uid, uint(id64))
118+
if err != nil {
119+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
120+
return
121+
}
122+
c.JSON(http.StatusOK, gin.H{"messages": msgs})
123+
}
124+
125+
type chatReq struct {
126+
ConversationID uint `json:"conversation_id"`
127+
Model string `json:"model"`
128+
Message string `json:"message" binding:"required"`
129+
Stream *bool `json:"stream"`
130+
}
131+
132+
func (s *Server) handleChat(c *gin.Context) {
133+
uid := c.GetUint("user_id")
134+
var req chatReq
135+
if err := c.ShouldBindJSON(&req); err != nil {
136+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
137+
return
138+
}
139+
// Fallback to query param model for middleware check compatibility
140+
if req.Model == "" {
141+
req.Model = c.Query("model")
142+
}
143+
// Enforce model access if provided in body
144+
role := c.GetString("user_role")
145+
if !middleware.CheckModelAccess(role, req.Model) {
146+
c.JSON(http.StatusForbidden, gin.H{"error": "model access denied for role"})
147+
return
148+
}
149+
streaming := false
150+
if req.Stream != nil {
151+
streaming = *req.Stream
152+
}
153+
if c.Query("stream") == "1" || c.Query("stream") == "true" {
154+
streaming = true
155+
}
156+
if !streaming {
157+
convID, reply, err := s.Chat.SendMessage(c.Request.Context(), uid, req.ConversationID, req.Model, req.Message, nil)
158+
if err != nil {
159+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
160+
return
161+
}
162+
c.JSON(http.StatusOK, gin.H{"conversation_id": convID, "reply": reply})
163+
return
164+
}
165+
// Streaming via SSE
166+
w := c.Writer
167+
c.Header("Content-Type", "text/event-stream")
168+
c.Header("Cache-Control", "no-cache")
169+
c.Header("Connection", "keep-alive")
170+
c.Status(http.StatusOK)
171+
flusher, _ := w.(http.Flusher)
172+
sentAny := false
173+
convID, _, err := s.Chat.SendMessage(c.Request.Context(), uid, req.ConversationID, req.Model, req.Message, func(chunk string) error {
174+
sentAny = true
175+
_, err := w.Write([]byte("data: " + chunk + "\n\n"))
176+
if err == nil && flusher != nil {
177+
flusher.Flush()
178+
}
179+
return err
180+
})
181+
if err != nil {
182+
// send error as SSE comment and 0-length event end
183+
_, _ = w.Write([]byte(": error: " + err.Error() + "\n\n"))
184+
if flusher != nil {
185+
flusher.Flush()
186+
}
187+
return
188+
}
189+
if !sentAny {
190+
// send at least one empty event to keep clients happy
191+
_, _ = w.Write([]byte("data: \n\n"))
192+
}
193+
// end event
194+
_, _ = w.Write([]byte("event: done\n" + "data: {\"conversation_id\": " + strconv.FormatUint(uint64(convID), 10) + "}\n\n"))
195+
if flusher != nil {
196+
flusher.Flush()
197+
}
198+
// allow connection to close shortly after
199+
time.Sleep(50 * time.Millisecond)
200+
}

0 commit comments

Comments
 (0)