Skip to content

Commit da9fedf

Browse files
authored
Merge pull request #83 from AlexStocks/fix/ws-concurrent-read
fix:add read mutex in gettyWSConn(websocket) struct to prevent data race in ReadMessage()
2 parents 87ba8ee + a48ffa3 commit da9fedf

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

transport/connection.go

+18-6
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,9 @@ func (u *gettyUDPConn) CloseConn(_ int) {
493493

494494
type gettyWSConn struct {
495495
gettyConn
496-
conn *websocket.Conn
497-
lock sync.Mutex
496+
writeLock sync.Mutex
497+
readLock sync.Mutex
498+
conn *websocket.Conn
498499
}
499500

500501
// create websocket connection
@@ -569,7 +570,7 @@ func (w *gettyWSConn) handlePong(string) error {
569570
func (w *gettyWSConn) recv() ([]byte, error) {
570571
// Pls do not set read deadline when using ReadMessage. AlexStocks 20180310
571572
// gorilla/websocket/conn.go:NextReader will always fail when got a timeout error.
572-
_, b, e := w.conn.ReadMessage() // the first return value is message type.
573+
_, b, e := w.threadSafeReadMessage() // the first return value is message type.
573574
if e == nil {
574575
w.readBytes.Add((uint32)(len(b)))
575576
} else {
@@ -643,12 +644,23 @@ func (w *gettyWSConn) CloseConn(waitSec int) {
643644
w.conn.Close()
644645
}
645646

646-
// uses a mutex to ensure that only one thread can send a message at a time, preventing race conditions.
647+
// uses a mutex(writeLock) to ensure that only one thread can send a message at a time, preventing race conditions.
647648
func (w *gettyWSConn) threadSafeWriteMessage(messageType int, data []byte) error {
648-
w.lock.Lock()
649-
defer w.lock.Unlock()
649+
w.writeLock.Lock()
650+
defer w.writeLock.Unlock()
650651
if err := w.conn.WriteMessage(messageType, data); err != nil {
651652
return err
652653
}
653654
return nil
654655
}
656+
657+
// uses a mutex(readLock) to ensure that only one thread can read a message at a time, preventing race conditions.
658+
func (w *gettyWSConn) threadSafeReadMessage() (int, []byte, error) {
659+
w.readLock.Lock()
660+
defer w.readLock.Unlock()
661+
messageType, readBytes, err := w.conn.ReadMessage()
662+
if err != nil {
663+
return messageType, nil, err
664+
}
665+
return messageType, readBytes, nil
666+
}

0 commit comments

Comments
 (0)