Skip to content

Commit

Permalink
Add Foreign Key Constraints to Sessions (#4)
Browse files Browse the repository at this point in the history
* Updating required schema

* Updating schema for storage to be more streamlined

* Fixing minor bugs and renaming functions/interfaces

* Making manager more streamlined by only exposing required functions

* Fixing and streamlining bugs around flow verification

* Updating generated code

* Updating generated code
  • Loading branch information
ShivanshVij authored Jan 12, 2023
1 parent 704c144 commit 7223f7c
Show file tree
Hide file tree
Showing 24 changed files with 882 additions and 517 deletions.
30 changes: 15 additions & 15 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@
package auth

const (
APIKeyPrefixString = "AK-"
ServiceKeyPrefixString = "SK-"
ServiceKeySessionPrefixString = "SS-"
APIKeyPrefixString = "AK-"
ServiceKeyPrefixString = "SK-"
ServiceSessionPrefixString = "SS-"
)

var (
APIKeyPrefix = []byte(APIKeyPrefixString)
ServiceKeySessionPrefix = []byte(ServiceKeySessionPrefixString)
APIKeyPrefix = []byte(APIKeyPrefixString)
ServiceKeyPrefix = []byte(ServiceKeyPrefixString)
ServiceSessionPrefix = []byte(ServiceSessionPrefixString)
)

const (
SessionContextKey = "session"
APIKeyContextKey = "apikey"
ServiceKeySessionContextKey = "service"
UserContextKey = "user"
OrganizationContextKey = "organization"
SessionContextKey = "session"
APIKeyContextKey = "apikey"
ServiceSessionContextKey = "service"
UserContextKey = "user"
OrganizationContextKey = "organization"
KindContextKey = "kind"
)

type Kind string

const (
KindContextKey Kind = "kind"

KindSession Kind = "session"
KindAPIKey Kind = "api"
KindServiceKey Kind = "service"
KindSession Kind = "session"
KindAPIKey Kind = "api"
KindServiceSession Kind = "service"
)
34 changes: 33 additions & 1 deletion pkg/api/v1/docs/api_docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ const docTemplateapi = `{
},
{
"type": "string",
"description": "Device Code Identifier",
"description": "Device Flow Identifier",
"name": "identifier",
"in": "query"
}
Expand Down Expand Up @@ -440,6 +440,12 @@ const docTemplateapi = `{
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/models.ServiceKeyLoginResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"type": "string"
}
Expand Down Expand Up @@ -490,6 +496,32 @@ const docTemplateapi = `{
"type": "string"
}
}
},
"models.ServiceKeyLoginResponse": {
"type": "object",
"properties": {
"organization": {
"type": "string"
},
"resource_id": {
"type": "string"
},
"resource_type": {
"type": "string"
},
"service_key_id": {
"type": "string"
},
"service_session_id": {
"type": "string"
},
"service_session_secret": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
}
}
}`
Expand Down
34 changes: 33 additions & 1 deletion pkg/api/v1/docs/api_swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
},
{
"type": "string",
"description": "Device Code Identifier",
"description": "Device Flow Identifier",
"name": "identifier",
"in": "query"
}
Expand Down Expand Up @@ -420,6 +420,12 @@
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/models.ServiceKeyLoginResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"type": "string"
}
Expand Down Expand Up @@ -470,6 +476,32 @@
"type": "string"
}
}
},
"models.ServiceKeyLoginResponse": {
"type": "object",
"properties": {
"organization": {
"type": "string"
},
"resource_id": {
"type": "string"
},
"resource_type": {
"type": "string"
},
"service_key_id": {
"type": "string"
},
"service_session_id": {
"type": "string"
},
"service_session_secret": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
}
}
}
23 changes: 22 additions & 1 deletion pkg/api/v1/docs/api_swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ definitions:
user_code:
type: string
type: object
models.ServiceKeyLoginResponse:
properties:
organization:
type: string
resource_id:
type: string
resource_type:
type: string
service_key_id:
type: string
service_session_id:
type: string
service_session_secret:
type: string
user_id:
type: string
type: object
host: localhost:8080
info:
contact:
Expand Down Expand Up @@ -201,7 +218,7 @@ paths:
in: query
name: organization
type: string
- description: Device Code Identifier
- description: Device Flow Identifier
in: query
name: identifier
type: string
Expand Down Expand Up @@ -296,6 +313,10 @@ paths:
responses:
"200":
description: OK
schema:
$ref: '#/definitions/models.ServiceKeyLoginResponse'
"400":
description: Bad Request
schema:
type: string
"401":
Expand Down
21 changes: 19 additions & 2 deletions pkg/api/v1/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (a *Github) App() *fiber.App {
// @Produce json
// @Param next query string false "Next Redirect URL"
// @Param organization query string false "Organization"
// @Param identifier query string false "Device Code Identifier"
// @Param identifier query string false "Device Flow Identifier"
// @Success 307
// @Header 307 {string} Location "Redirects to Github"
// @Failure 401 {string} string
Expand All @@ -74,7 +74,24 @@ func (a *Github) GithubLogin(ctx *fiber.Ctx) error {
return ctx.Status(fiber.StatusUnauthorized).SendString("github provider is not enabled")
}

redirect, err := a.options.Github().StartFlow(ctx.Context(), ctx.Query("next", a.options.NextURL()), ctx.Query("organization"), ctx.Query("identifier"))
identifier := ctx.Query("identifier")
if identifier != "" {
if a.options.Device() == nil {
return ctx.Status(fiber.StatusUnauthorized).SendString("device provider is not enabled")
}

exists, err := a.options.Device().FlowExists(ctx.Context(), identifier)
if err != nil {
a.logger.Error().Err(err).Msg("failed to check if flow exists")
return ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if flow exists")
}

if !exists {
return ctx.Status(fiber.StatusUnauthorized).SendString("invalid device flow identifier")
}
}

redirect, err := a.options.Github().StartFlow(ctx.Context(), ctx.Query("next", a.options.NextURL()), ctx.Query("organization"), identifier)
if err != nil {
a.logger.Error().Err(err).Msg("failed to get redirect")
return ctx.Status(fiber.StatusInternalServerError).SendString("failed to get redirect")
Expand Down
14 changes: 7 additions & 7 deletions pkg/api/v1/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ type DeviceCallbackResponse struct {
}

type ServiceKeyLoginResponse struct {
ServiceKeySessionID string `json:"service_key_session_id"`
ServiceKeySessionSecret string `json:"service_key_session_secret"`
ServiceKeyID string `json:"service_key_id"`
UserID string `json:"user_id"`
Organization string `json:"organization"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id"`
ServiceSessionID string `json:"service_session_id"`
ServiceSessionSecret string `json:"service_session_secret"`
ServiceKeyID string `json:"service_key_id"`
UserID string `json:"user_id"`
Organization string `json:"organization"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id"`
}
23 changes: 10 additions & 13 deletions pkg/api/v1/servicekey/servicekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ func (a *ServiceKey) App() *fiber.App {
// @Accept json
// @Produce json
// @Param servicekey query string true "Service Key"
// @Success 200 {string} string
// @Success 200 {object} models.ServiceKeyLoginResponse
// @Failure 400 {string} string
// @Failure 401 {string} string
// @Failure 500 {string} string
// @Router /servicekey/login [post]
Expand All @@ -87,22 +88,18 @@ func (a *ServiceKey) ServiceKeyLogin(ctx *fiber.Ctx) error {
keySecret := []byte(keySplit[1])

a.logger.Debug().Msgf("logging in user with service key ID %s", keyID)
sess, secret, err := a.options.Manager().CreateServiceKeySession(ctx, keyID, keySecret)
sess, secret, err := a.options.Manager().CreateServiceSession(ctx, keyID, keySecret)
if sess == nil || secret == nil {
return err
}

return ctx.JSON(&models.ServiceKeyLoginResponse{
ServiceKeySessionID: sess.ID,
ServiceKeySessionSecret: string(secret),
ServiceKeyID: sess.ServiceKeyID,
UserID: sess.UserID,
Organization: sess.Organization,
ResourceType: sess.ResourceType,
ResourceID: sess.ResourceID,
ServiceSessionID: sess.ID,
ServiceSessionSecret: string(secret),
ServiceKeyID: sess.ServiceKeyID,
UserID: sess.UserID,
Organization: sess.Organization,
ResourceType: sess.ResourceType,
ResourceID: sess.ResourceID,
})
}

func (a *ServiceKey) ServiceKeyLogout(ctx *fiber.Ctx) error {
return nil
}
2 changes: 1 addition & 1 deletion pkg/api/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (v *V1) Logout(ctx *fiber.Ctx) error {
return err
}

err = v.options.Manager().LogoutServiceKeySession(ctx)
err = v.options.Manager().LogoutServiceSession(ctx)
if err != nil {
return err
}
Expand Down
74 changes: 2 additions & 72 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,12 @@ package database

import (
"context"
"github.com/loopholelabs/auth/internal/ent"
"github.com/loopholelabs/auth/internal/ent/deviceflow"
"github.com/loopholelabs/auth/internal/ent/githubflow"
"github.com/loopholelabs/auth/pkg/provider/github"
"github.com/rs/zerolog"
"time"

_ "github.com/lib/pq"
"github.com/loopholelabs/auth/internal/ent"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog"
)

var _ github.Database = (*Database)(nil)

type Database struct {
logger *zerolog.Logger
client *ent.Client
Expand Down Expand Up @@ -76,66 +69,3 @@ func (d *Database) Shutdown() error {
}
return nil
}

func (d *Database) SetGithubFlow(ctx context.Context, state string, verifier string, challenge string, nextURL string, organization string, deviceIdentifier string) error {
d.logger.Debug().Msgf("setting github flow for %s", state)
_, err := d.client.GithubFlow.Create().SetState(state).SetVerifier(verifier).SetChallenge(challenge).SetNextURL(nextURL).SetOrganization(organization).SetDeviceIdentifier(deviceIdentifier).Save(ctx)
return err
}

func (d *Database) GetGithubFlow(ctx context.Context, state string) (*ent.GithubFlow, error) {
d.logger.Debug().Msgf("getting github flow for %s", state)
return d.client.GithubFlow.Query().Where(githubflow.State(state)).Only(ctx)
}

func (d *Database) DeleteGithubFlow(ctx context.Context, state string) error {
d.logger.Debug().Msgf("deleting github flow for %s", state)
_, err := d.client.GithubFlow.Delete().Where(githubflow.State(state)).Exec(ctx)
return err
}

func (d *Database) GCGithubFlow(ctx context.Context, expiry time.Duration) (int, error) {
d.logger.Debug().Msgf("running github flow gc")
return d.client.GithubFlow.Delete().Where(githubflow.CreatedAtLT(time.Now().Add(expiry))).Exec(ctx)
}

func (d *Database) SetDeviceFlow(ctx context.Context, identifier string, deviceCode string, userCode string) error {
d.logger.Debug().Msgf("setting device flow for %s (device code %s, user code %s)", identifier, deviceCode, userCode)
_, err := d.client.DeviceFlow.Create().SetIdentifier(identifier).SetDeviceCode(deviceCode).SetUserCode(userCode).Save(ctx)
return err
}

func (d *Database) GetDeviceFlow(ctx context.Context, deviceCode string) (*ent.DeviceFlow, error) {
d.logger.Debug().Msgf("getting device flow for device code %s", deviceCode)
return d.client.DeviceFlow.Query().Where(deviceflow.DeviceCode(deviceCode)).Only(ctx)
}

func (d *Database) UpdateDeviceFlow(ctx context.Context, identifier string, session string, expiry time.Time) error {
d.logger.Debug().Msgf("updating device flow for %s (expiry %s)", identifier, expiry)
_, err := d.client.DeviceFlow.Update().Where(deviceflow.Identifier(identifier)).SetSession(session).SetExpiresAt(expiry).Save(ctx)
return err
}

func (d *Database) GetDeviceFlowUserCode(ctx context.Context, userCode string) (*ent.DeviceFlow, error) {
d.logger.Debug().Msgf("getting device flow for user code %s", userCode)
flow, err := d.client.DeviceFlow.Query().Where(deviceflow.UserCode(userCode)).Only(ctx)
if err != nil {
return nil, err
}
_, err = flow.Update().SetLastPoll(time.Now()).Save(ctx)
if err != nil {
return nil, err
}
return flow, nil
}

func (d *Database) DeleteDeviceFlow(ctx context.Context, deviceCode string) error {
d.logger.Debug().Msgf("deleting device flow for device code %s", deviceCode)
_, err := d.client.DeviceFlow.Delete().Where(deviceflow.DeviceCode(deviceCode)).Exec(ctx)
return err
}

func (d *Database) GCDeviceFlow(ctx context.Context, expiry time.Duration) (int, error) {
d.logger.Debug().Msgf("running device flow gc")
return d.client.DeviceFlow.Delete().Where(deviceflow.CreatedAtLT(time.Now().Add(expiry))).Exec(ctx)
}
Loading

0 comments on commit 7223f7c

Please sign in to comment.