Skip to content

Commit

Permalink
exp: rs append directly to the main table
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Dec 23, 2024
1 parent a82aa47 commit 4bc5c4e
Show file tree
Hide file tree
Showing 3 changed files with 427 additions and 362 deletions.
131 changes: 106 additions & 25 deletions warehouse/integrations/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,13 @@ func (rs *Redshift) generateManifest(ctx context.Context, tableName string) (str
}

func (rs *Redshift) dropStagingTables(ctx context.Context, stagingTableNames []string) {
for _, stagingTableName := range stagingTableNames {
rs.logger.Infof("WH: dropping table %+v\n", stagingTableName)
_, err := rs.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, rs.Namespace, stagingTableName))
if err != nil {
rs.logger.Errorf("WH: RS: Error dropping staging tables in redshift: %v", err)
}
}
//for _, stagingTableName := range stagingTableNames {
// rs.logger.Infof("WH: dropping table %+v\n", stagingTableName)
// _, err := rs.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, rs.Namespace, stagingTableName))
// if err != nil {
// rs.logger.Errorf("WH: RS: Error dropping staging tables in redshift: %v", err)
// }
//}
}

func (rs *Redshift) loadTable(
Expand Down Expand Up @@ -461,6 +461,30 @@ func (rs *Redshift) loadTable(
}
log.Debugw("generated manifest", "manifestLocation", manifestLocation)

strKeys := warehouseutils.GetColumnsFromTableSchema(tableSchemaInUpload)
sort.Strings(strKeys)

if !rs.ShouldMerge(tableName) {
log.Infow("copying data into main table")

result, err := rs.copyInto(
ctx, rs.DB, tableName,
manifestLocation, strKeys,
)
if err != nil {
return nil, "", fmt.Errorf("loading data into staging table: %w", err)
}

rowsInserted, err := result.RowsAffected()
if err != nil {
return nil, "", fmt.Errorf("getting rows affected: %w", err)
}

return &types.LoadTableStats{
RowsInserted: rowsInserted,
}, "", nil
}

stagingTableName := warehouseutils.StagingTableName(
provider,
tableName,
Expand Down Expand Up @@ -493,9 +517,6 @@ func (rs *Redshift) loadTable(
}
}()

strKeys := warehouseutils.GetColumnsFromTableSchema(tableSchemaInUpload)
sort.Strings(strKeys)

log.Infow("loading data into staging table")
err = rs.copyIntoLoadTable(
ctx, txn, stagingTableName,
Expand All @@ -509,15 +530,14 @@ func (rs *Redshift) loadTable(
rowsDeletedResult, rowsInsertedResult sql.Result
rowsDeleted, rowsInserted int64
)
if rs.ShouldMerge(tableName) {
log.Infow("deleting from load table")
rowsDeletedResult, err = rs.deleteFromLoadTable(
ctx, txn, tableName,
stagingTableName, tableSchemaAfterUpload,
)
if err != nil {
return nil, "", fmt.Errorf("delete from load table: %w", err)
}

log.Infow("deleting from load table")
rowsDeletedResult, err = rs.deleteFromLoadTable(
ctx, txn, tableName,
stagingTableName, tableSchemaAfterUpload,
)
if err != nil {
return nil, "", fmt.Errorf("delete from load table: %w", err)
}

log.Infow("inserting into load table")
Expand Down Expand Up @@ -555,6 +575,67 @@ func (rs *Redshift) loadTable(
}, stagingTableName, nil
}

func (rs *Redshift) copyInto(ctx context.Context, db *sqlmiddleware.DB, stagingTableName, manifestLocation string, strKeys []string) (sql.Result, error) {
tempAccessKeyId, tempSecretAccessKey, token, err := warehouseutils.GetTemporaryS3Cred(&rs.Warehouse.Destination)
if err != nil {
return nil, fmt.Errorf("getting temporary s3 credentials: %w", err)
}

manifestS3Location, region := warehouseutils.GetS3Location(manifestLocation)
if region == "" {
region = "us-east-1"
}

sortedColumnNames := warehouseutils.JoinWithFormatting(strKeys, func(_ int, name string) string {
return fmt.Sprintf(`%q`, name)
}, ",")

var copyStmt string
if rs.Uploader.GetLoadFileType() == warehouseutils.LoadFileTypeParquet {
copyStmt = fmt.Sprintf(
`COPY %s
FROM '%s'
ACCESS_KEY_ID '%s'
SECRET_ACCESS_KEY '%s'
SESSION_TOKEN '%s'
MANIFEST FORMAT PARQUET;`,
fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName),
manifestS3Location,
tempAccessKeyId,
tempSecretAccessKey,
token,
)
} else {
copyStmt = fmt.Sprintf(
`COPY %s(%s)
FROM '%s'
CSV GZIP
ACCESS_KEY_ID '%s'
SECRET_ACCESS_KEY '%s'
SESSION_TOKEN '%s'
REGION '%s'
DATEFORMAT 'auto'
TIMEFORMAT 'auto'
MANIFEST TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS
COMPUPDATE OFF
STATUPDATE OFF;`,
fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName),
sortedColumnNames,
manifestS3Location,
tempAccessKeyId,
tempSecretAccessKey,
token,
region,
)
}

result, err := db.ExecContext(ctx, copyStmt)
if err != nil {
return nil, fmt.Errorf("running copy command: %w", normalizeError(err))
}
return result, nil
}

func (rs *Redshift) copyIntoLoadTable(
ctx context.Context,
txn *sqlmiddleware.Tx,
Expand Down Expand Up @@ -1288,12 +1369,12 @@ func (rs *Redshift) TestConnection(ctx context.Context, _ model.Warehouse) error

func (rs *Redshift) Cleanup(ctx context.Context) {
if rs.DB != nil {
err := rs.dropDanglingStagingTables(ctx)
if err != nil {
rs.logger.Errorw("Error dropping dangling staging tables",
logfield.Error, err.Error(),
)
}
//err := rs.dropDanglingStagingTables(ctx)
//if err != nil {
// rs.logger.Errorw("Error dropping dangling staging tables",
// logfield.Error, err.Error(),
// )
//}
_ = rs.DB.Close()
}
}
Expand Down
Loading

0 comments on commit 4bc5c4e

Please sign in to comment.