Skip to content

Commit

Permalink
FORK: fix bucket root query strings (upstream kubeflow#10319)
Browse files Browse the repository at this point in the history
Signed-off-by: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com>
  • Loading branch information
thesuperzapper committed Apr 14, 2024
1 parent 3963f34 commit 3d95f8d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 11 deletions.
12 changes: 3 additions & 9 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import (
"context"
"encoding/json"
"fmt"
"path"
"strconv"
"strings"
"time"

"github.com/golang/glog"
Expand Down Expand Up @@ -1062,7 +1060,9 @@ func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.C
outputs.Artifacts[name] = &pipelinespec.ArtifactList{
Artifacts: []*pipelinespec.RuntimeArtifact{
{
Uri: generateOutputURI(pipelineRoot, name, taskName),
// Do not preserve the query string for output artifacts, as otherwise
// they'd appear in file and artifact names.
Uri: metadata.GenerateOutputURI(pipelineRoot, []string{taskName, name}, false),
Type: artifact.GetArtifactType(),
Metadata: artifact.GetMetadata(),
},
Expand All @@ -1078,12 +1078,6 @@ func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.C
return outputs
}

func generateOutputURI(root, artifactName string, taskName string) string {
// we cannot path.Join(root, taskName, artifactName), because root
// contains scheme like gs:// and path.Join cleans up scheme to gs:/
return fmt.Sprintf("%s/%s", strings.TrimRight(root, "/"), path.Join(taskName, artifactName))
}

var accessModeMap = map[string]k8score.PersistentVolumeAccessMode{
"ReadWriteOnce": k8score.ReadWriteOnce,
"ReadOnlyMany": k8score.ReadOnlyMany,
Expand Down
22 changes: 21 additions & 1 deletion backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ func (e *Execution) FingerPrint() string {
return e.execution.GetCustomProperties()[keyCacheFingerPrint].GetStringValue()
}

// GenerateOutputURI appends the specified paths to the pipeline root.
// It may be configured to preserve the query part of the pipeline root
// by splitting it off and appending it back to the full URI.
func GenerateOutputURI(pipelineRoot string, paths []string, preserveQueryString bool) string {
querySplit := strings.Split(pipelineRoot, "?")
query := ""
if len(querySplit) == 2 {
pipelineRoot = querySplit[0]
if preserveQueryString {
query = "?" + querySplit[1]
}
} else if len(querySplit) > 2 {
// this should never happen, but just in case.
glog.Warningf("Unexpected pipeline root: %v", pipelineRoot)
}
// we cannot path.Join(root, taskName, artifactName), because root
// contains scheme like gs:// and path.Join cleans up scheme to gs:/
return fmt.Sprintf("%s/%s%s", strings.TrimRight(pipelineRoot, "/"), path.Join(paths...), query)
}

// GetPipeline returns the current pipeline represented by the specified
// pipeline name and run ID.
func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot string) (*Pipeline, error) {
Expand All @@ -272,7 +292,7 @@ func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace
keyNamespace: stringValue(namespace),
keyResourceName: stringValue(runResource),
// pipeline root of this run
keyPipelineRoot: stringValue(strings.TrimRight(pipelineRoot, "/") + "/" + path.Join(pipelineName, runID)),
keyPipelineRoot: stringValue(GenerateOutputURI(pipelineRoot, []string{pipelineName, runID}, true)),
}
runContext, err := c.getOrInsertContext(ctx, runID, pipelineRunContextType, metadata)
glog.Infof("Pipeline Run Context: %+v", runContext)
Expand Down
56 changes: 55 additions & 1 deletion backend/src/v2/metadata/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func Test_GetPipeline_Twice(t *testing.T) {
// The second call to GetPipeline won't fail because it avoid inserting to MLMD again.
samePipeline, err := client.GetPipeline(ctx, "get-pipeline-test", runId, namespace, runResource, pipelineRoot)
fatalIf(err)
if (pipeline.GetCtxID() != samePipeline.GetCtxID()) {
if pipeline.GetCtxID() != samePipeline.GetCtxID() {
t.Errorf("Expect pipeline context ID %d, actual is %d", pipeline.GetCtxID(), samePipeline.GetCtxID())
}
}
Expand Down Expand Up @@ -214,6 +214,60 @@ func Test_GetPipelineConcurrently(t *testing.T) {
wg.Wait()
}

func Test_GenerateOutputURI(t *testing.T) {
// Const define the artifact name
const (
pipelineName = "my-pipeline-name"
runID = "my-run-id"
pipelineRoot = "minio://mlpipeline/v2/artifacts"
pipelineRootQuery = "?query=string&another=query"
)
tests := []struct {
name string
queryString string
paths []string
preserveQueryString bool
want string
}{
{
name: "plain pipeline root without preserveQueryString",
queryString: "",
paths: []string{pipelineName, runID},
preserveQueryString: false,
want: fmt.Sprintf("%s/%s/%s", pipelineRoot, pipelineName, runID),
},
{
name: "plain pipeline root with preserveQueryString",
queryString: "",
paths: []string{pipelineName, runID},
preserveQueryString: true,
want: fmt.Sprintf("%s/%s/%s", pipelineRoot, pipelineName, runID),
},
{
name: "pipeline root with query string without preserveQueryString",
queryString: pipelineRootQuery,
paths: []string{pipelineName, runID},
preserveQueryString: false,
want: fmt.Sprintf("%s/%s/%s", pipelineRoot, pipelineName, runID),
},
{
name: "pipeline root with query string with preserveQueryString",
queryString: pipelineRootQuery,
paths: []string{pipelineName, runID},
preserveQueryString: true,
want: fmt.Sprintf("%s/%s/%s%s", pipelineRoot, pipelineName, runID, pipelineRootQuery),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := metadata.GenerateOutputURI(fmt.Sprintf("%s%s", pipelineRoot, tt.queryString), tt.paths, tt.preserveQueryString)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("GenerateOutputURI() = %v, want %v\nDiff (-want, +got)\n%s", got, tt.want, diff)
}
})
}
}

func Test_DAG(t *testing.T) {
t.Skip("Temporarily disable the test that requires cluster connection.")

Expand Down

0 comments on commit 3d95f8d

Please sign in to comment.