Skip to content

Commit

Permalink
chore: switch to typed atomics
Browse files Browse the repository at this point in the history
This just means we can be less concerned about struct field ordering and
other details.
  • Loading branch information
Stebalien committed Nov 11, 2023
1 parent eb03fd4 commit 4714f4b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
5 changes: 2 additions & 3 deletions flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package flow
import (
"math"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -106,7 +105,7 @@ func TestUnregister(t *testing.T) {

mockClock.Add(62 * time.Second)

if atomic.LoadUint64(&m.accumulator) != 0 {
if m.accumulator.Load() != 0 {
t.Error("expected meter to be paused")
}

Expand All @@ -131,7 +130,7 @@ func TestUnregister(t *testing.T) {
if actual.Total != 120 {
t.Errorf("expected total 120, got %d", actual.Total)
}
if atomic.LoadUint64(&m.accumulator) == 0 {
if m.accumulator.Load() == 0 {
t.Error("expected meter to be active")
}

Expand Down
6 changes: 3 additions & 3 deletions meter.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (s Snapshot) String() string {

// Meter is a meter for monitoring a flow.
type Meter struct {
accumulator uint64
accumulator atomic.Uint64

// managed by the sweeper loop.
registered bool
Expand All @@ -42,7 +42,7 @@ type Meter struct {

// Mark updates the total.
func (m *Meter) Mark(count uint64) {
if count > 0 && atomic.AddUint64(&m.accumulator, count) == count {
if count > 0 && m.accumulator.Add(count) == count {
// The accumulator is 0 so we probably need to register. We may
// already _be_ registered however, if we are, the registration
// loop will notice that `m.registered` is set and ignore us.
Expand All @@ -60,7 +60,7 @@ func (m *Meter) Snapshot() Snapshot {
// Reset sets accumulator, total and rate to zero.
func (m *Meter) Reset() {
globalSweeper.snapshotMu.Lock()
atomic.StoreUint64(&m.accumulator, 0)
m.accumulator.Store(0)
m.snapshot.Rate = 0
m.snapshot.Total = 0
globalSweeper.snapshotMu.Unlock()
Expand Down
1 change: 1 addition & 0 deletions meter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package flow
import (
"testing"
"time"
"unsafe"
)

func TestResetMeter(t *testing.T) {
Expand Down
11 changes: 5 additions & 6 deletions sweeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package flow
import (
"math"
"sync"
"sync/atomic"
"time"

"github.com/benbjohnson/clock"
Expand Down Expand Up @@ -100,7 +99,7 @@ func (sw *sweeper) update() {

// Calculate the bandwidth for all active meters.
for i, m := range sw.meters[:sw.activeMeters] {
total := atomic.LoadUint64(&m.accumulator)
total := m.accumulator.Load()
diff := total - m.snapshot.Total
instant := timeMultiplier * float64(diff)

Expand All @@ -124,7 +123,7 @@ func (sw *sweeper) update() {
// Ok, so we are idle...

// Mark this as idle by zeroing the accumulator.
swappedTotal := atomic.SwapUint64(&m.accumulator, 0)
swappedTotal := m.accumulator.Swap(0)

// So..., are we really idle?
if swappedTotal > total {
Expand All @@ -134,7 +133,7 @@ func (sw *sweeper) update() {
// First, add back what we removed. If we can do this
// fast enough, we can put it back before anyone
// notices.
currentTotal := atomic.AddUint64(&m.accumulator, swappedTotal)
currentTotal := m.accumulator.Add(swappedTotal)

// Did we make it?
if currentTotal == swappedTotal {
Expand All @@ -150,7 +149,7 @@ func (sw *sweeper) update() {
// `^uint64(total - 1)` is the two's complement of
// `total`. It's the "correct" way to subtract
// atomically in go.
atomic.AddUint64(&m.accumulator, ^uint64(m.snapshot.Total-1))
m.accumulator.Add(^uint64(m.snapshot.Total - 1))
}

// Reset the rate, keep the total.
Expand All @@ -163,7 +162,7 @@ func (sw *sweeper) update() {
// 1. We don't do this on register to avoid having to take the snapshot lock.
// 2. We skip calculating the bandwidth for this round so we get an _accurate_ bandwidth calculation.
for _, m := range sw.meters[sw.activeMeters:] {
total := atomic.AddUint64(&m.accumulator, m.snapshot.Total)
total := m.accumulator.Add(m.snapshot.Total)
if total > m.snapshot.Total {
m.snapshot.LastUpdate = now
}
Expand Down

0 comments on commit 4714f4b

Please sign in to comment.