@@ -37,8 +37,15 @@ type Server struct {
3737 IdleTimeout time.Duration // connection timeout when no activity, none if empty
3838 MaxTimeout time.Duration // absolute connection timeout, none if empty
3939
40- channelHandlers map [string ]channelHandler
41- requestHandlers map [string ]RequestHandler
40+ // ChannelHandlers allow overriding the built-in session handlers or provide
41+ // extensions to the protocol, such as tcpip forwarding. By default only the
42+ // "session" handler is enabled.
43+ ChannelHandlers map [string ]ChannelHandler
44+
45+ // RequestHandlers allow overriding the server-level request handlers or
46+ // provide extensions to the protocol, such as tcpip forwarding. By default
47+ // no handlers are enabled.
48+ RequestHandlers map [string ]RequestHandler
4249
4350 listenerWg sync.WaitGroup
4451 mu sync.Mutex
@@ -47,12 +54,32 @@ type Server struct {
4754 connWg sync.WaitGroup
4855 doneChan chan struct {}
4956}
57+
5058type RequestHandler interface {
51- HandleRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
59+ HandleSSHRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
60+ }
61+
62+ type RequestHandlerFunc func (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
63+
64+ func (f RequestHandlerFunc ) HandleSSHRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte ) {
65+ return f (ctx , srv , req )
66+ }
67+
68+ var DefaultRequestHandlers = map [string ]RequestHandler {}
69+
70+ type ChannelHandler interface {
71+ HandleSSHChannel (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
72+ }
73+
74+ type ChannelHandlerFunc func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
75+
76+ func (f ChannelHandlerFunc ) HandleSSHChannel (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context ) {
77+ f (srv , conn , newChan , ctx )
5278}
5379
54- // internal for now
55- type channelHandler func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
80+ var DefaultChannelHandlers = map [string ]ChannelHandler {
81+ "session" : ChannelHandlerFunc (DefaultSessionHandler ),
82+ }
5683
5784func (srv * Server ) ensureHostSigner () error {
5885 if len (srv .HostSigners ) == 0 {
@@ -68,13 +95,17 @@ func (srv *Server) ensureHostSigner() error {
6895func (srv * Server ) ensureHandlers () {
6996 srv .mu .Lock ()
7097 defer srv .mu .Unlock ()
71- srv .requestHandlers = map [string ]RequestHandler {
72- "tcpip-forward" : forwardedTCPHandler {},
73- "cancel-tcpip-forward" : forwardedTCPHandler {},
98+ if srv .RequestHandlers == nil {
99+ srv .RequestHandlers = map [string ]RequestHandler {}
100+ for k , v := range DefaultRequestHandlers {
101+ srv .RequestHandlers [k ] = v
102+ }
74103 }
75- srv .channelHandlers = map [string ]channelHandler {
76- "session" : sessionHandler ,
77- "direct-tcpip" : directTcpipHandler ,
104+ if srv .ChannelHandlers == nil {
105+ srv .ChannelHandlers = map [string ]ChannelHandler {}
106+ for k , v := range DefaultChannelHandlers {
107+ srv .ChannelHandlers [k ] = v
108+ }
78109 }
79110}
80111
@@ -186,12 +217,6 @@ func (srv *Server) Serve(l net.Listener) error {
186217 if srv .Handler == nil {
187218 srv .Handler = DefaultHandler
188219 }
189- if srv .channelHandlers == nil {
190- srv .channelHandlers = map [string ]channelHandler {
191- "session" : sessionHandler ,
192- "direct-tcpip" : directTcpipHandler ,
193- }
194- }
195220 var tempDelay time.Duration
196221
197222 srv .trackListener (l , true )
@@ -255,30 +280,32 @@ func (srv *Server) handleConn(newConn net.Conn) {
255280 //go gossh.DiscardRequests(reqs)
256281 go srv .handleRequests (ctx , reqs )
257282 for ch := range chans {
258- handler , found := srv .channelHandlers [ch .ChannelType ()]
259- if ! found {
283+ handler := srv .ChannelHandlers [ch .ChannelType ()]
284+ if handler == nil {
285+ handler = srv .ChannelHandlers ["default" ]
286+ }
287+ if handler == nil {
260288 ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
261289 continue
262290 }
263- go handler (srv , sshConn , ch , ctx )
291+ go handler . HandleSSHChannel (srv , sshConn , ch , ctx )
264292 }
265293}
266294
267295func (srv * Server ) handleRequests (ctx Context , in <- chan * gossh.Request ) {
268296 for req := range in {
269- handler , found := srv .requestHandlers [req .Type ]
270- if ! found {
271- if req .WantReply {
272- req .Reply (false , nil )
273- }
297+ handler := srv .RequestHandlers [req .Type ]
298+ if handler == nil {
299+ handler = srv .RequestHandlers ["default" ]
300+ }
301+ if handler == nil {
302+ req .Reply (false , nil )
274303 continue
275304 }
276305 /*reqCtx, cancel := context.WithCancel(ctx)
277306 defer cancel() */
278- ret , payload := handler .HandleRequest (ctx , srv , req )
279- if req .WantReply {
280- req .Reply (ret , payload )
281- }
307+ ret , payload := handler .HandleSSHRequest (ctx , srv , req )
308+ req .Reply (ret , payload )
282309 }
283310}
284311
0 commit comments