Skip to content

Commit 70b73d9

Browse files
liuzenghwineguo
andauthored
admin: sync internal bugfixes and enhancements (#206)
Co-authored-by: wineguo <[email protected]>
1 parent d23a9e1 commit 70b73d9

File tree

4 files changed

+109
-172
lines changed

4 files changed

+109
-172
lines changed

.github/workflows/prc.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ jobs:
5151
go-apidiff:
5252
if: github.event_name == 'pull_request'
5353
runs-on: ubuntu-latest
54+
permissions:
55+
contents: read
56+
pull-requests: write # Required for commenting on PRs
5457
steps:
55-
- uses: actions/checkout@v3
58+
- uses: actions/checkout@v4
5659
with:
5760
fetch-depth: 0
58-
- uses: actions/setup-go@v4
61+
- uses: actions/setup-go@v5
5962
with:
60-
go-version: 1.19
61-
- uses: joelanford/go-apidiff@main
63+
go-version: 'stable'
64+
- uses: imjasonh/apidiff[email protected]

admin/admin.go

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,29 @@ import (
3939
"trpc.group/trpc-go/trpc-go/transport"
4040
)
4141

42+
func init() {
43+
// The pprof functionality supported by the admin package relies on the imported net/http/pprof package.
44+
// However, the imported net/http/pprof package implicitly registers HTTP handlers for
45+
// "/debug/pprof/", "/debug/pprof/cmdline", "/debug/pprof/profile", "/debug/pprof/symbol", "/debug/pprof/trace"
46+
// in http.DefaultServeMux in its init function. This implicit behavior is too subtle and may contribute to people
47+
// inadvertently leaving such endpoints open, and may cause security problems:https://github.com/golang/go/issues/22085
48+
// if people use http.DefaultServeMux. So we decide to reset default serve mux to remove pprof registration.
49+
// This requires making sure that people are not using http.DefaultServeMux before we reset it.
50+
// In most cases, this works, which is guaranteed by the execution order of the init function.
51+
// If you need to enable pprof on http.DefaultServeMux you need to
52+
// register it explicitly after importing the admin package:
53+
//
54+
// http.DefaultServeMux.HandleFunc("/debug/pprof/", pprof.Index)
55+
// http.DefaultServeMux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
56+
// http.DefaultServeMux.HandleFunc("/debug/pprof/profile", pprof.Profile)
57+
// http.DefaultServeMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
58+
// http.DefaultServeMux.HandleFunc("/debug/pprof/trace", pprof.Trace)
59+
//
60+
// Simply importing the net/http/pprof package anonymously will not work.
61+
// More details see: https://git.woa.com/trpc-go/trpc-go/issues/912, and https://github.com/golang/go/issues/42834.
62+
http.DefaultServeMux = http.NewServeMux()
63+
}
64+
4265
// ServiceName is the service name of admin service.
4366
const ServiceName = "admin"
4467

@@ -122,21 +145,6 @@ func (s *Server) configRouter(r *router) *router {
122145
for pattern, handler := range pattern2Handler {
123146
r.add(pattern, handler)
124147
}
125-
126-
// Delete the router registered with http.DefaultServeMux.
127-
// Avoid causing security problems: https://github.com/golang/go/issues/22085.
128-
err := unregisterHandlers(
129-
[]string{
130-
pprofPprof,
131-
pprofCmdline,
132-
pprofProfile,
133-
pprofSymbol,
134-
pprofTrace,
135-
},
136-
)
137-
if err != nil {
138-
log.Errorf("failed to unregister pprof handlers from http.DefaultServeMux, err: %+v", err)
139-
}
140148
return r
141149
}
142150

@@ -173,13 +181,18 @@ func (s *Server) Serve() error {
173181
return err
174182
}
175183

184+
log.Infof("admin service launch success, %s:%s, serving ...", ln.Addr().Network(), ln.Addr().String())
185+
176186
s.server = &http.Server{
177187
Addr: ln.Addr().String(),
178188
ReadTimeout: cfg.readTimeout,
179189
WriteTimeout: cfg.writeTimeout,
180190
Handler: s.router,
181191
}
182-
if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed {
192+
// Restricted access to the internal/poll.ErrNetClosing type necessitates comparing a string literal.
193+
const closeError = "use of closed network connection"
194+
if err := s.server.Serve(ln); err != nil &&
195+
err != http.ErrServerClosed && !strings.Contains(err.Error(), closeError) {
183196
return err
184197
}
185198
return nil

admin/admin_test.go

Lines changed: 73 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"io"
2222
"net"
2323
"net/http"
24+
"net/http/pprof"
2425
"os"
2526
"reflect"
2627
"strings"
@@ -574,10 +575,10 @@ func TestOptionsConfig(t *testing.T) {
574575

575576
func httpRequest(method string, url string, body string) ([]byte, error) {
576577
request, err := http.NewRequest(method, url, strings.NewReader(body))
577-
request.Header.Set("content-type", "application/x-www-form-urlencoded")
578578
if err != nil {
579579
return nil, err
580580
}
581+
request.Header.Set("content-type", "application/x-www-form-urlencoded")
581582

582583
response, err := http.DefaultClient.Do(request)
583584
if err != nil {
@@ -599,72 +600,84 @@ func panicHandle(w http.ResponseWriter, r *http.Request) {
599600
panic("panic error handle")
600601
}
601602

602-
func TestUnregisterHandlers(t *testing.T) {
603-
_ = newDefaultAdminServer()
604-
mux, err := extractServeMuxData()
605-
require.Nil(t, err)
606-
require.Len(t, mux.m, 0)
607-
require.Len(t, mux.es, 0)
608-
require.False(t, mux.hosts)
609-
610-
http.HandleFunc("/usercmd", userCmd)
611-
http.HandleFunc("/errout", errOutput)
612-
http.HandleFunc("/panicHandle", panicHandle)
613-
http.HandleFunc("www.qq.com/", userCmd)
614-
http.HandleFunc("anything/", userCmd)
603+
func Test_init(t *testing.T) {
604+
t.Run("reset default serve mux to remove pprof registration at admin init func", func(t *testing.T) {
605+
l, err := net.Listen("tcp", "127.0.0.1:0")
606+
require.Nil(t, err)
607+
go func() {
608+
server := &http.Server{
609+
Handler: nil,
610+
ReadTimeout: 15 * time.Second,
611+
WriteTimeout: 15 * time.Second,
612+
IdleTimeout: 60 * time.Second,
613+
}
614+
615+
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
616+
t.Logf("http serving: %v", err)
617+
}
618+
}()
619+
time.Sleep(200 * time.Millisecond)
620+
621+
r, err := http.Get(fmt.Sprintf("http://%s/debug/pprof/", l.Addr().String()))
622+
require.Nil(t, err)
623+
require.Equal(t, http.StatusNotFound, r.StatusCode)
615624

616-
l := mustListenTCP(t)
617-
go func() {
618-
if err := http.Serve(l, nil); err != nil {
619-
t.Log(err)
620-
}
621-
}()
622-
time.Sleep(200 * time.Millisecond)
625+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/cmdline", l.Addr().String()))
626+
require.Nil(t, err)
627+
require.Equal(t, http.StatusNotFound, r.StatusCode)
623628

624-
mux, err = extractServeMuxData()
625-
require.Nil(t, err)
626-
require.Equal(t, 5, len(mux.m))
627-
require.Equal(t, 2, len(mux.es))
628-
require.Equal(t, true, mux.hosts)
629-
630-
err = unregisterHandlers(
631-
[]string{
632-
"/usercmd",
633-
"/errout",
634-
"/panicHandle",
635-
"www.qq.com/",
636-
"anything/",
637-
},
638-
)
639-
require.Nil(t, err)
629+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/profile", l.Addr().String()))
630+
require.Nil(t, err)
631+
require.Equal(t, http.StatusNotFound, r.StatusCode)
640632

641-
mux, err = extractServeMuxData()
642-
require.Nil(t, err)
643-
require.Len(t, mux.m, 0)
644-
require.Len(t, mux.es, 0)
645-
require.False(t, mux.hosts)
633+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/symbol", l.Addr().String()))
634+
require.Nil(t, err)
635+
require.Equal(t, http.StatusNotFound, r.StatusCode)
646636

647-
resp1, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr()))
648-
require.Nil(t, err)
649-
defer resp1.Body.Close()
650-
require.Equal(t, http.StatusNotFound, resp1.StatusCode)
637+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/trace", l.Addr().String()))
638+
require.Nil(t, err)
639+
require.Equal(t, http.StatusNotFound, r.StatusCode)
640+
})
641+
t.Run("register pprof handler explicitly after importing the admin package", func(t *testing.T) {
642+
http.DefaultServeMux.HandleFunc("/debug/pprof/", pprof.Index)
643+
http.DefaultServeMux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
644+
http.DefaultServeMux.HandleFunc("/debug/pprof/profile", pprof.Profile)
645+
http.DefaultServeMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
646+
http.DefaultServeMux.HandleFunc("/debug/pprof/trace", pprof.Trace)
647+
t.Cleanup(func() {
648+
http.DefaultServeMux = http.NewServeMux()
649+
})
650+
l, err := net.Listen("tcp", "127.0.0.1:0")
651+
require.Nil(t, err)
652+
go func() {
653+
server := &http.Server{
654+
Handler: nil,
655+
ReadTimeout: 15 * time.Second,
656+
WriteTimeout: 15 * time.Second,
657+
IdleTimeout: 60 * time.Second,
658+
}
659+
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
660+
t.Logf("http serving: %v", err)
661+
}
662+
}()
663+
time.Sleep(200 * time.Millisecond)
664+
665+
r, err := http.Get(fmt.Sprintf("http://%s/debug/pprof/", l.Addr().String()))
666+
require.Nil(t, err)
667+
require.Equal(t, http.StatusOK, r.StatusCode)
651668

652-
http.HandleFunc("/usercmd", userCmd)
653-
http.HandleFunc("/errout", errOutput)
654-
http.HandleFunc("/panicHandle", panicHandle)
669+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/cmdline", l.Addr().String()))
670+
require.Nil(t, err)
671+
require.Equal(t, http.StatusOK, r.StatusCode)
655672

656-
mux, err = extractServeMuxData()
657-
require.Nil(t, err)
658-
require.Len(t, mux.m, 3)
659-
require.Len(t, mux.es, 0)
660-
require.False(t, mux.hosts)
673+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/symbol", l.Addr().String()))
674+
require.Nil(t, err)
675+
require.Equal(t, http.StatusOK, r.StatusCode)
661676

662-
resp2, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr()))
663-
require.Nil(t, err)
664-
defer resp2.Body.Close()
665-
respBody, err := io.ReadAll(resp2.Body)
666-
require.Nil(t, err)
667-
require.Equal(t, []byte("usercmd"), respBody)
677+
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/trace", l.Addr().String()))
678+
require.Nil(t, err)
679+
require.Equal(t, http.StatusOK, r.StatusCode)
680+
})
668681
}
669682
func mustListenTCP(t *testing.T) *net.TCPListener {
670683
l, err := net.Listen("tcp", testAddress)

admin/mux.go

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)