Skip to content

Commit

Permalink
respond to feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vinay-gopalan committed Nov 2, 2023
1 parent 13778ae commit 86d2276
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 33 deletions.
26 changes: 13 additions & 13 deletions api/applications.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ type ApplicationsClient interface {
RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) error
}

var _ ApplicationsClient = (*AppClient)(nil)
var _ GroupsClient = (*AppClient)(nil)
var _ ServicePrincipalClient = (*AppClient)(nil)
var _ ApplicationsClient = (*MSGraphClient)(nil)
var _ GroupsClient = (*MSGraphClient)(nil)
var _ ServicePrincipalClient = (*MSGraphClient)(nil)

type AppClient struct {
type MSGraphClient struct {
client *msgraphsdkgo.GraphServiceClient
}

Expand All @@ -50,11 +50,11 @@ type PasswordCredential struct {
SecretText string
}

// NewMSGraphApplicationClient returns a new AppClient configured to interact with
// NewMSGraphApplicationClient returns a new MSGraphClient configured to interact with
// the Microsoft Graph API. It can be configured to target alternative national cloud
// deployments via graphURI. For details on the client configuration see
// https://learn.microsoft.com/en-us/graph/sdks/national-clouds
func NewMSGraphApplicationClient(graphURI string, creds azcore.TokenCredential) (*AppClient, error) {
func NewMSGraphApplicationClient(graphURI string, creds azcore.TokenCredential) (*MSGraphClient, error) {
scopes := []string{
fmt.Sprintf("%s/.default", graphURI),
}
Expand All @@ -72,13 +72,13 @@ func NewMSGraphApplicationClient(graphURI string, creds azcore.TokenCredential)
adapter.SetBaseUrl(fmt.Sprintf("%s/v1.0", graphURI))
client := msgraphsdkgo.NewGraphServiceClient(adapter)

ac := &AppClient{
ac := &MSGraphClient{
client: client,
}
return ac, nil
}

func (c *AppClient) GetApplication(ctx context.Context, clientID string) (Application, error) {
func (c *MSGraphClient) GetApplication(ctx context.Context, clientID string) (Application, error) {
filter := fmt.Sprintf("appId eq '%s'", clientID)
req := applications.ApplicationsRequestBuilderGetRequestConfiguration{
QueryParameters: &applications.ApplicationsRequestBuilderGetQueryParameters{
Expand Down Expand Up @@ -113,7 +113,7 @@ func (c *AppClient) GetApplication(ctx context.Context, clientID string) (Applic
return application, nil
}

func (c *AppClient) ListApplications(ctx context.Context, filter string) ([]Application, error) {
func (c *MSGraphClient) ListApplications(ctx context.Context, filter string) ([]Application, error) {

req := &applications.ApplicationsRequestBuilderGetQueryParameters{
Filter: &filter,
Expand Down Expand Up @@ -142,7 +142,7 @@ func (c *AppClient) ListApplications(ctx context.Context, filter string) ([]Appl
}

// CreateApplication create a new Azure application object.
func (c *AppClient) CreateApplication(ctx context.Context, displayName string) (Application, error) {
func (c *MSGraphClient) CreateApplication(ctx context.Context, displayName string) (Application, error) {
requestBody := models.NewApplication()
requestBody.SetDisplayName(&displayName)

Expand All @@ -163,7 +163,7 @@ func (c *AppClient) CreateApplication(ctx context.Context, displayName string) (

// DeleteApplication deletes an Azure application object.
// This will in turn remove the service principal (but not the role assignments).
func (c *AppClient) DeleteApplication(ctx context.Context, applicationObjectID string, permanentlyDelete bool) error {
func (c *MSGraphClient) DeleteApplication(ctx context.Context, applicationObjectID string, permanentlyDelete bool) error {
err := c.client.Applications().ByApplicationId(applicationObjectID).Delete(ctx, nil)

if permanentlyDelete {
Expand All @@ -175,7 +175,7 @@ func (c *AppClient) DeleteApplication(ctx context.Context, applicationObjectID s
return err
}

func (c *AppClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime time.Time) (PasswordCredential, error) {
func (c *MSGraphClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime time.Time) (PasswordCredential, error) {
requestBody := applications.NewItemAddPasswordPostRequestBody()
passwordCredential := models.NewPasswordCredential()
passwordCredential.SetDisplayName(&displayName)
Expand All @@ -195,7 +195,7 @@ func (c *AppClient) AddApplicationPassword(ctx context.Context, applicationObjec
}, nil
}

func (c *AppClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) error {
func (c *MSGraphClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) error {
requestBody := applications.NewItemRemovePasswordPostRequestBody()
kid, err := uuid.Parse(keyID)
if err != nil {
Expand Down
11 changes: 6 additions & 5 deletions api/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package api

import (
"context"
"fmt"

"github.com/microsoftgraph/msgraph-sdk-go/groups"
"github.com/microsoftgraph/msgraph-sdk-go/models"
Expand All @@ -22,19 +23,19 @@ type Group struct {
DisplayName string
}

func (c *AppClient) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error {
func (c *MSGraphClient) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error {
req := models.NewReferenceCreate()
odataId := "https://graph.microsoft.com/v1.0/directoryObjects/{id}"
odataId := fmt.Sprintf("https://graph.microsoft.com/v1.0/directoryObjects/%s", memberObjectID)
req.SetOdataId(&odataId)

return c.client.Groups().ByGroupId(groupObjectID).Members().Ref().Post(ctx, req, nil)
}

func (c *AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error {
func (c *MSGraphClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error {
return c.client.Groups().ByGroupId(groupObjectID).Members().ByDirectoryObjectId(memberObjectID).Ref().Delete(ctx, nil)
}

func (c *AppClient) GetGroup(ctx context.Context, groupID string) (Group, error) {
func (c *MSGraphClient) GetGroup(ctx context.Context, groupID string) (Group, error) {
resp, err := c.client.Groups().ByGroupId(groupID).Get(ctx, nil)
if err != nil {
return Group{}, err
Expand All @@ -46,7 +47,7 @@ func (c *AppClient) GetGroup(ctx context.Context, groupID string) (Group, error)
}, nil
}

func (c *AppClient) ListGroups(ctx context.Context, filter string) ([]Group, error) {
func (c *MSGraphClient) ListGroups(ctx context.Context, filter string) ([]Group, error) {
req := &groups.GroupsRequestBuilderGetQueryParameters{
Filter: &filter,
}
Expand Down
10 changes: 5 additions & 5 deletions api/service_principals.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type ServicePrincipalClient interface {
DeleteServicePrincipal(ctx context.Context, spObjectID string, permanentlyDelete bool) error
}

func (c *AppClient) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (string, string, error) {
func (c *MSGraphClient) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (string, string, error) {
spReq := models.NewServicePrincipal()
spReq.SetAppId(&appID)

Expand All @@ -38,7 +38,7 @@ func (c *AppClient) CreateServicePrincipal(ctx context.Context, appID string, st

passwordReq.SetPasswordCredential(passwordCredential)

password, err := c.client.ServicePrincipals().ByServicePrincipalId(*spID).AddPassword().Post(context.Background(), passwordReq, nil)
password, err := c.client.ServicePrincipals().ByServicePrincipalId(*spID).AddPassword().Post(ctx, passwordReq, nil)

if err != nil {
e := c.DeleteServicePrincipal(ctx, *spID, false)
Expand All @@ -48,7 +48,7 @@ func (c *AppClient) CreateServicePrincipal(ctx context.Context, appID string, st
return *spID, *password.GetSecretText(), nil
}

func (c *AppClient) DeleteServicePrincipal(ctx context.Context, spObjectID string, permanentlyDelete bool) error {
func (c *MSGraphClient) DeleteServicePrincipal(ctx context.Context, spObjectID string, permanentlyDelete bool) error {
err := c.client.ServicePrincipals().ByServicePrincipalId(spObjectID).Delete(ctx, nil)

if permanentlyDelete {
Expand All @@ -60,7 +60,7 @@ func (c *AppClient) DeleteServicePrincipal(ctx context.Context, spObjectID strin
return err
}

func (c *AppClient) ListServicePrincipals(ctx context.Context, spObjectID string) ([]models.ServicePrincipalable, error) {
func (c *MSGraphClient) ListServicePrincipals(ctx context.Context, spObjectID string) ([]models.ServicePrincipalable, error) {
filter := fmt.Sprintf("appId eq '%s'", spObjectID)
requestParameters := &serviceprincipals.ServicePrincipalsRequestBuilderGetQueryParameters{
Filter: &filter,
Expand All @@ -81,6 +81,6 @@ func (c *AppClient) ListServicePrincipals(ctx context.Context, spObjectID string
return spList.GetValue(), nil
}

func (c *AppClient) GetServicePrincipalByID(ctx context.Context, spObjectID string) (models.ServicePrincipalable, error) {
func (c *MSGraphClient) GetServicePrincipalByID(ctx context.Context, spObjectID string) (models.ServicePrincipalable, error) {
return c.client.ServicePrincipals().ByServicePrincipalId(spObjectID).Get(ctx, nil)
}
11 changes: 3 additions & 8 deletions path_service_principal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestSP_WAL_Cleanup(t *testing.T) {
b, s := getTestBackendMocked(t, true)

mp := newMockProvider()
mp.(*mockProvider).ctxTimeout = 5
mp.(*mockProvider).ctxTimeout = 5 * time.Second
b.getProvider = func(s *clientSettings, p api.Passwords) (AzureProvider, error) {
return mp, nil
}
Expand Down Expand Up @@ -1044,16 +1044,11 @@ func assertClientSecret(tb testing.TB, data map[string]interface{}) {
}
}

type servicePrincipalResp struct {
AppID string `json:"appId"`
ID string `json:"id"`
}

func findServicePrincipalID(t *testing.T, client api.ServicePrincipalClient, appID string) (spID string) {
t.Helper()

switch spClient := client.(type) {
case *api.AppClient:
case *api.MSGraphClient:
pathVals := &url.Values{}
pathVals.Set("$filter", fmt.Sprintf("appId eq '%s'", appID))

Expand Down Expand Up @@ -1081,7 +1076,7 @@ func assertServicePrincipalExists(t *testing.T, client api.ServicePrincipalClien
t.Helper()

switch spClient := client.(type) {
case *api.AppClient:
case *api.MSGraphClient:
sp, err := spClient.GetServicePrincipalByID(context.Background(), spID)
assertErrorIsNil(t, err)

Expand Down
4 changes: 2 additions & 2 deletions provider_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type mockProvider struct {
deletedObjects map[string]bool
passwords map[string]string
failNextCreateApplication bool
ctxTimeout int
ctxTimeout time.Duration
lock sync.Mutex
}

Expand Down Expand Up @@ -118,7 +118,7 @@ func (m *mockProvider) CreateServicePrincipal(_ context.Context, _ string, _ tim
func (m *mockProvider) CreateApplication(_ context.Context, _ string) (api.Application, error) {
if m.ctxTimeout != 0 {
// simulate a context deadline error by sleeping for timeout period
time.Sleep(time.Duration(m.ctxTimeout) * time.Second)
time.Sleep(m.ctxTimeout)
}

if m.failNextCreateApplication {
Expand Down

0 comments on commit 86d2276

Please sign in to comment.