Skip to content

Commit 3882036

Browse files
committed
Start cleaning up config to fix race conditions
1 parent f79e692 commit 3882036

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

server.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,17 @@ type Server struct {
5858
RequestHandlers map[string]RequestHandler
5959

6060
listenerWg sync.WaitGroup
61-
mu sync.Mutex
61+
mu sync.RWMutex
6262
listeners map[net.Listener]struct{}
6363
conns map[*gossh.ServerConn]struct{}
6464
connWg sync.WaitGroup
6565
doneChan chan struct{}
6666
}
6767

6868
func (srv *Server) ensureHostSigner() error {
69+
srv.mu.Lock()
70+
defer srv.mu.Unlock()
71+
6972
if len(srv.HostSigners) == 0 {
7073
signer, err := generateSigner()
7174
if err != nil {
@@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error {
7982
func (srv *Server) ensureHandlers() {
8083
srv.mu.Lock()
8184
defer srv.mu.Unlock()
85+
8286
if srv.RequestHandlers == nil {
8387
srv.RequestHandlers = map[string]RequestHandler{}
8488
for k, v := range DefaultRequestHandlers {
@@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() {
9498
}
9599

96100
func (srv *Server) config(ctx Context) *gossh.ServerConfig {
101+
srv.mu.RLock()
102+
defer srv.mu.RUnlock()
103+
97104
var config *gossh.ServerConfig
98105
if srv.ServerConfigCallback == nil {
99106
config = &gossh.ServerConfig{}
@@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
142149

143150
// Handle sets the Handler for the server.
144151
func (srv *Server) Handle(fn Handler) {
152+
srv.mu.Lock()
153+
defer srv.mu.Unlock()
154+
145155
srv.Handler = fn
146156
}
147157

@@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) {
153163
func (srv *Server) Close() error {
154164
srv.mu.Lock()
155165
defer srv.mu.Unlock()
166+
156167
srv.closeDoneChanLocked()
157168
err := srv.closeListenersLocked()
158169
for c := range srv.conns {
@@ -313,6 +324,9 @@ func (srv *Server) ListenAndServe() error {
313324
// with the same algorithm, it is overwritten. Each server config must have at
314325
// least one host key.
315326
func (srv *Server) AddHostKey(key Signer) {
327+
srv.mu.Lock()
328+
defer srv.mu.Unlock()
329+
316330
// these are later added via AddHostKey on ServerConfig, which performs the
317331
// check for one of every algorithm.
318332

@@ -332,12 +346,20 @@ func (srv *Server) AddHostKey(key Signer) {
332346

333347
// SetOption runs a functional option against the server.
334348
func (srv *Server) SetOption(option Option) error {
349+
// NOTE: there is a potential race here for any option that doesn't call an
350+
// internal method. We can't actually lock here because if something calls
351+
// (as an example) AddHostKey, it will deadlock.
352+
353+
//srv.mu.Lock()
354+
//defer srv.mu.Unlock()
355+
335356
return option(srv)
336357
}
337358

338359
func (srv *Server) getDoneChan() <-chan struct{} {
339360
srv.mu.Lock()
340361
defer srv.mu.Unlock()
362+
341363
return srv.getDoneChanLocked()
342364
}
343365

@@ -374,6 +396,7 @@ func (srv *Server) closeListenersLocked() error {
374396
func (srv *Server) trackListener(ln net.Listener, add bool) {
375397
srv.mu.Lock()
376398
defer srv.mu.Unlock()
399+
377400
if srv.listeners == nil {
378401
srv.listeners = make(map[net.Listener]struct{})
379402
}
@@ -394,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
394417
func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
395418
srv.mu.Lock()
396419
defer srv.mu.Unlock()
420+
397421
if srv.conns == nil {
398422
srv.conns = make(map[*gossh.ServerConn]struct{})
399423
}

0 commit comments

Comments
 (0)