Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (ldap): add worker pool for LDAP token group lookups #98

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Canonical reference for changes, improvements, and bugfixes for cap.

## Next

* LDAP
* Add worker pool for LDAP token group lookups ([**PR**](https://github.com/hashicorp/cap/pull/98))

## 0.3.4

### Bug fixes
Expand Down
77 changes: 53 additions & 24 deletions ldap/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net"
"net/url"
"strings"
"sync"
"text/template"
"time"

Expand Down Expand Up @@ -458,35 +459,63 @@ func (c *Client) tokenGroupsSearch(userDN string) ([]*ldap.Entry, []Warning, err

userEntry := result.Entries[0]
groupAttrValues := userEntry.GetRawAttributeValues("tokenGroups")

groupEntries := make([]*ldap.Entry, 0, len(groupAttrValues))
for _, sidBytes := range groupAttrValues {
sidString, err := sidBytesToString(sidBytes)
if err != nil {
warnings = append(warnings, fmtWarning("%s: unable to read sid: %s", op, err.Error()))
continue
}

groupResult, err := c.conn.Search(&ldap.SearchRequest{
BaseDN: fmt.Sprintf("<SID=%s>", sidString),
Scope: ldap.ScopeBaseObject,
DerefAliases: derefAliasMap[c.conf.DerefAliases],
Filter: "(objectClass=*)",
Attributes: []string{
"1.1", // RFC no attributes
},
SizeLimit: 1,
})
if err != nil {
warnings = append(warnings, fmtWarning("%s: unable to read the group sid (baseDN: %q / filter: %q): %s", op, fmt.Sprintf("<SID=%s>", sidString), "(objectClass=*)", sidString))
continue
{
// we're using worker pool to make looking up token groups more
// performant. token groups have to be looked up individually, so if a
// user is a member of MANY groups it can be helpful to do these lookups
// concurrently vs serially. This is based on benchmarks and a
// subsequent implementation within vault's codebase for looking up token
// groups. See: https://github.com/hashicorp/vault/pull/22659
const maxWorkers = 10
var wg sync.WaitGroup
var lock sync.Mutex
taskChan := make(chan string) // intentionally an unbuffered chan so we can iterate (range) over it before it's closed.
for i := 0; i < maxWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for sidString := range taskChan {
groupResult, err := c.conn.Search(&ldap.SearchRequest{
BaseDN: fmt.Sprintf("<SID=%s>", sidString),
Scope: ldap.ScopeBaseObject,
DerefAliases: derefAliasMap[c.conf.DerefAliases],
Filter: "(objectClass=*)",
Attributes: []string{
"1.1", // RFC no attributes
},
SizeLimit: 1,
})
if err != nil {
warnings = append(warnings, fmtWarning("%s: unable to read the group sid (baseDN: %q / filter: %q): %s", op, fmt.Sprintf("<SID=%s>", sidString), "(objectClass=*)", sidString))
continue
}
if len(groupResult.Entries) == 0 {
warnings = append(warnings, fmtWarning("%s: unable to find the group sid (baseDN: %q / filter: %q): %s", op, fmt.Sprintf("<SID=%s>", sidString), "(objectClass=*)", sidString))
continue
}
lock.Lock()
groupEntries = append(groupEntries, groupResult.Entries[0])
lock.Unlock()
}
}()
}
if len(groupResult.Entries) == 0 {
warnings = append(warnings, fmtWarning("%s: unable to find the group sid (baseDN: %q / filter: %q): %s", op, fmt.Sprintf("<SID=%s>", sidString), "(objectClass=*)", sidString))
continue
for _, sidBytes := range groupAttrValues {
sidString, err := sidBytesToString(sidBytes)
if err != nil {
warnings = append(warnings, fmtWarning("%s: unable to read sid: %s", op, err.Error()))
continue
}
taskChan <- sidString
}
// closing the taskChan will allow the workers to start iterating
// (range) - this unblocks them
close(taskChan)

groupEntries = append(groupEntries, groupResult.Entries[0])
// wait for all the workers to finish up the token group lookups and
// adding all the groups to the slice of group entries
wg.Wait()
}

return groupEntries, warnings, nil
Expand Down
Loading