From a2c5f6fc25940512be3dda9470418bda315ff50b Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Thu, 2 Jan 2025 16:44:04 +0100 Subject: [PATCH 01/15] fetch clustering and partioning info --- pkg/bigquery/db.go | 63 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index a4cac2fdc..5f927df0e 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -198,7 +198,6 @@ func (d *Client) UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipel if asset.Description == "" && (len(asset.Columns) == 0 || !anyColumnHasDescription) { return NoMetadataUpdatedError{} } - tableComponents := strings.Split(asset.Name, ".") if len(tableComponents) != 2 { return fmt.Errorf("asset name must be in schema.table format to update the metadata, '%s' given", asset.Name) @@ -228,7 +227,6 @@ func (d *Client) UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipel if asset.Description != "" { update.Description = asset.Description } - primaryKeys := asset.ColumnNamesWithPrimaryKey() if len(primaryKeys) > 0 { update.TableConstraints = &bigquery.TableConstraints{ @@ -271,3 +269,64 @@ func (d *Client) Ping(ctx context.Context) error { return nil // Return nil if the query runs successfully } + +func (d *Client) CompareTableClusteringAndPartitioning(ctx context.Context, tableName string, asset *pipeline.Asset) (mismatch bool, err error) { + tableComponents := strings.Split(tableName, ".") + if len(tableComponents) != 2 { + err = fmt.Errorf("table name must be in schema.table format, '%s' given", tableName) + return + } + tableRef := d.client.Dataset(tableComponents[0]).Table(tableComponents[1]) + // Fetch table metadata + meta, err := tableRef.Metadata(ctx) + if err != nil { + err = fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err) + return + } + partitioningMismatch := comparePartitioning(meta, asset) + clusteringMismatch := compareClustering(meta, asset) + + // Return whether there's a mismatch + mismatch = partitioningMismatch || clusteringMismatch + return +} + +func comparePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { + if meta.TimePartitioning != nil { + if meta.TimePartitioning.Field != asset.Materialization.PartitionBy { + fmt.Printf("Mismatch in time partitioning field: BigQuery=%s, User=%s\n", meta.TimePartitioning.Field, asset.Materialization.PartitionBy) + return true + } + } else if asset.Materialization.PartitionBy != "" { + fmt.Printf("User provided time partitioning, but none found in BigQuery.\n") + return true + } + if meta.RangePartitioning != nil { + if meta.RangePartitioning.Field != asset.Materialization.PartitionBy { + fmt.Printf("Mismatch in range partitioning field: BigQuery=%s, User=%s\n", meta.RangePartitioning.Field, asset.Materialization.PartitionBy) + return true + } + } else if asset.Materialization.PartitionBy != "" { + fmt.Printf("User provided range partitioning, but none found in BigQuery.\n") + return true + } + + return false +} + +func compareClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { + bigQueryFields := meta.Clustering.Fields + userFields := asset.Materialization.ClusterBy + if len(bigQueryFields) != len(userFields) { + fmt.Printf("Mismatch in clustering fields length: BigQuery=%v, User=%v\n", bigQueryFields, userFields) + return true + } + for i := range bigQueryFields { + if bigQueryFields[i] != userFields[i] { + fmt.Printf("Mismatch in clustering fields: BigQuery=%v, User=%v\n", bigQueryFields, userFields) + return true + } + } + + return false +} From afaa9ebc5ccdb8b7cfd665117a14ef80fa4c1857 Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 00:40:45 +0100 Subject: [PATCH 02/15] added cluster and partioning handling --- pkg/bigquery/db.go | 83 +++++++++++++++++++++++++----------- pkg/bigquery/operator.go | 12 ++++++ pkg/pipeline/materializer.go | 4 ++ 3 files changed, 74 insertions(+), 25 deletions(-) diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index 5f927df0e..0ea8d4ec0 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -31,6 +31,7 @@ type Selector interface { type MetadataUpdater interface { UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipeline.Asset) error + DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) (err error) } type DB interface { @@ -270,7 +271,7 @@ func (d *Client) Ping(ctx context.Context) error { return nil // Return nil if the query runs successfully } -func (d *Client) CompareTableClusteringAndPartitioning(ctx context.Context, tableName string, asset *pipeline.Asset) (mismatch bool, err error) { +func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) (err error) { tableComponents := strings.Split(tableName, ".") if len(tableComponents) != 2 { err = fmt.Errorf("table name must be in schema.table format, '%s' given", tableName) @@ -283,50 +284,82 @@ func (d *Client) CompareTableClusteringAndPartitioning(ctx context.Context, tabl err = fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err) return } - partitioningMismatch := comparePartitioning(meta, asset) - clusteringMismatch := compareClustering(meta, asset) - - // Return whether there's a mismatch - mismatch = partitioningMismatch || clusteringMismatch + // Check if partitioning or clustering exists in metadata + hasPartitioning := meta.TimePartitioning != nil || meta.RangePartitioning != nil + hasClustering := meta.Clustering != nil && len(meta.Clustering.Fields) > 0 + // If neither partitioning nor clustering exists, do nothing + if !hasPartitioning && !hasClustering { + return + } + partitioningMismatch := false + clusteringMismatch := false + if hasPartitioning { + partitioningMismatch = !IsSamePartitioning(meta, asset) + } + if hasClustering { + clusteringMismatch = !IsSameClustering(meta, asset) + } + mismatch := partitioningMismatch || clusteringMismatch + if mismatch { + err = tableRef.Delete(ctx) + if err != nil { + err = fmt.Errorf("failed to delete table '%s': %w", tableName, err) + return + } + fmt.Printf("Table '%s' deleted successfully.\n", tableName) + fmt.Printf("Recreating the table with the new clustering and partitioning strategies...\n") + } return } -func comparePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { +func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { if meta.TimePartitioning != nil { if meta.TimePartitioning.Field != asset.Materialization.PartitionBy { - fmt.Printf("Mismatch in time partitioning field: BigQuery=%s, User=%s\n", meta.TimePartitioning.Field, asset.Materialization.PartitionBy) - return true + fmt.Printf( + "Mismatch detected: Your table has a time partitioning strategy with the field '%s', "+ + "but you are attempting to use the field '%s'. Your table will be dropped and recreated.\n", + meta.TimePartitioning.Field, + asset.Materialization.PartitionBy, + ) + return false } - } else if asset.Materialization.PartitionBy != "" { - fmt.Printf("User provided time partitioning, but none found in BigQuery.\n") - return true } if meta.RangePartitioning != nil { if meta.RangePartitioning.Field != asset.Materialization.PartitionBy { - fmt.Printf("Mismatch in range partitioning field: BigQuery=%s, User=%s\n", meta.RangePartitioning.Field, asset.Materialization.PartitionBy) - return true + fmt.Printf( + "Mismatch detected: Your table has a range partitioning strategy with the field '%s', "+ + "but you are attempting to use the field '%s'. Your table will be dropped and recreated.\n", meta.RangePartitioning.Field, + asset.Materialization.PartitionBy, + ) + return false } - } else if asset.Materialization.PartitionBy != "" { - fmt.Printf("User provided range partitioning, but none found in BigQuery.\n") - return true } - - return false + return true } -func compareClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { +func IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { bigQueryFields := meta.Clustering.Fields userFields := asset.Materialization.ClusterBy + if len(bigQueryFields) != len(userFields) { - fmt.Printf("Mismatch in clustering fields length: BigQuery=%v, User=%v\n", bigQueryFields, userFields) - return true + fmt.Printf( + "Mismatch detected: Your table has %d clustering fields (%v), but you are trying to use %d fields (%v). "+ + "Your table will be dropped and recreated.\n", + len(bigQueryFields), bigQueryFields, len(userFields), userFields, + ) + return false } + for i := range bigQueryFields { if bigQueryFields[i] != userFields[i] { - fmt.Printf("Mismatch in clustering fields: BigQuery=%v, User=%v\n", bigQueryFields, userFields) - return true + fmt.Printf( + "Mismatch detected: Your table is clustered by '%s' at position %d, "+ + "but you are trying to cluster by '%s'. Your table will be dropped and recreated.\n", + bigQueryFields[i], i+1, userFields[i], + ) + return false } } - return false + return true } diff --git a/pkg/bigquery/operator.go b/pkg/bigquery/operator.go index c3015a5ec..d7c3f567a 100644 --- a/pkg/bigquery/operator.go +++ b/pkg/bigquery/operator.go @@ -14,6 +14,7 @@ import ( type materializer interface { Render(task *pipeline.Asset, query string) (string, error) + IsFullRefresh() bool } type queryExtractor interface { @@ -44,6 +45,8 @@ func (o BasicOperator) Run(ctx context.Context, ti scheduler.TaskInstance) error } func (o BasicOperator) RunTask(ctx context.Context, p *pipeline.Pipeline, t *pipeline.Asset) error { + + // Step 2: Extract queries from the task's executable file queries, err := o.extractor.ExtractQueriesFromString(t.ExecutableFile.Content) if err != nil { return errors.Wrap(err, "cannot extract queries from the task file") @@ -57,6 +60,7 @@ func (o BasicOperator) RunTask(ctx context.Context, p *pipeline.Pipeline, t *pip return errors.New("cannot enable materialization for tasks with multiple queries") } + // Step 3: Render materialized query if needed q := queries[0] materialized, err := o.materializer.Render(t, q.String()) if err != nil { @@ -74,6 +78,14 @@ func (o BasicOperator) RunTask(ctx context.Context, p *pipeline.Pipeline, t *pip if err != nil { return err } + za := t.Materialization.Strategy + print(za) + if o.materializer.IsFullRefresh() { + err = conn.DeleteTableIfPartitioningOrClusteringMismatch(ctx, t.Name, t) + if err != nil { + return errors.Wrap(err, "failed to compare clustering and partitioning metadata") + } + } return conn.RunQueryWithoutResult(ctx, q) } diff --git a/pkg/pipeline/materializer.go b/pkg/pipeline/materializer.go index 4e5be4ad7..295f4632e 100644 --- a/pkg/pipeline/materializer.go +++ b/pkg/pipeline/materializer.go @@ -44,3 +44,7 @@ func removeComments(query string) string { newBytes := re.ReplaceAll(bytes, []byte("")) return string(newBytes) } + +func (m *Materializer) IsFullRefresh() bool { + return m.FullRefresh +} From 9ea2898831de9a076bf72f124a8e4579266e1b8a Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 08:09:55 +0100 Subject: [PATCH 03/15] add mockmoterializer function --- pkg/bigquery/operator_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/bigquery/operator_test.go b/pkg/bigquery/operator_test.go index 7b762750a..c7c007f28 100644 --- a/pkg/bigquery/operator_test.go +++ b/pkg/bigquery/operator_test.go @@ -32,6 +32,11 @@ func (m *mockMaterializer) Render(t *pipeline.Asset, query string) (string, erro return res.Get(0).(string), res.Error(1) } +func (m *mockMaterializer) IsFullRefresh() bool { + res := m.Called() + return res.Bool(0) +} + func TestBasicOperator_RunTask(t *testing.T) { t.Parallel() From adf221d8a8c2e208b623f0db17545489f2d4041d Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 08:15:45 +0100 Subject: [PATCH 04/15] add mockquerierwith result --- pkg/bigquery/checks_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pkg/bigquery/checks_test.go b/pkg/bigquery/checks_test.go index 5286fa822..c99439968 100644 --- a/pkg/bigquery/checks_test.go +++ b/pkg/bigquery/checks_test.go @@ -1,6 +1,7 @@ package bigquery import ( + "cloud.google.com/go/bigquery" "context" "fmt" "testing" @@ -58,6 +59,20 @@ func (m *mockQuerierWithResult) SelectWithSchema(ctx context.Context, q *query.Q return result, args.Error(1) } +func (m *mockQuerierWithResult) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error { + args := m.Called(ctx, tableName, asset) + return args.Error(0) +} + +func (m *mockQuerierWithResult) IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { + args := m.Called(meta, asset) + return args.Bool(0) +} + +func (m *mockQuerierWithResult) IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { + args := m.Called(meta, asset) + return args.Bool(0) +} type mockConnectionFetcher struct { mock.Mock From 13810803e049b5553e3b04cb6ae46c5bf4e767da Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 12:21:23 +0100 Subject: [PATCH 05/15] delete unncessary lines --- pkg/bigquery/operator.go | 2 -- pkg/bigquery/operator_test.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/bigquery/operator.go b/pkg/bigquery/operator.go index d7c3f567a..d6e97782f 100644 --- a/pkg/bigquery/operator.go +++ b/pkg/bigquery/operator.go @@ -78,8 +78,6 @@ func (o BasicOperator) RunTask(ctx context.Context, p *pipeline.Pipeline, t *pip if err != nil { return err } - za := t.Materialization.Strategy - print(za) if o.materializer.IsFullRefresh() { err = conn.DeleteTableIfPartitioningOrClusteringMismatch(ctx, t.Name, t) if err != nil { diff --git a/pkg/bigquery/operator_test.go b/pkg/bigquery/operator_test.go index c7c007f28..e0d186752 100644 --- a/pkg/bigquery/operator_test.go +++ b/pkg/bigquery/operator_test.go @@ -176,6 +176,7 @@ func TestBasicOperator_RunTask(t *testing.T) { f.q.On("RunQueryWithoutResult", mock.Anything, &query.Query{Query: "CREATE TABLE x AS select * from users"}). Return(nil) + f.m.On("IsFullRefresh").Return(false) }, args: args{ t: &pipeline.Asset{ @@ -198,7 +199,6 @@ func TestBasicOperator_RunTask(t *testing.T) { mat := new(mockMaterializer) conn := new(mockConnectionFetcher) conn.On("GetBqConnection", "gcp-default").Return(client, nil) - if tt.setup != nil { tt.setup(&fields{ q: client, From 7720646952600a1762794dab5eb26b2bf6c553c5 Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 12:32:05 +0100 Subject: [PATCH 06/15] fix unit tests --- pkg/bigquery/operator_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/bigquery/operator_test.go b/pkg/bigquery/operator_test.go index e0d186752..a01da450b 100644 --- a/pkg/bigquery/operator_test.go +++ b/pkg/bigquery/operator_test.go @@ -197,6 +197,7 @@ func TestBasicOperator_RunTask(t *testing.T) { client := new(mockQuerierWithResult) extractor := new(mockExtractor) mat := new(mockMaterializer) + mat.On("IsFullRefresh").Return(false) conn := new(mockConnectionFetcher) conn.On("GetBqConnection", "gcp-default").Return(client, nil) if tt.setup != nil { From 7992f2170796ed3ff5f29f826f0330fc04ccb20a Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 13:57:01 +0100 Subject: [PATCH 07/15] more changes --- pkg/bigquery/checks_test.go | 7 ++++--- pkg/bigquery/db.go | 34 ++++++++++++++++++++++------------ pkg/bigquery/operator.go | 4 ---- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pkg/bigquery/checks_test.go b/pkg/bigquery/checks_test.go index c99439968..dfe530ad2 100644 --- a/pkg/bigquery/checks_test.go +++ b/pkg/bigquery/checks_test.go @@ -1,17 +1,18 @@ package bigquery import ( - "cloud.google.com/go/bigquery" "context" "fmt" "testing" + "cloud.google.com/go/bigquery" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/bruin-data/bruin/pkg/ansisql" "github.com/bruin-data/bruin/pkg/pipeline" "github.com/bruin-data/bruin/pkg/query" "github.com/bruin-data/bruin/pkg/scheduler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) type mockQuerierWithResult struct { diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index 0ea8d4ec0..6be01b7fa 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -31,13 +31,17 @@ type Selector interface { type MetadataUpdater interface { UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipeline.Asset) error - DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) (err error) +} + +type TableManager interface { + DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error } type DB interface { Querier Selector MetadataUpdater + TableManager } type Client struct { @@ -271,45 +275,51 @@ func (d *Client) Ping(ctx context.Context) error { return nil // Return nil if the query runs successfully } -func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) (err error) { +func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error { tableComponents := strings.Split(tableName, ".") if len(tableComponents) != 2 { - err = fmt.Errorf("table name must be in schema.table format, '%s' given", tableName) - return + return fmt.Errorf("table name must be in schema.table format, '%s' given", tableName) } + tableRef := d.client.Dataset(tableComponents[0]).Table(tableComponents[1]) + // Fetch table metadata meta, err := tableRef.Metadata(ctx) if err != nil { - err = fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err) - return + return fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err) } + // Check if partitioning or clustering exists in metadata hasPartitioning := meta.TimePartitioning != nil || meta.RangePartitioning != nil hasClustering := meta.Clustering != nil && len(meta.Clustering.Fields) > 0 + // If neither partitioning nor clustering exists, do nothing if !hasPartitioning && !hasClustering { - return + return nil } + partitioningMismatch := false clusteringMismatch := false + if hasPartitioning { partitioningMismatch = !IsSamePartitioning(meta, asset) } + if hasClustering { clusteringMismatch = !IsSameClustering(meta, asset) } + mismatch := partitioningMismatch || clusteringMismatch if mismatch { - err = tableRef.Delete(ctx) - if err != nil { - err = fmt.Errorf("failed to delete table '%s': %w", tableName, err) - return + if err := tableRef.Delete(ctx); err != nil { + return fmt.Errorf("failed to delete table '%s': %w", tableName, err) } + fmt.Printf("Table '%s' deleted successfully.\n", tableName) fmt.Printf("Recreating the table with the new clustering and partitioning strategies...\n") } - return + + return nil } func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { diff --git a/pkg/bigquery/operator.go b/pkg/bigquery/operator.go index d6e97782f..9be8ae59f 100644 --- a/pkg/bigquery/operator.go +++ b/pkg/bigquery/operator.go @@ -45,8 +45,6 @@ func (o BasicOperator) Run(ctx context.Context, ti scheduler.TaskInstance) error } func (o BasicOperator) RunTask(ctx context.Context, p *pipeline.Pipeline, t *pipeline.Asset) error { - - // Step 2: Extract queries from the task's executable file queries, err := o.extractor.ExtractQueriesFromString(t.ExecutableFile.Content) if err != nil { return errors.Wrap(err, "cannot extract queries from the task file") @@ -59,8 +57,6 @@ func (o BasicOperator) RunTask(ctx context.Context, p *pipeline.Pipeline, t *pip if len(queries) > 1 && t.Materialization.Type != pipeline.MaterializationTypeNone { return errors.New("cannot enable materialization for tasks with multiple queries") } - - // Step 3: Render materialized query if needed q := queries[0] materialized, err := o.materializer.Render(t, q.String()) if err != nil { From 0e0d564e09df33df9c05eb45ed87c84e9ab63223 Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 14:09:00 +0100 Subject: [PATCH 08/15] fix lint --- pkg/bigquery/checks_test.go | 5 ++--- pkg/bigquery/db.go | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/bigquery/checks_test.go b/pkg/bigquery/checks_test.go index dfe530ad2..837bc113b 100644 --- a/pkg/bigquery/checks_test.go +++ b/pkg/bigquery/checks_test.go @@ -6,13 +6,12 @@ import ( "testing" "cloud.google.com/go/bigquery" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/bruin-data/bruin/pkg/ansisql" "github.com/bruin-data/bruin/pkg/pipeline" "github.com/bruin-data/bruin/pkg/query" "github.com/bruin-data/bruin/pkg/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) type mockQuerierWithResult struct { diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index 6be01b7fa..161a829a2 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -315,7 +315,7 @@ func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Conte return fmt.Errorf("failed to delete table '%s': %w", tableName, err) } - fmt.Printf("Table '%s' deleted successfully.\n", tableName) + fmt.Printf("Table '%s' dropped successfully.\n", tableName) fmt.Printf("Recreating the table with the new clustering and partitioning strategies...\n") } From e519dd65c3074f1526473caece92c75404b02e7d Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 15:32:37 +0100 Subject: [PATCH 09/15] added unit tests --- pkg/bigquery/db.go | 34 +++-- pkg/bigquery/db_test.go | 276 ++++++++++++++++++++++++++-------------- 2 files changed, 204 insertions(+), 106 deletions(-) diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index 161a829a2..9db3998d4 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -190,6 +190,24 @@ func (m NoMetadataUpdatedError) Error() string { return "no metadata found for the given asset to be pushed to BigQuery" } +func (d *Client) getTableRef(tableName string) (*bigquery.Table, error) { + tableComponents := strings.Split(tableName, ".") + + // Check for empty components + for _, component := range tableComponents { + if component == "" { + return nil, fmt.Errorf("table name must be in dataset.table or project.dataset.table format, '%s' given", tableName) + } + } + + if len(tableComponents) == 3 { + return d.client.DatasetInProject(tableComponents[0], tableComponents[1]).Table(tableComponents[2]), nil + } else if len(tableComponents) == 2 { + return d.client.Dataset(tableComponents[0]).Table(tableComponents[1]), nil + } + return nil, fmt.Errorf("table name must be in dataset.table or project.dataset.table format, '%s' given", tableName) +} + func (d *Client) UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipeline.Asset) error { anyColumnHasDescription := false colsByName := make(map[string]*pipeline.Column, len(asset.Columns)) @@ -203,12 +221,10 @@ func (d *Client) UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipel if asset.Description == "" && (len(asset.Columns) == 0 || !anyColumnHasDescription) { return NoMetadataUpdatedError{} } - tableComponents := strings.Split(asset.Name, ".") - if len(tableComponents) != 2 { - return fmt.Errorf("asset name must be in schema.table format to update the metadata, '%s' given", asset.Name) + tableRef, err := d.getTableRef(asset.Name) + if err != nil { + return err } - - tableRef := d.client.Dataset(tableComponents[0]).Table(tableComponents[1]) meta, err := tableRef.Metadata(ctx) if err != nil { return err @@ -276,13 +292,11 @@ func (d *Client) Ping(ctx context.Context) error { } func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error { - tableComponents := strings.Split(tableName, ".") - if len(tableComponents) != 2 { - return fmt.Errorf("table name must be in schema.table format, '%s' given", tableName) + tableRef, err := d.getTableRef(tableName) + if err != nil { + return err } - tableRef := d.client.Dataset(tableComponents[0]).Table(tableComponents[1]) - // Fetch table metadata meta, err := tableRef.Metadata(ctx) if err != nil { diff --git a/pkg/bigquery/db_test.go b/pkg/bigquery/db_test.go index 438387bcf..5481bc296 100644 --- a/pkg/bigquery/db_test.go +++ b/pkg/bigquery/db_test.go @@ -22,6 +22,10 @@ import ( "google.golang.org/api/option" ) +const ( + testProjectID = "test-project" +) + func TestDB_IsValid(t *testing.T) { t.Parallel() @@ -150,7 +154,7 @@ func TestDB_IsValid(t *testing.T) { func TestDB_RunQueryWithoutResult(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID jobID := "test-job" tests := []struct { @@ -257,58 +261,10 @@ func TestDB_RunQueryWithoutResult(t *testing.T) { } } -type jobSubmitResponse struct { - response any - statusCode int -} - -type queryResultResponse struct { - response *bigquery2.GetQueryResultsResponse - statusCode int -} - -func mockBqHandler(t *testing.T, projectID, jobID string, jsr jobSubmitResponse, qrr queryResultResponse) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries/%s?", projectID, jobID)) { - w.WriteHeader(qrr.statusCode) - - response, err := json.Marshal(qrr.response) - if err != nil { - t.Fatal(err) - } - - _, err = w.Write(response) - if err != nil { - t.Fatal(err) - } - return - } else if r.Method == http.MethodPost && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries", projectID)) { - w.WriteHeader(jsr.statusCode) - - response, err := json.Marshal(jsr.response) - if err != nil { - t.Fatal(err) - } // Updated error handling - - _, err = w.Write(response) - if err != nil { - t.Fatal(err) - } // Updated error handling - return - } - - w.WriteHeader(http.StatusInternalServerError) - _, err := w.Write([]byte("there is no test definition found for the given request: " + r.Method + " " + r.RequestURI)) - if err != nil { - t.Fatal(err) - } // Updated error handling - }) -} - func TestDB_Select(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID jobID := "test-job" tests := []struct { @@ -378,50 +334,6 @@ func TestDB_Select(t *testing.T) { JobId: "job-id", }, JobComplete: true, - Schema: &bigquery2.TableSchema{ - Fields: []*bigquery2.TableFieldSchema{ - { - Name: "first_name", - Type: "STRING", - }, - { - Name: "last_name", - Type: "STRING", - }, - { - Name: "age", - Type: "INTEGER", - }, - }, - }, - Rows: []*bigquery2.TableRow{ - { - F: []*bigquery2.TableCell{ - { - V: "jane", - }, - { - V: "doe", - }, - { - V: "30", - }, - }, - }, - { - F: []*bigquery2.TableCell{ - { - V: "joe", - }, - { - V: "doe", - }, - { - V: "28", - }, - }, - }, - }, }, statusCode: http.StatusOK, }, @@ -469,7 +381,7 @@ func TestDB_Select(t *testing.T) { func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID schema := "myschema" table := "mytable" assetName := fmt.Sprintf("%s.%s", schema, table) @@ -651,7 +563,7 @@ func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { func TestDB_SelectWithSchema(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID jobID := "test-job" tests := []struct { @@ -810,3 +722,175 @@ func TestDB_SelectWithSchema(t *testing.T) { }) } } + +func TestClient_getTableRef(t *testing.T) { + t.Parallel() + + projectID := testProjectID + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + tests := []struct { + name string + tableName string + wantErr bool + errContains string + }{ + { + name: "valid two-part table name", + tableName: "dataset.table", + wantErr: false, + }, + { + name: "valid three-part table name", + tableName: "project.dataset.table", + wantErr: false, + }, + { + name: "invalid one-part table name", + tableName: "table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid four-part table name", + tableName: "a.b.c.d", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "empty table name", + tableName: "", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid trailing dot", + tableName: "dataset.table.", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid leading dot", + tableName: ".dataset.table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid consecutive dots", + tableName: "project..table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "only dots", + tableName: "..", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "three dots", + tableName: "...", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client, err := bigquery.NewClient( + context.Background(), + projectID, + option.WithEndpoint(srv.URL), + option.WithCredentials(&google.Credentials{ + ProjectID: projectID, + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "some-token", + }), + }), + ) + require.NoError(t, err) + + d := Client{client: client} + + tableRef, err := d.getTableRef(tt.tableName) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + assert.Nil(t, tableRef) + return + } + + require.NoError(t, err) + require.NotNil(t, tableRef) + + // For two-part names, verify the table and dataset + if strings.Count(tt.tableName, ".") == 1 { + parts := strings.Split(tt.tableName, ".") + assert.Equal(t, parts[0], tableRef.DatasetID) + assert.Equal(t, parts[1], tableRef.TableID) + assert.Equal(t, projectID, tableRef.ProjectID) + } + + // For three-part names, verify project, dataset and table + if strings.Count(tt.tableName, ".") == 2 { + parts := strings.Split(tt.tableName, ".") + assert.Equal(t, parts[0], tableRef.ProjectID) + assert.Equal(t, parts[1], tableRef.DatasetID) + assert.Equal(t, parts[2], tableRef.TableID) + } + }) + } +} + +func mockBqHandler(t *testing.T, projectID, jobID string, jsr jobSubmitResponse, qrr queryResultResponse) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries/%s?", projectID, jobID)) { + w.WriteHeader(qrr.statusCode) + + response, err := json.Marshal(qrr.response) + if err != nil { + t.Fatal(err) + } + + _, err = w.Write(response) + if err != nil { + t.Fatal(err) + } + return + } else if r.Method == http.MethodPost && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries", projectID)) { + w.WriteHeader(jsr.statusCode) + + response, err := json.Marshal(jsr.response) + if err != nil { + t.Fatal(err) + } // Updated error handling + + _, err = w.Write(response) + if err != nil { + t.Fatal(err) + } // Updated error handling + return + } + + w.WriteHeader(http.StatusInternalServerError) + _, err := w.Write([]byte("there is no test definition found for the given request: " + r.Method + " " + r.RequestURI)) + if err != nil { + t.Fatal(err) + } // Updated error handling + }) +} + +type jobSubmitResponse struct { + response any + statusCode int +} + +type queryResultResponse struct { + response *bigquery2.GetQueryResultsResponse + statusCode int +} From 4e5aaf329602646a63ebcf5b13291d507c5c629e Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 15:44:46 +0100 Subject: [PATCH 10/15] Reverted pkg/bigquery/db_test.go to its state in commit 1f648d3b9aeb0fec7ecd95083ff87a99897e6db8 --- pkg/bigquery/db_test.go | 276 ++++++++++++++-------------------------- 1 file changed, 96 insertions(+), 180 deletions(-) diff --git a/pkg/bigquery/db_test.go b/pkg/bigquery/db_test.go index 5481bc296..438387bcf 100644 --- a/pkg/bigquery/db_test.go +++ b/pkg/bigquery/db_test.go @@ -22,10 +22,6 @@ import ( "google.golang.org/api/option" ) -const ( - testProjectID = "test-project" -) - func TestDB_IsValid(t *testing.T) { t.Parallel() @@ -154,7 +150,7 @@ func TestDB_IsValid(t *testing.T) { func TestDB_RunQueryWithoutResult(t *testing.T) { t.Parallel() - projectID := testProjectID + projectID := "test-project" jobID := "test-job" tests := []struct { @@ -261,10 +257,58 @@ func TestDB_RunQueryWithoutResult(t *testing.T) { } } +type jobSubmitResponse struct { + response any + statusCode int +} + +type queryResultResponse struct { + response *bigquery2.GetQueryResultsResponse + statusCode int +} + +func mockBqHandler(t *testing.T, projectID, jobID string, jsr jobSubmitResponse, qrr queryResultResponse) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries/%s?", projectID, jobID)) { + w.WriteHeader(qrr.statusCode) + + response, err := json.Marshal(qrr.response) + if err != nil { + t.Fatal(err) + } + + _, err = w.Write(response) + if err != nil { + t.Fatal(err) + } + return + } else if r.Method == http.MethodPost && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries", projectID)) { + w.WriteHeader(jsr.statusCode) + + response, err := json.Marshal(jsr.response) + if err != nil { + t.Fatal(err) + } // Updated error handling + + _, err = w.Write(response) + if err != nil { + t.Fatal(err) + } // Updated error handling + return + } + + w.WriteHeader(http.StatusInternalServerError) + _, err := w.Write([]byte("there is no test definition found for the given request: " + r.Method + " " + r.RequestURI)) + if err != nil { + t.Fatal(err) + } // Updated error handling + }) +} + func TestDB_Select(t *testing.T) { t.Parallel() - projectID := testProjectID + projectID := "test-project" jobID := "test-job" tests := []struct { @@ -334,6 +378,50 @@ func TestDB_Select(t *testing.T) { JobId: "job-id", }, JobComplete: true, + Schema: &bigquery2.TableSchema{ + Fields: []*bigquery2.TableFieldSchema{ + { + Name: "first_name", + Type: "STRING", + }, + { + Name: "last_name", + Type: "STRING", + }, + { + Name: "age", + Type: "INTEGER", + }, + }, + }, + Rows: []*bigquery2.TableRow{ + { + F: []*bigquery2.TableCell{ + { + V: "jane", + }, + { + V: "doe", + }, + { + V: "30", + }, + }, + }, + { + F: []*bigquery2.TableCell{ + { + V: "joe", + }, + { + V: "doe", + }, + { + V: "28", + }, + }, + }, + }, }, statusCode: http.StatusOK, }, @@ -381,7 +469,7 @@ func TestDB_Select(t *testing.T) { func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { t.Parallel() - projectID := testProjectID + projectID := "test-project" schema := "myschema" table := "mytable" assetName := fmt.Sprintf("%s.%s", schema, table) @@ -563,7 +651,7 @@ func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { func TestDB_SelectWithSchema(t *testing.T) { t.Parallel() - projectID := testProjectID + projectID := "test-project" jobID := "test-job" tests := []struct { @@ -722,175 +810,3 @@ func TestDB_SelectWithSchema(t *testing.T) { }) } } - -func TestClient_getTableRef(t *testing.T) { - t.Parallel() - - projectID := testProjectID - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer srv.Close() - - tests := []struct { - name string - tableName string - wantErr bool - errContains string - }{ - { - name: "valid two-part table name", - tableName: "dataset.table", - wantErr: false, - }, - { - name: "valid three-part table name", - tableName: "project.dataset.table", - wantErr: false, - }, - { - name: "invalid one-part table name", - tableName: "table", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "invalid four-part table name", - tableName: "a.b.c.d", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "empty table name", - tableName: "", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "invalid trailing dot", - tableName: "dataset.table.", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "invalid leading dot", - tableName: ".dataset.table", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "invalid consecutive dots", - tableName: "project..table", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "only dots", - tableName: "..", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - { - name: "three dots", - tableName: "...", - wantErr: true, - errContains: "must be in dataset.table or project.dataset.table format", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - client, err := bigquery.NewClient( - context.Background(), - projectID, - option.WithEndpoint(srv.URL), - option.WithCredentials(&google.Credentials{ - ProjectID: projectID, - TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ - AccessToken: "some-token", - }), - }), - ) - require.NoError(t, err) - - d := Client{client: client} - - tableRef, err := d.getTableRef(tt.tableName) - if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) - assert.Nil(t, tableRef) - return - } - - require.NoError(t, err) - require.NotNil(t, tableRef) - - // For two-part names, verify the table and dataset - if strings.Count(tt.tableName, ".") == 1 { - parts := strings.Split(tt.tableName, ".") - assert.Equal(t, parts[0], tableRef.DatasetID) - assert.Equal(t, parts[1], tableRef.TableID) - assert.Equal(t, projectID, tableRef.ProjectID) - } - - // For three-part names, verify project, dataset and table - if strings.Count(tt.tableName, ".") == 2 { - parts := strings.Split(tt.tableName, ".") - assert.Equal(t, parts[0], tableRef.ProjectID) - assert.Equal(t, parts[1], tableRef.DatasetID) - assert.Equal(t, parts[2], tableRef.TableID) - } - }) - } -} - -func mockBqHandler(t *testing.T, projectID, jobID string, jsr jobSubmitResponse, qrr queryResultResponse) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries/%s?", projectID, jobID)) { - w.WriteHeader(qrr.statusCode) - - response, err := json.Marshal(qrr.response) - if err != nil { - t.Fatal(err) - } - - _, err = w.Write(response) - if err != nil { - t.Fatal(err) - } - return - } else if r.Method == http.MethodPost && strings.HasPrefix(r.RequestURI, fmt.Sprintf("/projects/%s/queries", projectID)) { - w.WriteHeader(jsr.statusCode) - - response, err := json.Marshal(jsr.response) - if err != nil { - t.Fatal(err) - } // Updated error handling - - _, err = w.Write(response) - if err != nil { - t.Fatal(err) - } // Updated error handling - return - } - - w.WriteHeader(http.StatusInternalServerError) - _, err := w.Write([]byte("there is no test definition found for the given request: " + r.Method + " " + r.RequestURI)) - if err != nil { - t.Fatal(err) - } // Updated error handling - }) -} - -type jobSubmitResponse struct { - response any - statusCode int -} - -type queryResultResponse struct { - response *bigquery2.GetQueryResultsResponse - statusCode int -} From dc27dabb939425965c91689855173b53e223cebe Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 16:33:06 +0100 Subject: [PATCH 11/15] added unit tests --- pkg/bigquery/db.go | 2 +- pkg/bigquery/db_test.go | 430 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 431 insertions(+), 1 deletion(-) diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index 9db3998d4..4d5ec1d61 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -203,7 +203,7 @@ func (d *Client) getTableRef(tableName string) (*bigquery.Table, error) { if len(tableComponents) == 3 { return d.client.DatasetInProject(tableComponents[0], tableComponents[1]).Table(tableComponents[2]), nil } else if len(tableComponents) == 2 { - return d.client.Dataset(tableComponents[0]).Table(tableComponents[1]), nil + return d.client.DatasetInProject(d.config.ProjectID, tableComponents[0]).Table(tableComponents[1]), nil } return nil, fmt.Errorf("table name must be in dataset.table or project.dataset.table format, '%s' given", tableName) } diff --git a/pkg/bigquery/db_test.go b/pkg/bigquery/db_test.go index 438387bcf..36e11358c 100644 --- a/pkg/bigquery/db_test.go +++ b/pkg/bigquery/db_test.go @@ -22,6 +22,8 @@ import ( "google.golang.org/api/option" ) +const testProjectID = "test-project" + func TestDB_IsValid(t *testing.T) { t.Parallel() @@ -810,3 +812,431 @@ func TestDB_SelectWithSchema(t *testing.T) { }) } } + +func TestClient_getTableRef(t *testing.T) { + t.Parallel() + + projectID := testProjectID + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + tests := []struct { + name string + tableName string + wantErr bool + errContains string + }{ + { + name: "valid two-part table name", + tableName: "dataset.table", + wantErr: false, + }, + { + name: "valid three-part table name", + tableName: "project.dataset.table", + wantErr: false, + }, + { + name: "invalid one-part table name", + tableName: "table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid four-part table name", + tableName: "a.b.c.d", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "empty table name", + tableName: "", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid trailing dot", + tableName: "dataset.table.", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid leading dot", + tableName: ".dataset.table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid consecutive dots", + tableName: "project..table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "only dots", + tableName: "..", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "three dots", + tableName: "...", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client, err := bigquery.NewClient( + context.Background(), + projectID, + option.WithEndpoint(srv.URL), + option.WithCredentials(&google.Credentials{ + ProjectID: projectID, + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "some-token", + }), + }), + ) + require.NoError(t, err) + + d := Client{client: client} + + tableRef, err := d.getTableRef(tt.tableName) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + assert.Nil(t, tableRef) + return + } + + require.NoError(t, err) + require.NotNil(t, tableRef) + + // For two-part names, verify the table and dataset + if strings.Count(tt.tableName, ".") == 1 { + parts := strings.Split(tt.tableName, ".") + assert.Equal(t, parts[0], tableRef.DatasetID) + assert.Equal(t, parts[1], tableRef.TableID) + assert.Equal(t, projectID, tableRef.ProjectID) + } + + // For three-part names, verify project, dataset and table + if strings.Count(tt.tableName, ".") == 2 { + parts := strings.Split(tt.tableName, ".") + assert.Equal(t, parts[0], tableRef.ProjectID) + assert.Equal(t, parts[1], tableRef.DatasetID) + assert.Equal(t, parts[2], tableRef.TableID) + } + }) + } +} + +func TestClient_getTableRef_TableNameValidation(t *testing.T) { + t.Parallel() + + projectID := "test-project" + client := &Client{ + client: &bigquery.Client{}, + config: &Config{ + ProjectID: projectID, + }, + } + + tests := []struct { + name string + tableName string + wantErr bool + errContains string + }{ + { + name: "valid two-part table name", + tableName: "dataset.table", + wantErr: false, + }, + { + name: "valid three-part table name", + tableName: "project.dataset.table", + wantErr: false, + }, + { + name: "invalid one-part table name", + tableName: "table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid four-part table name", + tableName: "a.b.c.d", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "empty table name", + tableName: "", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid trailing dot", + tableName: "dataset.table.", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid leading dot", + tableName: ".dataset.table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "invalid consecutive dots", + tableName: "project..table", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "only dots", + tableName: "..", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + { + name: "three dots", + tableName: "...", + wantErr: true, + errContains: "must be in dataset.table or project.dataset.table format", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + tableRef, err := client.getTableRef(tt.tableName) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + assert.Nil(t, tableRef) + return + } + + require.NoError(t, err) + require.NotNil(t, tableRef) + + // For two-part names, verify the table and dataset + if strings.Count(tt.tableName, ".") == 1 { + parts := strings.Split(tt.tableName, ".") + assert.Equal(t, parts[0], tableRef.DatasetID) + assert.Equal(t, parts[1], tableRef.TableID) + assert.Equal(t, projectID, tableRef.ProjectID) + } + + // For three-part names, verify project, dataset and table + if strings.Count(tt.tableName, ".") == 2 { + parts := strings.Split(tt.tableName, ".") + assert.Equal(t, parts[0], tableRef.ProjectID) + assert.Equal(t, parts[1], tableRef.DatasetID) + assert.Equal(t, parts[2], tableRef.TableID) + } + }) + } +} + +func TestIsSamePartitioning(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + meta *bigquery.TableMetadata + asset *pipeline.Asset + expected bool + }{ + { + name: "matching time partitioning", + meta: &bigquery.TableMetadata{ + TimePartitioning: &bigquery.TimePartitioning{ + Field: "date_field", + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + PartitionBy: "date_field", + }, + }, + expected: true, + }, + { + name: "mismatched time partitioning", + meta: &bigquery.TableMetadata{ + TimePartitioning: &bigquery.TimePartitioning{ + Field: "date_field", + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + PartitionBy: "other_field", + }, + }, + expected: false, + }, + { + name: "matching range partitioning", + meta: &bigquery.TableMetadata{ + RangePartitioning: &bigquery.RangePartitioning{ + Field: "id_field", + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + PartitionBy: "id_field", + }, + }, + expected: true, + }, + { + name: "mismatched range partitioning", + meta: &bigquery.TableMetadata{ + RangePartitioning: &bigquery.RangePartitioning{ + Field: "id_field", + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + PartitionBy: "other_field", + }, + }, + expected: false, + }, + { + name: "no partitioning in metadata", + meta: &bigquery.TableMetadata{}, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + PartitionBy: "some_field", + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := IsSamePartitioning(tt.meta, tt.asset) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsSameClustering(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + meta *bigquery.TableMetadata + asset *pipeline.Asset + expected bool + }{ + { + name: "matching single field clustering", + meta: &bigquery.TableMetadata{ + Clustering: &bigquery.Clustering{ + Fields: []string{"field1"}, + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + ClusterBy: []string{"field1"}, + }, + }, + expected: true, + }, + { + name: "matching multiple fields clustering", + meta: &bigquery.TableMetadata{ + Clustering: &bigquery.Clustering{ + Fields: []string{"field1", "field2", "field3"}, + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + ClusterBy: []string{"field1", "field2", "field3"}, + }, + }, + expected: true, + }, + { + name: "different number of clustering fields", + meta: &bigquery.TableMetadata{ + Clustering: &bigquery.Clustering{ + Fields: []string{"field1", "field2"}, + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + ClusterBy: []string{"field1"}, + }, + }, + expected: false, + }, + { + name: "different field order", + meta: &bigquery.TableMetadata{ + Clustering: &bigquery.Clustering{ + Fields: []string{"field1", "field2"}, + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + ClusterBy: []string{"field2", "field1"}, + }, + }, + expected: false, + }, + { + name: "different field names", + meta: &bigquery.TableMetadata{ + Clustering: &bigquery.Clustering{ + Fields: []string{"field1", "field2"}, + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + ClusterBy: []string{"field1", "field3"}, + }, + }, + expected: false, + }, + { + name: "empty clustering fields in both", + meta: &bigquery.TableMetadata{ + Clustering: &bigquery.Clustering{ + Fields: []string{}, + }, + }, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + ClusterBy: []string{}, + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := IsSameClustering(tt.meta, tt.asset) + assert.Equal(t, tt.expected, result) + }) + } +} From 813009aa423903110ae268172802bd9e0ed60d60 Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 16:46:12 +0100 Subject: [PATCH 12/15] fix tests --- pkg/bigquery/db_test.go | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/pkg/bigquery/db_test.go b/pkg/bigquery/db_test.go index 36e11358c..225aa82b8 100644 --- a/pkg/bigquery/db_test.go +++ b/pkg/bigquery/db_test.go @@ -123,10 +123,10 @@ func TestDB_IsValid(t *testing.T) { client, err := bigquery.NewClient( context.Background(), - "some-project-id", + testProjectID, option.WithEndpoint(server.URL), option.WithCredentials(&google.Credentials{ - ProjectID: "some-project-id", + ProjectID: testProjectID, TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ AccessToken: "some-token", }), @@ -152,7 +152,7 @@ func TestDB_IsValid(t *testing.T) { func TestDB_RunQueryWithoutResult(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID jobID := "test-job" tests := []struct { @@ -310,7 +310,7 @@ func mockBqHandler(t *testing.T, projectID, jobID string, jsr jobSubmitResponse, func TestDB_Select(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID jobID := "test-job" tests := []struct { @@ -471,7 +471,7 @@ func TestDB_Select(t *testing.T) { func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID schema := "myschema" table := "mytable" assetName := fmt.Sprintf("%s.%s", schema, table) @@ -638,7 +638,12 @@ func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { require.NoError(t, err) client.Location = "US" - d := Client{client: client} + d := Client{ + client: client, + config: &Config{ + ProjectID: projectID, + }, + } err = d.UpdateTableMetadataIfNotExist(context.Background(), tt.asset) if tt.err == nil { @@ -653,7 +658,7 @@ func TestDB_UpdateTableMetadataIfNotExists(t *testing.T) { func TestDB_SelectWithSchema(t *testing.T) { t.Parallel() - projectID := "test-project" + projectID := testProjectID jobID := "test-job" tests := []struct { @@ -905,7 +910,12 @@ func TestClient_getTableRef(t *testing.T) { ) require.NoError(t, err) - d := Client{client: client} + d := Client{ + client: client, + config: &Config{ + ProjectID: projectID, + }, + } tableRef, err := d.getTableRef(tt.tableName) if tt.wantErr { @@ -1015,7 +1025,6 @@ func TestClient_getTableRef_TableNameValidation(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() @@ -1127,7 +1136,6 @@ func TestIsSamePartitioning(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() result := IsSamePartitioning(tt.meta, tt.asset) @@ -1232,7 +1240,6 @@ func TestIsSameClustering(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() result := IsSameClustering(tt.meta, tt.asset) From 55c89824ed84886e14280002aa8d5ab70c4f35f1 Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 17:21:26 +0100 Subject: [PATCH 13/15] change error messagws --- pkg/bigquery/db.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index 4d5ec1d61..d25aed516 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -351,8 +351,8 @@ func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) boo if meta.RangePartitioning != nil { if meta.RangePartitioning.Field != asset.Materialization.PartitionBy { fmt.Printf( - "Mismatch detected: Your table has a range partitioning strategy with the field '%s', "+ - "but you are attempting to use the field '%s'. Your table will be dropped and recreated.\n", meta.RangePartitioning.Field, + "Mismatch detected: Your table has a range partitioning strategy with the field '%s',"+ + "but you are attempting to use the field '%s'.Your table will be dropped and recreated.\n", meta.RangePartitioning.Field, asset.Materialization.PartitionBy, ) return false @@ -367,9 +367,9 @@ func IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool if len(bigQueryFields) != len(userFields) { fmt.Printf( - "Mismatch detected: Your table has %d clustering fields (%v), but you are trying to use %d fields (%v). "+ + "Mismatch detected: Your table has the clustering fields (%v), but you are trying to use the fields (%v). "+ "Your table will be dropped and recreated.\n", - len(bigQueryFields), bigQueryFields, len(userFields), userFields, + bigQueryFields, userFields, ) return false } From 212dbd518e1f80c88d1db185e1d30c2169376759 Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 18:09:45 +0100 Subject: [PATCH 14/15] handle when meta has zero clusters and partitions --- pkg/bigquery/db.go | 76 ++++++++++++++++++++++++----------------- pkg/bigquery/db_test.go | 12 ++++++- 2 files changed, 55 insertions(+), 33 deletions(-) diff --git a/pkg/bigquery/db.go b/pkg/bigquery/db.go index d25aed516..6bf4bdc58 100644 --- a/pkg/bigquery/db.go +++ b/pkg/bigquery/db.go @@ -303,45 +303,43 @@ func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Conte return fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err) } - // Check if partitioning or clustering exists in metadata - hasPartitioning := meta.TimePartitioning != nil || meta.RangePartitioning != nil - hasClustering := meta.Clustering != nil && len(meta.Clustering.Fields) > 0 - - // If neither partitioning nor clustering exists, do nothing - if !hasPartitioning && !hasClustering { - return nil - } - - partitioningMismatch := false - clusteringMismatch := false - - if hasPartitioning { - partitioningMismatch = !IsSamePartitioning(meta, asset) - } - - if hasClustering { - clusteringMismatch = !IsSameClustering(meta, asset) - } - - mismatch := partitioningMismatch || clusteringMismatch - if mismatch { - if err := tableRef.Delete(ctx); err != nil { - return fmt.Errorf("failed to delete table '%s': %w", tableName, err) + // Check if partitioning or clustering exists in metadata or is wanted by asset + if meta.TimePartitioning != nil || meta.RangePartitioning != nil || asset.Materialization.PartitionBy != "" || len(asset.Materialization.ClusterBy) > 0 { + if !IsSamePartitioning(meta, asset) || !IsSameClustering(meta, asset) { + if err := tableRef.Delete(ctx); err != nil { + return fmt.Errorf("failed to delete table '%s': %w", tableName, err) + } + fmt.Printf("Your table will be dropped and recreated:\n") + fmt.Printf("Table '%s' dropped successfully.\n", tableName) + fmt.Printf("Recreating the table with the new clustering and partitioning strategies...\n") } - - fmt.Printf("Table '%s' dropped successfully.\n", tableName) - fmt.Printf("Recreating the table with the new clustering and partitioning strategies...\n") } return nil } func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { + // If asset wants partitioning but table has none + if asset.Materialization.PartitionBy != "" && + meta.TimePartitioning == nil && + meta.RangePartitioning == nil { + fmt.Printf( + "Mismatch detected: Your table has no partitioning, but you are attempting to partition by '%s'.\n", + asset.Materialization.PartitionBy, + ) + return false + } + + // Safe to proceed only if table has any partitioning + if meta.TimePartitioning == nil && meta.RangePartitioning == nil { + return true + } + if meta.TimePartitioning != nil { if meta.TimePartitioning.Field != asset.Materialization.PartitionBy { fmt.Printf( "Mismatch detected: Your table has a time partitioning strategy with the field '%s', "+ - "but you are attempting to use the field '%s'. Your table will be dropped and recreated.\n", + "but you are attempting to use the field '%s'\n", meta.TimePartitioning.Field, asset.Materialization.PartitionBy, ) @@ -352,7 +350,7 @@ func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) boo if meta.RangePartitioning.Field != asset.Materialization.PartitionBy { fmt.Printf( "Mismatch detected: Your table has a range partitioning strategy with the field '%s',"+ - "but you are attempting to use the field '%s'.Your table will be dropped and recreated.\n", meta.RangePartitioning.Field, + "but you are attempting to use the field '%s'.\n", meta.RangePartitioning.Field, asset.Materialization.PartitionBy, ) return false @@ -362,13 +360,27 @@ func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) boo } func IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool { + // If asset wants clustering but table has none + if len(asset.Materialization.ClusterBy) > 0 && + (meta.Clustering == nil || len(meta.Clustering.Fields) == 0) { + fmt.Printf( + "Mismatch detected: Your table has no clustering, but you are attempting to cluster by %v.\n", + asset.Materialization.ClusterBy, + ) + return false + } + + // Safe to proceed only if table has clustering + if meta.Clustering == nil { + return true + } + bigQueryFields := meta.Clustering.Fields userFields := asset.Materialization.ClusterBy if len(bigQueryFields) != len(userFields) { fmt.Printf( - "Mismatch detected: Your table has the clustering fields (%v), but you are trying to use the fields (%v). "+ - "Your table will be dropped and recreated.\n", + "Mismatch detected: Your table has the clustering fields (%v), but you are trying to use the fields (%v).\n", bigQueryFields, userFields, ) return false @@ -378,7 +390,7 @@ func IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool if bigQueryFields[i] != userFields[i] { fmt.Printf( "Mismatch detected: Your table is clustered by '%s' at position %d, "+ - "but you are trying to cluster by '%s'. Your table will be dropped and recreated.\n", + "but you are trying to cluster by '%s'.\n", bigQueryFields[i], i+1, userFields[i], ) return false diff --git a/pkg/bigquery/db_test.go b/pkg/bigquery/db_test.go index 225aa82b8..78dfd7054 100644 --- a/pkg/bigquery/db_test.go +++ b/pkg/bigquery/db_test.go @@ -1124,13 +1124,23 @@ func TestIsSamePartitioning(t *testing.T) { expected: false, }, { - name: "no partitioning in metadata", + name: "no partitioning in metadata but asset wants it", meta: &bigquery.TableMetadata{}, asset: &pipeline.Asset{ Materialization: pipeline.Materialization{ PartitionBy: "some_field", }, }, + expected: false, + }, + { + name: "no partitioning in metadata and asset doesn't want it", + meta: &bigquery.TableMetadata{}, + asset: &pipeline.Asset{ + Materialization: pipeline.Materialization{ + PartitionBy: "", + }, + }, expected: true, }, } From 5e202abd9d21c42035d2085db0872bf3e46834fd Mon Sep 17 00:00:00 2001 From: Baris Terzioglu Date: Fri, 3 Jan 2025 18:34:15 +0100 Subject: [PATCH 15/15] add gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2af1a944f..276373e71 100644 --- a/.gitignore +++ b/.gitignore @@ -146,4 +146,5 @@ integration-tests/logs integration-tests/bruin !integration-tests/.bruin.yml -venv \ No newline at end of file +venv +logs/runs \ No newline at end of file