Skip to content

Commit

Permalink
Use atomic types instead of manual calls
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Feb 12, 2024
1 parent 6893652 commit 2e5dad7
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 27 deletions.
18 changes: 9 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ type Client struct {
socketLock sync.RWMutex
socketWait chan struct{}

isLoggedIn uint32
expectedDisconnectVal uint32
isLoggedIn atomic.Bool
expectedDisconnect atomic.Bool
EnableAutoReconnect bool
LastSuccessfulConnect time.Time
AutoReconnectErrors int
// AutoReconnectHook is called when auto-reconnection fails. If the function returns false,
// the client will not attempt to reconnect. The number of retries can be read from AutoReconnectErrors.
AutoReconnectHook func(error) bool

sendActiveReceipts uint32
sendActiveReceipts atomic.Uint32

// EmitAppStateEventsOnFullSync can be set to true if you want to get app state events emitted
// even when re-syncing the whole state.
Expand All @@ -82,7 +82,7 @@ type Client struct {
appStateSyncLock sync.Mutex

historySyncNotifications chan *waProto.HistorySyncNotification
historySyncHandlerStarted uint32
historySyncHandlerStarted atomic.Bool

uploadPreKeysLock sync.Mutex
lastPreKeyUpload time.Time
Expand Down Expand Up @@ -153,7 +153,7 @@ type Client struct {
phoneLinkingCache *phoneLinkingCache

uniqueID string
idCounter uint32
idCounter atomic.Uint64

proxy socket.Proxy
http *http.Client
Expand Down Expand Up @@ -362,7 +362,7 @@ func (cli *Client) Connect() error {

// IsLoggedIn returns true after the client is successfully connected and authenticated on WhatsApp.
func (cli *Client) IsLoggedIn() bool {
return atomic.LoadUint32(&cli.isLoggedIn) == 1
return cli.isLoggedIn.Load()
}

func (cli *Client) onDisconnect(ns *socket.NoiseSocket, remote bool) {
Expand All @@ -387,15 +387,15 @@ func (cli *Client) onDisconnect(ns *socket.NoiseSocket, remote bool) {
}

func (cli *Client) expectDisconnect() {
atomic.StoreUint32(&cli.expectedDisconnectVal, 1)
cli.expectedDisconnect.Store(true)
}

func (cli *Client) resetExpectedDisconnect() {
atomic.StoreUint32(&cli.expectedDisconnectVal, 0)
cli.expectedDisconnect.Store(false)
}

func (cli *Client) isExpectedDisconnect() bool {
return atomic.LoadUint32(&cli.expectedDisconnectVal) == 1
return cli.expectedDisconnect.Load()
}

func (cli *Client) autoReconnect() {
Expand Down
5 changes: 2 additions & 3 deletions connectionevents.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
package whatsmeow

import (
"sync/atomic"
"time"

waBinary "go.mau.fi/whatsmeow/binary"
Expand All @@ -17,7 +16,7 @@ import (
)

func (cli *Client) handleStreamError(node *waBinary.Node) {
atomic.StoreUint32(&cli.isLoggedIn, 0)
cli.isLoggedIn.Store(false)
cli.clearResponseWaiters(node)
code, _ := node.Attrs["code"].(string)
conflict, _ := node.GetOptionalChildByTag("conflict")
Expand Down Expand Up @@ -148,7 +147,7 @@ func (cli *Client) handleConnectSuccess(node *waBinary.Node) {
cli.Log.Infof("Successfully authenticated")
cli.LastSuccessfulConnect = time.Now()
cli.AutoReconnectErrors = 0
atomic.StoreUint32(&cli.isLoggedIn, 1)
cli.isLoggedIn.Store(true)
go func() {
if dbCount, err := cli.Store.PreKeys.UploadedPreKeyCount(); err != nil {
cli.Log.Errorf("Failed to get number of prekeys in database: %v", err)
Expand Down
7 changes: 3 additions & 4 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"fmt"
"io"
"runtime/debug"
"sync/atomic"
"time"

"go.mau.fi/libsignal/groups"
Expand Down Expand Up @@ -352,15 +351,15 @@ func (cli *Client) handleSenderKeyDistributionMessage(chat, from types.JID, axol

func (cli *Client) handleHistorySyncNotificationLoop() {
defer func() {
atomic.StoreUint32(&cli.historySyncHandlerStarted, 0)
cli.historySyncHandlerStarted.Store(false)
err := recover()
if err != nil {
cli.Log.Errorf("History sync handler panicked: %v\n%s", err, debug.Stack())
}

// Check in case something new appeared in the channel between the loop stopping
// and the atomic variable being updated. If yes, restart the loop.
if len(cli.historySyncNotifications) > 0 && atomic.CompareAndSwapUint32(&cli.historySyncHandlerStarted, 0, 1) {
if len(cli.historySyncNotifications) > 0 && cli.historySyncHandlerStarted.CompareAndSwap(false, true) {
cli.Log.Warnf("New history sync notifications appeared after loop stopped, restarting loop...")
go cli.handleHistorySyncNotificationLoop()
}
Expand Down Expand Up @@ -453,7 +452,7 @@ func (cli *Client) handleProtocolMessage(info *types.MessageInfo, msg *waProto.M

if protoMsg.GetHistorySyncNotification() != nil && info.IsFromMe {
cli.historySyncNotifications <- protoMsg.HistorySyncNotification
if atomic.CompareAndSwapUint32(&cli.historySyncHandlerStarted, 0, 1) {
if cli.historySyncHandlerStarted.CompareAndSwap(false, true) {
go cli.handleHistorySyncNotificationLoop()
}
go cli.sendProtocolMessageReceipt(info.ID, types.ReceiptTypeHistorySync)
Expand Down
5 changes: 2 additions & 3 deletions presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package whatsmeow

import (
"fmt"
"sync/atomic"

waBinary "go.mau.fi/whatsmeow/binary"
"go.mau.fi/whatsmeow/types"
Expand Down Expand Up @@ -66,9 +65,9 @@ func (cli *Client) SendPresence(state types.Presence) error {
return ErrNoPushName
}
if state == types.PresenceAvailable {
atomic.CompareAndSwapUint32(&cli.sendActiveReceipts, 0, 1)
cli.sendActiveReceipts.CompareAndSwap(0, 1)
} else {
atomic.CompareAndSwapUint32(&cli.sendActiveReceipts, 1, 0)
cli.sendActiveReceipts.CompareAndSwap(1, 0)
}
return cli.sendNode(waBinary.Node{
Tag: "presence",
Expand Down
7 changes: 3 additions & 4 deletions receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package whatsmeow

import (
"fmt"
"sync/atomic"
"time"

waBinary "go.mau.fi/whatsmeow/binary"
Expand Down Expand Up @@ -190,9 +189,9 @@ func (cli *Client) MarkRead(ids []types.MessageID, timestamp time.Time, chat, se
// receipts will act like the client is offline until SendPresence is called again.
func (cli *Client) SetForceActiveDeliveryReceipts(active bool) {
if active {
atomic.StoreUint32(&cli.sendActiveReceipts, 2)
cli.sendActiveReceipts.Store(2)
} else {
atomic.StoreUint32(&cli.sendActiveReceipts, 0)
cli.sendActiveReceipts.Store(0)
}
}

Expand All @@ -202,7 +201,7 @@ func (cli *Client) sendMessageReceipt(info *types.MessageInfo) {
}
if info.IsFromMe {
attrs["type"] = string(types.ReceiptTypeSender)
} else if atomic.LoadUint32(&cli.sendActiveReceipts) == 0 {
} else if cli.sendActiveReceipts.Load() == 0 {
attrs["type"] = string(types.ReceiptTypeInactive)
}
attrs["to"] = info.Chat
Expand Down
3 changes: 1 addition & 2 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ import (
"context"
"fmt"
"strconv"
"sync/atomic"
"time"

waBinary "go.mau.fi/whatsmeow/binary"
"go.mau.fi/whatsmeow/types"
)

func (cli *Client) generateRequestID() string {
return cli.uniqueID + strconv.FormatUint(uint64(atomic.AddUint32(&cli.idCounter, 1)), 10)
return cli.uniqueID + strconv.FormatUint(cli.idCounter.Add(1), 10)
}

var xmlStreamEndNode = &waBinary.Node{Tag: "xmlstreamend"}
Expand Down
4 changes: 2 additions & 2 deletions socket/noisesocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type NoiseSocket struct {
writeCounter uint32
readCounter uint32
writeLock sync.Mutex
destroyed uint32
destroyed atomic.Bool
stopConsumer chan struct{}
}

Expand Down Expand Up @@ -75,7 +75,7 @@ func (ns *NoiseSocket) Context() context.Context {
}

func (ns *NoiseSocket) Stop(disconnect bool) {
if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) {
if ns.destroyed.CompareAndSwap(false, true) {
close(ns.stopConsumer)
ns.fs.OnDisconnect = nil
if disconnect {
Expand Down

0 comments on commit 2e5dad7

Please sign in to comment.