Skip to content

Commit

Permalink
[+] added transactions for postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
dimacgka committed Dec 5, 2024
1 parent c378583 commit 3ea1d72
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,34 +291,54 @@ func (p *Postgres) runStatement(statement []byte) error {
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}

query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}

tx, err := p.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}

if _, err := tx.ExecContext(ctx, query); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("failed to rollback transaction after error: %v, rollback error: %v", err, rollbackErr)
}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
return handleError(err, query, statement)
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}

return nil
}

func handleError(err error, query string, statement []byte) error {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}

message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
}

func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
Expand Down

0 comments on commit 3ea1d72

Please sign in to comment.