Skip to content

Commit

Permalink
fixed encryption queries (#2204)
Browse files Browse the repository at this point in the history
* fixed encryption queries

* updated logger
  • Loading branch information
mekilis authored Dec 19, 2024
1 parent ec1c3b8 commit 04c956f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 27 deletions.
34 changes: 19 additions & 15 deletions internal/pkg/keys/encrypter_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func InitEncryption(lo log.StdLogger, db database.Database, km KeyManager, encry
}

for column, cipherColumn := range columns {
if err := encryptColumn(ctx, tx, table, column, cipherColumn, encryptionKey); err != nil {
if err := encryptColumn(lo, ctx, tx, table, column, cipherColumn, encryptionKey); err != nil {
rollback(lo, tx)
lo.WithError(err).Error("failed to encrypt column")
return fmt.Errorf("failed to encrypt column %s: %w", columns, err)
Expand Down Expand Up @@ -111,46 +111,50 @@ func lockTable(ctx context.Context, tx *sqlx.Tx, table string, timeout int) erro
}

// encryptColumn encrypts the specified column in the table.
func encryptColumn(ctx context.Context, tx *sqlx.Tx, table, column, cipherColumn, encryptionKey string) error {
func encryptColumn(lo log.StdLogger, ctx context.Context, tx *sqlx.Tx, table, column, cipherColumn, encryptionKey string) error {
// Encrypt the column data and store it in the _cipher column
columnZero, err := getColumnZero(lo, ctx, tx, table, column)
if err != nil {
return err
}
encryptQuery := fmt.Sprintf(
"UPDATE convoy.%s SET %s = pgp_sym_encrypt(%s::text, $1), %s = %s WHERE %s IS NOT NULL;",
table, cipherColumn, column, column, getColumnZero(ctx, tx, table, column), column,
table, cipherColumn, column, column, columnZero, column,
)
_, err := tx.ExecContext(ctx, encryptQuery, encryptionKey)
_, err = tx.ExecContext(ctx, encryptQuery, encryptionKey)
if err != nil {
return fmt.Errorf("failed to encrypt column %s in table %s: %w", column, table, err)
}

return nil
}

func getColumnZero(ctx context.Context, tx *sqlx.Tx, table, column string) string {
query := `SELECT is_nullable, data_type FROM convoy.information_schema.columns WHERE table_name = $1 AND column_name = $2;`
func getColumnZero(lo log.StdLogger, ctx context.Context, tx *sqlx.Tx, table, column string) (string, error) {
query := `SELECT is_nullable, data_type FROM information_schema.columns WHERE table_name = $1 AND column_name = $2;`
var isNullable, columnType string
err := tx.QueryRowContext(ctx, query, table, column).Scan(&isNullable, &columnType)
if err != nil {
log.Infof("Failed to fetch column info for %s.%s: %v", table, column, err)
return NULL
lo.Errorf("Failed to fetch column info for %s.%s: %v", table, column, err)
return NULL, err
}

if isNullable == "NO" {
switch {
case strings.Contains(columnType, "json"):
return "'[]'::jsonb"
return "'[]'::jsonb", nil
case strings.Contains(columnType, "text") || strings.Contains(columnType, "char"):
return "''"
return "''", nil
case strings.Contains(columnType, "int") || strings.Contains(columnType, "numeric"):
return "0"
return "0", nil
case strings.Contains(columnType, "bool"):
return "FALSE"
return "FALSE", nil
default:
log.Warnf("Unknown type %s for %s.%s, defaulting to NULL", columnType, table, column)
return NULL
lo.Warnf("Unknown type %s for %s.%s, defaulting to NULL", columnType, table, column)
return NULL, nil
}
}

return NULL
return NULL, nil
}

// markTableEncrypted sets the `is_encrypted` column to true.
Expand Down
22 changes: 13 additions & 9 deletions internal/pkg/keys/encrypter_revert.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func RevertEncryption(lo log.StdLogger, db database.Database, encryptionKey stri
}

for column, cipherColumn := range columns {
if err := decryptAndRestoreColumn(ctx, tx, table, column, cipherColumn, encryptionKey); err != nil {
if err := decryptAndRestoreColumn(lo, ctx, tx, table, column, cipherColumn, encryptionKey); err != nil {
rollback(lo, tx)
return err
}
Expand All @@ -61,13 +61,17 @@ func RevertEncryption(lo log.StdLogger, db database.Database, encryptionKey stri
}

// decryptAndRestoreColumn decrypts the cipher column and restores the data to the plain column.
func decryptAndRestoreColumn(ctx context.Context, tx *sqlx.Tx, table, column, cipherColumn, encryptionKey string) error {
func decryptAndRestoreColumn(lo log.StdLogger, ctx context.Context, tx *sqlx.Tx, table, column, cipherColumn, encryptionKey string) error {
// Decrypt the cipher column and update the plain column, casting as needed
columnType, err := getColumnType(lo, ctx, tx, table, column)
if err != nil {
return err
}
revertQuery := fmt.Sprintf(
"UPDATE convoy.%s SET %s = pgp_sym_decrypt(%s::bytea, $1)::%s WHERE %s IS NOT NULL;",
table, column, cipherColumn, getColumnType(ctx, tx, table, column), cipherColumn,
table, column, cipherColumn, columnType, cipherColumn,
)
_, err := tx.ExecContext(ctx, revertQuery, encryptionKey)
_, err = tx.ExecContext(ctx, revertQuery, encryptionKey)
if err != nil {
return fmt.Errorf("failed to decrypt column %s in table %s: %w", cipherColumn, table, err)
}
Expand All @@ -85,15 +89,15 @@ func decryptAndRestoreColumn(ctx context.Context, tx *sqlx.Tx, table, column, ci
return nil
}

func getColumnType(ctx context.Context, tx *sqlx.Tx, table, column string) string {
query := `SELECT data_type FROM convoy.information_schema.columns WHERE table_name = $1 AND column_name = $2;`
func getColumnType(lo log.StdLogger, ctx context.Context, tx *sqlx.Tx, table, column string) (string, error) {
query := `SELECT data_type FROM information_schema.columns WHERE table_name = $1 AND column_name = $2;`
var columnType string
err := tx.GetContext(ctx, &columnType, query, table, column)
if err != nil {
log.Infof("Failed to fetch column type for %s.%s: %v", table, column, err)
return ""
lo.Errorf("Failed to fetch column type for %s.%s: %v", table, column, err)
return "", err
}
return columnType
return columnType, nil
}

// markTableDecrypted sets the `is_encrypted` column to false.
Expand Down
6 changes: 3 additions & 3 deletions internal/pkg/keys/encrypter_rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ func RotateEncryptionKey(lo log.StdLogger, db database.Database, km KeyManager,
err = lockTable(ctx, tx, table, timeout)
if err != nil {
rollback(lo, tx)
log.WithError(err).Error("failed to lock table")
lo.WithError(err).Error("failed to lock table")
return err
}

isEncrypted, err := checkEncryptionStatus(ctx, tx, table)
if err != nil {
rollback(lo, tx)
log.WithError(err).Error("failed to check encryption status")
lo.WithError(err).Error("failed to check encryption status")
return err
}

Expand All @@ -46,7 +46,7 @@ func RotateEncryptionKey(lo log.StdLogger, db database.Database, km KeyManager,
err = reEncryptColumn(ctx, tx, table, cipherColumn, oldKey, newKey)
if err != nil {
rollback(lo, tx)
log.WithError(err).Error("failed to re-encrypt column")
lo.WithError(err).Error("failed to re-encrypt column")
return err
}
}
Expand Down

0 comments on commit 04c956f

Please sign in to comment.