Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): workaround snowflake metadata-only lim…
Browse files Browse the repository at this point in the history
…itations (apache#1790)

Workaround to fix apache#1454 until snowflake addresses
snowflakedb/gosnowflake#1110 with a better
solution (hopefully by having the server actually return Arrow...)
  • Loading branch information
zeroshade authored and cocoa-xu committed May 8, 2024
1 parent a0fe48f commit 32a27f1
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 9 deletions.
25 changes: 25 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2006,3 +2006,28 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() {
defer os.Remove(binKey)
verifyKey(binKey)
}

func (suite *SnowflakeTests) TestMetadataOnlyQuery() {
// force more than one chunk for `SHOW FUNCTIONS` which will return
// JSON data instead of arrow, even though we ask for Arrow
suite.Require().NoError(suite.stmt.SetSqlQuery(`ALTER SESSION SET CLIENT_RESULT_CHUNK_SIZE = 50`))
_, err := suite.stmt.ExecuteUpdate(suite.ctx)
suite.Require().NoError(err)

// since we lowered the CLIENT_RESULT_CHUNK_SIZE this will return at least
// 1 chunk in addition to the first one. Metadata queries will return JSON
// no matter what currently.
suite.Require().NoError(suite.stmt.SetSqlQuery(`SHOW FUNCTIONS`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

recv := int64(0)
for rdr.Next() {
recv += rdr.Record().NumRows()
}

// verify that we got the exepected number of rows if we sum up
// all the rows from each record in the stream.
suite.Equal(n, recv)
}
92 changes: 83 additions & 9 deletions go/adbc/driver/snowflake/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package snowflake

import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"strconv"
"strings"
Expand Down Expand Up @@ -300,7 +303,7 @@ func integerToDecimal128(ctx context.Context, a arrow.Array, dt *arrow.Decimal12
return result, err
}

func rowTypesToArrowSchema(ctx context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) {
func rowTypesToArrowSchema(_ context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) {
var loc *time.Location

metadata := ld.RowTypes()
Expand Down Expand Up @@ -360,8 +363,7 @@ func extractTimestamp(src *string) (sec, nsec int64, err error) {
return
}

func jsonDataToArrow(ctx context.Context, bldr *array.RecordBuilder, ld gosnowflake.ArrowStreamLoader) (arrow.Record, error) {
rawData := ld.JSONData()
func jsonDataToArrow(_ context.Context, bldr *array.RecordBuilder, rawData [][]*string) (arrow.Record, error) {
fieldBuilders := bldr.Fields()
for _, rec := range rawData {
for i, col := range rec {
Expand Down Expand Up @@ -471,7 +473,12 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
return nil, errToAdbcErr(adbc.StatusInternal, err)
}

if len(batches) == 0 {
// if the first chunk was JSON, that means this was a metadata query which
// is only returning JSON data rather than Arrow
rawData := ld.JSONData()
if len(rawData) > 0 {
// construct an Arrow schema based on reading the JSON metadata description of the
// result type schema
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
if err != nil {
return nil, adbc.Error{
Expand All @@ -480,20 +487,87 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}
}

if ld.TotalRows() == 0 {
return array.NewRecordReader(schema, []arrow.Record{})
}

bldr := array.NewRecordBuilder(alloc, schema)
defer bldr.Release()

rec, err := jsonDataToArrow(ctx, bldr, ld)
rec, err := jsonDataToArrow(ctx, bldr, rawData)
if err != nil {
return nil, err
}
defer rec.Release()

if ld.TotalRows() != 0 {
return array.NewRecordReader(schema, []arrow.Record{rec})
} else {
return array.NewRecordReader(schema, []arrow.Record{})
results := []arrow.Record{rec}
for _, b := range batches {
rdr, err := b.GetStream(ctx)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
defer rdr.Close()

// the "JSON" data returned isn't valid JSON. Instead it is a list of
// comma-delimited JSON lists containing every value as a string, except
// for a JSON null to represent nulls. Thus we can't just use the existing
// JSON parsing code in Arrow.
data, err := io.ReadAll(rdr)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}

if cap(rawData) >= int(b.NumRows()) {
rawData = rawData[:b.NumRows()]
} else {
rawData = make([][]*string, b.NumRows())
}
bldr.Reserve(int(b.NumRows()))

// we grab the entire JSON message and create a bytes reader
offset, buf := int64(0), bytes.NewReader(data)
for i := 0; i < int(b.NumRows()); i++ {
// we construct a decoder from the bytes.Reader to read the next JSON list
// of columns (one row) from the input
dec := json.NewDecoder(buf)
if err = dec.Decode(&rawData[i]); err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}

// dec.InputOffset() now represents the index of the ',' so we skip the comma
offset += dec.InputOffset() + 1
// then seek the buffer to that spot. we have to seek based on the start
// because json.Decoder can read from the buffer more than is necessary to
// process the JSON data.
if _, err = buf.Seek(offset, 0); err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
}

// now that we have our [][]*string of JSON data, we can pass it to get converted
// to an Arrow record batch and appended to our slice of batches
rec, err := jsonDataToArrow(ctx, bldr, rawData)
if err != nil {
return nil, err
}
defer rec.Release()

results = append(results, rec)
}

return array.NewRecordReader(schema, results)
}

ch := make(chan arrow.Record, bufferSize)
Expand Down

0 comments on commit 32a27f1

Please sign in to comment.