Skip to content

Commit

Permalink
[fakeintake] make client thread safe (#21507)
Browse files Browse the repository at this point in the history
[fakeintake] make client thread safe
  • Loading branch information
pducolin authored Dec 18, 2023
1 parent 10907e0 commit 7b3653a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 23 deletions.
56 changes: 41 additions & 15 deletions test/fakeintake/aggregator/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"compress/zlib"
"io"
"sort"
"sync"
"time"

"github.com/DataDog/datadog-agent/test/fakeintake/api"
Expand All @@ -29,6 +30,8 @@ type parseFunc[P PayloadItem] func(payload api.Payload) (items []P, err error)
type Aggregator[P PayloadItem] struct {
payloadsByName map[string][]P
parse parseFunc[P]

mutex sync.RWMutex
}

const (
Expand All @@ -43,42 +46,40 @@ func newAggregator[P PayloadItem](parse parseFunc[P]) Aggregator[P] {
return Aggregator[P]{
payloadsByName: map[string][]P{},
parse: parse,
mutex: sync.RWMutex{},
}
}

// UnmarshallPayloads aggregate the payloads
func (agg *Aggregator[P]) UnmarshallPayloads(payloads []api.Payload) error {
// reset map
agg.Reset()
// build map
// build new map
payloadsByName := map[string][]P{}
for _, p := range payloads {
payloads, err := agg.parse(p)
if err != nil {
return err
}

for _, item := range payloads {
if _, found := agg.payloadsByName[item.name()]; !found {
agg.payloadsByName[item.name()] = []P{}
if _, found := payloadsByName[item.name()]; !found {
payloadsByName[item.name()] = []P{}
}
agg.payloadsByName[item.name()] = append(agg.payloadsByName[item.name()], item)
payloadsByName[item.name()] = append(payloadsByName[item.name()], item)
}
}
agg.replace(payloadsByName)

return nil
}

// ContainsPayloadName return true if name match one of the payloads
func (agg *Aggregator[P]) ContainsPayloadName(name string) bool {
_, found := agg.payloadsByName[name]
return found
return len(agg.GetPayloadsByName(name)) != 0
}

// ContainsPayloadNameAndTags return true if the payload name exist and on of the payloads contains all the tags
func (agg *Aggregator[P]) ContainsPayloadNameAndTags(name string, tags []string) bool {
payloads, found := agg.payloadsByName[name]
if !found {
return false
}
payloads := agg.GetPayloadsByName(name)

for _, payloadItem := range payloads {
if AreTagsSubsetOfOtherTags(tags, payloadItem.GetTags()) {
Expand All @@ -91,11 +92,18 @@ func (agg *Aggregator[P]) ContainsPayloadNameAndTags(name string, tags []string)

// GetNames return the names of the payloads
func (agg *Aggregator[P]) GetNames() []string {
names := []string{}
names := agg.getNamesUnsorted()
sort.Strings(names)
return names
}

func (agg *Aggregator[P]) getNamesUnsorted() []string {
agg.mutex.RLock()
defer agg.mutex.RUnlock()
names := make([]string, 0, len(agg.payloadsByName))
for name := range agg.payloadsByName {
names = append(names, name)
}
sort.Strings(names)
return names
}

Expand Down Expand Up @@ -126,14 +134,32 @@ func getReadCloserForEncoding(payload []byte, encoding string) (rc io.ReadCloser

// GetPayloadsByName return the payloads for the resource name
func (agg *Aggregator[P]) GetPayloadsByName(name string) []P {
return agg.payloadsByName[name]
agg.mutex.RLock()
defer agg.mutex.RUnlock()
payloads := agg.payloadsByName[name]
return payloads
}

// Reset the aggregation
func (agg *Aggregator[P]) Reset() {
agg.mutex.Lock()
defer agg.mutex.Unlock()
agg.unsafeReset()
}

func (agg *Aggregator[P]) unsafeReset() {
agg.payloadsByName = map[string][]P{}
}

func (agg *Aggregator[P]) replace(payloadsByName map[string][]P) {
agg.mutex.Lock()
defer agg.mutex.Unlock()
agg.unsafeReset()
for name, payloads := range payloadsByName {
agg.payloadsByName[name] = payloads
}
}

// FilterByTags return the payloads that match all the tags
func FilterByTags[P PayloadItem](payloads []P, tags []string) []P {
ret := []P{}
Expand Down
47 changes: 39 additions & 8 deletions test/fakeintake/aggregator/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package aggregator
import (
"encoding/json"
"runtime"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -66,7 +67,7 @@ func generateTestData() (data []api.Payload, err error) {
}, nil
}

func validateCollectionTime(t *testing.T, agg Aggregator[*mockPayloadItem]) {
func validateCollectionTime(t *testing.T, agg *Aggregator[*mockPayloadItem]) {
if runtime.GOOS != "linux" {
t.Logf("validateCollectionTime test skip on %s", runtime.GOOS)
return
Expand All @@ -80,26 +81,28 @@ func validateCollectionTime(t *testing.T, agg Aggregator[*mockPayloadItem]) {

func TestCommonAggregator(t *testing.T) {
t.Run("ContainsPayloadName", func(t *testing.T) {
agg := newAggregator(parseMockPayloadItem)
assert.False(t, agg.ContainsPayloadName("totoro"))
data, err := generateTestData()
require.NoError(t, err)
agg := newAggregator(parseMockPayloadItem)
err = agg.UnmarshallPayloads(data)
assert.NoError(t, err)
assert.True(t, agg.ContainsPayloadName("totoro"))
assert.False(t, agg.ContainsPayloadName("ponyo"))
validateCollectionTime(t, agg)
validateCollectionTime(t, &agg)
})

t.Run("ContainsPayloadNameAndTags", func(t *testing.T) {
agg := newAggregator(parseMockPayloadItem)
assert.False(t, agg.ContainsPayloadNameAndTags("totoro", []string{"age:123"}))
data, err := generateTestData()
require.NoError(t, err)
agg := newAggregator(parseMockPayloadItem)
err = agg.UnmarshallPayloads(data)
assert.NoError(t, err)
assert.True(t, agg.ContainsPayloadNameAndTags("totoro", []string{"age:123"}))
assert.False(t, agg.ContainsPayloadNameAndTags("porco rosso", []string{"country:it", "role:king"}))
assert.True(t, agg.ContainsPayloadNameAndTags("porco rosso", []string{"country:it", "role:pilot"}))
validateCollectionTime(t, agg)
validateCollectionTime(t, &agg)
})

t.Run("AreTagsSubsetOfOtherTags", func(t *testing.T) {
Expand Down Expand Up @@ -127,11 +130,39 @@ func TestCommonAggregator(t *testing.T) {
})

t.Run("Reset", func(t *testing.T) {
_, err := generateTestData()
data, err := generateTestData()
require.NoError(t, err)
agg := newAggregator(parseMockPayloadItem)
err = agg.UnmarshallPayloads(data)
require.NoError(t, err)
assert.NotEmpty(t, agg.payloadsByName)
agg.Reset()
assert.Equal(t, 0, len(agg.payloadsByName))
validateCollectionTime(t, agg)
assert.Empty(t, agg.payloadsByName)
})

t.Run("Thread safe", func(t *testing.T) {
var wg sync.WaitGroup
data, err := generateTestData()
require.NoError(t, err)
agg := newAggregator(parseMockPayloadItem)
// add some data to ensure we have names
err = agg.UnmarshallPayloads(data)
assert.NoError(t, err)
wg.Add(2)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
err := agg.UnmarshallPayloads(data)
assert.NoError(t, err)
}
}()
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
names := agg.GetNames()
assert.NotEmpty(t, names)
}
}()
wg.Wait()
})
}

0 comments on commit 7b3653a

Please sign in to comment.