From ba64bd3866023c36c93cf2b46a066a1b008bb6c4 Mon Sep 17 00:00:00 2001 From: Ajay Gopinathan Date: Wed, 6 Mar 2019 09:16:51 -0800 Subject: [PATCH] Record TFX output artifacts in Metadata store (#884) * WIP: ML Metadata in KFP * Move metadata tracking to its own package. * Clean up * Address review comments, update travis.yml * Add dependencies for building in Dockerfile * Log errors but continue to update run when metadata storing fails. * Update workspace to get latest ml-metadata version. * Update errors --- .travis.yml | 10 +- BUILD.bazel | 4 + WORKSPACE | 46 ++++++- backend/Dockerfile | 31 +++-- backend/src/apiserver/BUILD.bazel | 5 + backend/src/apiserver/client_manager.go | 38 +++++- backend/src/apiserver/main.go | 5 +- backend/src/apiserver/metadata/BUILD.bazel | 27 ++++ .../src/apiserver/metadata/metadata_store.go | 127 ++++++++++++++++++ .../apiserver/metadata/metadata_store_test.go | 122 +++++++++++++++++ .../apiserver/resource/client_manager_fake.go | 3 +- .../apiserver/resource/resource_manager.go | 1 + backend/src/apiserver/storage/BUILD.bazel | 4 + backend/src/apiserver/storage/run_store.go | 50 ++++++- .../src/apiserver/storage/run_store_test.go | 2 +- go.mod | 5 +- go.sum | 2 + 17 files changed, 453 insertions(+), 29 deletions(-) create mode 100644 backend/src/apiserver/metadata/BUILD.bazel create mode 100644 backend/src/apiserver/metadata/metadata_store.go create mode 100644 backend/src/apiserver/metadata/metadata_store_test.go diff --git a/.travis.yml b/.travis.yml index b0b55b80995..f5d369f5053 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,8 +31,14 @@ matrix: - cd $TRAVIS_BUILD_DIR/backend/src - gimme -f 1.11.4 - source ~/.gimme/envs/go1.11.4.env - - go vet -all -shadow ./... - - go test ./... + - go vet -all -shadow ./agent/... + - go vet -all -shadow ./cmd/... + - go vet -all -shadow ./common/... + - go vet -all -shadow ./crd/... + - go test ./agent/... + - go test ./cmd/... + - go test ./common/... + - go test ./crd/... - language: python python: "2.7" env: TOXENV=py27 diff --git a/BUILD.bazel b/BUILD.bazel index 83604ad683a..0e9906756fd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4,4 +4,8 @@ load("@bazel_gazelle//:def.bzl", "gazelle") # gazelle:resolve proto protoc-gen-swagger/options/annotations.proto @com_github_grpc_ecosystem_grpc_gateway//protoc-gen-swagger/options:options_proto # gazelle:resolve proto go protoc-gen-swagger/options/annotations.proto @com_github_grpc_ecosystem_grpc_gateway//protoc-gen-swagger/options:go_default_library # gazelle:resolve go github.com/kubeflow/pipelines/backend/api/go_client //backend/api:go_default_library +# gazelle:resolve go ml_metadata/metadata_store/mlmetadata @google_ml_metadata//ml_metadata/metadata_store:metadata_store_go +# gazelle:resolve go ml_metadata/proto/metadata_store_go_proto @google_ml_metadata//ml_metadata/proto:metadata_store_go_proto +# gazelle:resolve go ml_metadata/proto/metadata_store_service_go_proto @google_ml_metadata//ml_metadata/proto:metadata_store_service_go_proto +# gazelle:exclude vendor/ gazelle(name = "gazelle") diff --git a/WORKSPACE b/WORKSPACE index 004d355cda7..9b87ac3db09 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,12 +1,13 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "io_bazel_rules_go", - sha256 = "7be7dc01f1e0afdba6c8eb2b43d2fa01c743be1b9273ab1eaf6c233df078d705", - urls = ["https://github.com/bazelbuild/rules_go/releases/download/0.16.5/rules_go-0.16.5.tar.gz"], + sha256 = "492c3ac68ed9dcf527a07e6a1b2dcbf199c6bf8b35517951467ac32e421c06c1", + urls = ["https://github.com/bazelbuild/rules_go/releases/download/0.17.0/rules_go-0.17.0.tar.gz"], ) -load("@io_bazel_rules_go//go:def.bzl", "go_register_toolchains", "go_rules_dependencies") +load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") go_rules_dependencies() @@ -22,6 +23,45 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") gazelle_dependencies() +http_archive( + name = "org_tensorflow", + sha256 = "24570d860d87dcfb936f53fb8dd30302452d0aa6b8b8537e4555c1bf839121a6", + strip_prefix = "tensorflow-1.13.0-rc0", + urls = [ + "https://github.com/tensorflow/tensorflow/archive/v1.13.0-rc0.tar.gz", + ], +) + +http_archive( + name = "io_bazel_rules_closure", + sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1", + strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21 + ], +) + +load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") + +tf_workspace() + +load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") + +go_repository( + name = "google_ml_metadata", + commit = "becc26ab61f82bfe7c812894f56f597949ce0fdc", + importpath = "github.com/google/ml-metadata", +) + +new_git_repository( + name = "libmysqlclient", + build_file = "@google_ml_metadata//ml_metadata:libmysqlclient.BUILD", + remote = "https://github.com/MariaDB/mariadb-connector-c.git", + tag = "v3.0.8-release", + workspace_file = "@google_ml_metadata//ml_metadata:libmysqlclient.WORKSPACE", +) + go_repository( name = "io_k8s_client_go", build_file_proto_mode = "disable_global", diff --git a/backend/Dockerfile b/backend/Dockerfile index 8c3fe14e54f..a29bd044ceb 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,25 +1,29 @@ -FROM golang:1.11-alpine3.7 as builder +FROM l.gcr.io/google/bazel:0.21.0 as builder +RUN apt-get update && \ + apt-get install -y cmake clang musl-dev openssl WORKDIR /go/src/github.com/kubeflow/pipelines -COPY . . -# Needed musl-dev for github.com/mattn/go-sqlite3 -RUN apk update && apk upgrade && \ - apk add --no-cache bash git openssh gcc musl-dev +COPY WORKSPACE WORKSPACE +COPY backend/src backend/src +COPY backend/api backend/api -RUN GO111MODULE=on go build -o /bin/apiserver backend/src/apiserver/*.go +RUN bazel build -c opt --action_env=PATH --define=grpc_no_ares=true backend/src/apiserver:apiserver + +# Compile FROM python:3.5 as compiler RUN apt-get update -y && \ - apt-get install --no-install-recommends -y -q default-jdk wget - + apt-get install --no-install-recommends -y -q default-jdk wget RUN pip3 install setuptools==40.5.0 RUN wget http://central.maven.org/maven2/io/swagger/swagger-codegen-cli/2.4.1/swagger-codegen-cli-2.4.1.jar -O /tmp/swagger-codegen-cli.jar +# WORKDIR /go/src/github.com/kubeflow/pipelines WORKDIR /go/src/github.com/kubeflow/pipelines -COPY . . +COPY backend/api backend/api +COPY sdk sdk WORKDIR /go/src/github.com/kubeflow/pipelines/sdk/python RUN ./build.sh /kfp.tar.gz RUN pip3 install /kfp.tar.gz @@ -36,22 +40,21 @@ COPY ./samples . #The "for" loop breaks on all whitespace, so we either need to override IFS or use the "read" command instead. RUN find . -maxdepth 2 -name '*.py' -type f | while read pipeline; do dsl-compile --py "$pipeline" --output "$pipeline.tar.gz"; done - -FROM alpine:3.8 +FROM debian:stretch ARG COMMIT_SHA=unknown ENV COMMIT_SHA=${COMMIT_SHA} WORKDIR /bin -COPY --from=builder /bin/apiserver /bin/apiserver -COPY --from=builder /go/src/github.com/kubeflow/pipelines/third_party/license.txt /bin/license.txt +COPY third_party/license.txt /bin/license.txt +COPY --from=builder /go/src/github.com/kubeflow/pipelines/bazel-bin/backend/src/apiserver/linux_amd64_stripped/apiserver /bin/apiserver COPY backend/src/apiserver/config/ /config COPY --from=compiler /samples/ /samples/ # Adding CA certificate so API server can download pipeline through URL -RUN apk add ca-certificates +RUN apt-get update && apt-get install -y ca-certificates # Expose apiserver port EXPOSE 8888 diff --git a/backend/src/apiserver/BUILD.bazel b/backend/src/apiserver/BUILD.bazel index d2796b180d9..6cf5f8bd900 100644 --- a/backend/src/apiserver/BUILD.bazel +++ b/backend/src/apiserver/BUILD.bazel @@ -13,6 +13,7 @@ go_library( deps = [ "//backend/api:go_default_library", "//backend/src/apiserver/client:go_default_library", + "//backend/src/apiserver/metadata:go_default_library", "//backend/src/apiserver/model:go_default_library", "//backend/src/apiserver/resource:go_default_library", "//backend/src/apiserver/server:go_default_library", @@ -23,12 +24,16 @@ go_library( "@com_github_cenkalti_backoff//:go_default_library", "@com_github_fsnotify_fsnotify//:go_default_library", "@com_github_golang_glog//:go_default_library", + "@com_github_golang_protobuf//proto:go_default_library", "@com_github_grpc_ecosystem_grpc_gateway//runtime:go_default_library", "@com_github_jinzhu_gorm//:go_default_library", "@com_github_jinzhu_gorm//dialects/sqlite:go_default_library", "@com_github_minio_minio_go//:go_default_library", "@com_github_pkg_errors//:go_default_library", "@com_github_spf13_viper//:go_default_library", + "@google_ml_metadata//ml_metadata/metadata_store:metadata_store_go", # keep + "@google_ml_metadata//ml_metadata/proto:metadata_store_go_proto", # keep + "@google_ml_metadata//ml_metadata/proto:metadata_store_service_go_proto", # keep "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//reflection:go_default_library", ], diff --git a/backend/src/apiserver/client_manager.go b/backend/src/apiserver/client_manager.go index b77eb1abdad..1c932661e88 100644 --- a/backend/src/apiserver/client_manager.go +++ b/backend/src/apiserver/client_manager.go @@ -17,19 +17,25 @@ package main import ( "database/sql" "fmt" + "strconv" "time" workflowclient "github.com/argoproj/argo/pkg/client/clientset/versioned/typed/workflow/v1alpha1" "github.com/cenkalti/backoff" "github.com/golang/glog" + "github.com/golang/protobuf/proto" "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/sqlite" "github.com/kubeflow/pipelines/backend/src/apiserver/client" + "github.com/kubeflow/pipelines/backend/src/apiserver/metadata" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/apiserver/storage" "github.com/kubeflow/pipelines/backend/src/common/util" scheduledworkflowclient "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/clientset/versioned/typed/scheduledworkflow/v1beta1" minio "github.com/minio/minio-go" + + "ml_metadata/metadata_store/mlmetadata" + mlpb "ml_metadata/proto/metadata_store_go_proto" ) const ( @@ -56,6 +62,8 @@ type ClientManager struct { swfClient scheduledworkflowclient.ScheduledWorkflowInterface time util.TimeInterface uuid util.UUIDGeneratorInterface + + MetadataStore *mlmetadata.Store } func (c *ClientManager) ExperimentStore() storage.ExperimentStoreInterface { @@ -117,7 +125,6 @@ func (c *ClientManager) init() { c.experimentStore = storage.NewExperimentStore(db, c.time, c.uuid) c.pipelineStore = storage.NewPipelineStore(db, c.time, c.uuid) c.jobStore = storage.NewJobStore(db, c.time) - c.runStore = storage.NewRunStore(db, c.time) c.resourceReferenceStore = storage.NewResourceReferenceStore(db) c.dBStatusStore = storage.NewDBStatusStore(db) c.objectStore = initMinioClient(getDurationConfig(initConnectionTimeout)) @@ -127,6 +134,11 @@ func (c *ClientManager) init() { c.swfClient = client.CreateScheduledWorkflowClientOrFatal( getStringConfig(podNamespace), getDurationConfig(initConnectionTimeout)) + + metadataStore := initMetadataStore() + runStore := storage.NewRunStore(db, c.time, metadataStore) + c.runStore = runStore + glog.Infof("Client manager initialized successfully") } @@ -134,6 +146,30 @@ func (c *ClientManager) Close() { c.db.Close() } +func initMetadataStore() *metadata.Store { + port, err := strconv.Atoi(getStringConfig(mysqlServicePort)) + if err != nil { + glog.Fatalf("Failed to parse valid MySQL service port from %q: %v", getStringConfig(mysqlServicePort), err) + } + + cfg := &mlpb.ConnectionConfig{ + Config: &mlpb.ConnectionConfig_Mysql{ + &mlpb.MySQLDatabaseConfig{ + Host: proto.String(getStringConfig(mysqlServiceHost)), + Port: proto.Uint32(uint32(port)), + Database: proto.String("mlmetadata"), + User: proto.String("root"), + }, + }, + } + + mlmdStore, err := mlmetadata.NewStore(cfg) + if err != nil { + glog.Fatalf("Failed to create ML Metadata store: %v", err) + } + return metadata.NewStore(mlmdStore) +} + func initDBClient(initConnectionTimeout time.Duration) *storage.DB { driverName := getStringConfig("DBConfig.DriverName") var arg string diff --git a/backend/src/apiserver/main.go b/backend/src/apiserver/main.go index b1568d837ad..283cf08dd58 100644 --- a/backend/src/apiserver/main.go +++ b/backend/src/apiserver/main.go @@ -35,6 +35,10 @@ import ( "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + + _ "ml_metadata/metadata_store/mlmetadata" + _ "ml_metadata/proto/metadata_store_go_proto" + _ "ml_metadata/proto/metadata_store_service_go_proto" ) var ( @@ -48,7 +52,6 @@ type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.Serv func main() { flag.Parse() - glog.Infof("starting API server") initConfig() clientManager := newClientManager() diff --git a/backend/src/apiserver/metadata/BUILD.bazel b/backend/src/apiserver/metadata/BUILD.bazel new file mode 100644 index 00000000000..c106b18b6f4 --- /dev/null +++ b/backend/src/apiserver/metadata/BUILD.bazel @@ -0,0 +1,27 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = ["metadata_store.go"], + importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/metadata", + visibility = ["//visibility:public"], + deps = [ + "//backend/src/common/util:go_default_library", + "@com_github_argoproj_argo//pkg/apis/workflow/v1alpha1:go_default_library", + "@com_github_golang_protobuf//jsonpb:go_default_library_gen", + "@com_github_golang_protobuf//proto:go_default_library", + "@google_ml_metadata//ml_metadata/metadata_store:metadata_store_go", + "@google_ml_metadata//ml_metadata/proto:metadata_store_go_proto", + ], +) + +go_test( + name = "go_default_test", + srcs = ["metadata_store_test.go"], + embed = [":go_default_library"], + deps = [ + "@com_github_golang_protobuf//proto:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@google_ml_metadata//ml_metadata/proto:metadata_store_go_proto", + ], +) diff --git a/backend/src/apiserver/metadata/metadata_store.go b/backend/src/apiserver/metadata/metadata_store.go new file mode 100644 index 00000000000..a457c7540e6 --- /dev/null +++ b/backend/src/apiserver/metadata/metadata_store.go @@ -0,0 +1,127 @@ +package metadata + +import ( + "encoding/json" + "ml_metadata/metadata_store/mlmetadata" + mlpb "ml_metadata/proto/metadata_store_go_proto" + "strings" + + argoWorkflow "github.com/argoproj/argo/pkg/apis/workflow/v1alpha1" + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/kubeflow/pipelines/backend/src/common/util" +) + +// Store encapsulates a ML Metadata store. +type Store struct { + mlmdStore *mlmetadata.Store +} + +// NewStore creates a new Store, using mlmdStore as the backing ML Metadata +// store. +func NewStore(mlmdStore *mlmetadata.Store) *Store { + return &Store{mlmdStore: mlmdStore} +} + +// RecordOutputArtifacts records metadata on artifacts as parsed from the Argo +// output parameters in currentManifest. storedManifest represents the currently +// stored manifest for the run with id runID, and is used to ensure metadata is +// recorded at most once per artifact. +func (s *Store) RecordOutputArtifacts(runID, storedManifest, currentManifest string) error { + storedWorkflow := &argoWorkflow.Workflow{} + if err := json.Unmarshal([]byte(storedManifest), storedWorkflow); err != nil { + return util.NewInternalServerError(err, "unmarshaling workflow failed") + } + + currentWorkflow := &argoWorkflow.Workflow{} + if err := json.Unmarshal([]byte(currentManifest), currentWorkflow); err != nil { + return util.NewInternalServerError(err, "unmarshaling workflow failed") + } + + completed := make(map[string]bool) + for _, n := range storedWorkflow.Status.Nodes { + if n.Completed() { + completed[n.ID] = true + } + } + + for _, n := range currentWorkflow.Status.Nodes { + if n.Completed() && !completed[n.ID] { + // Newly completed node. Record output ml-metadata artifacts. + if n.Outputs != nil { + for _, output := range n.Outputs.Parameters { + if !strings.HasPrefix(output.ValueFrom.Path, "/output/ml_metadata/") { + continue + } + + artifacts, err := parseTFXMetadata(*output.Value) + if err != nil { + return util.NewInvalidInputError("metadata parsing failure: %v", err) + } + + if err := s.storeArtifacts(artifacts); err != nil { + return util.NewInvalidInputError("artifact storing failure: %v", err) + } + } + } + } + } + + return nil +} + +func (s *Store) storeArtifacts(artifacts artifactStructs) error { + for _, a := range artifacts { + id, err := s.mlmdStore.PutArtifactType( + a.ArtifactType, &mlmetadata.PutTypeOptions{AllFieldsMustMatch: true}) + if err != nil { + return util.NewInternalServerError(err, "failed to register artifact type") + } + a.Artifact.TypeId = proto.Int64(int64(id)) + _, err = s.mlmdStore.PutArtifacts([]*mlpb.Artifact{a.Artifact}) + if err != nil { + return util.NewInternalServerError(err, "failed to record artifact") + } + } + return nil +} + +type artifactStruct struct { + ArtifactType *mlpb.ArtifactType `json:"artifact_type"` + Artifact *mlpb.Artifact `json:"artifact"` +} + +func (a *artifactStruct) UnmarshalJSON(b []byte) error { + errorF := func(err error) error { + return util.NewInvalidInputError("JSON Unmarshal failure: %v", err) + } + + jsonMap := make(map[string]*json.RawMessage) + if err := json.Unmarshal(b, &jsonMap); err != nil { + return errorF(err) + } + + a.ArtifactType = &mlpb.ArtifactType{} + a.Artifact = &mlpb.Artifact{} + + if err := jsonpb.UnmarshalString(string(*jsonMap["artifact_type"]), a.ArtifactType); err != nil { + return errorF(err) + } + + if err := jsonpb.UnmarshalString(string(*jsonMap["artifact"]), a.Artifact); err != nil { + return errorF(err) + } + + return nil +} + +type artifactStructs []*artifactStruct + +func parseTFXMetadata(value string) (artifactStructs, error) { + var tfxArtifacts artifactStructs + + if err := json.Unmarshal([]byte(value), &tfxArtifacts); err != nil { + return nil, util.NewInternalServerError(err, "parse TFX metadata failure") + } + return tfxArtifacts, nil +} diff --git a/backend/src/apiserver/metadata/metadata_store_test.go b/backend/src/apiserver/metadata/metadata_store_test.go new file mode 100644 index 00000000000..1ad0895a9d1 --- /dev/null +++ b/backend/src/apiserver/metadata/metadata_store_test.go @@ -0,0 +1,122 @@ +package metadata + +import ( + "fmt" + mlpb "ml_metadata/proto/metadata_store_go_proto" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" +) + +func (as artifactStructs) String() string { + var s string + for _, a := range as { + s += fmt.Sprintf("%+v\n", a) + } + return s +} + +func TestParseValidTFXMetadata(t *testing.T) { + tests := []struct { + input string + want artifactStructs + }{ + { + `[{ + "artifact_type": { + "name": "Artifact", + "properties": {"state": "STRING", "span": "INT" } }, + "artifact": { + "uri": "/location", + "properties": { + "state": {"stringValue": "complete"}, + "span": {"intValue": 10} } + } + }]`, + []*artifactStruct{ + &artifactStruct{ + ArtifactType: &mlpb.ArtifactType{ + Name: proto.String("Artifact"), + Properties: map[string]mlpb.PropertyType{ + "state": mlpb.PropertyType_STRING, + "span": mlpb.PropertyType_INT, + }, + }, + Artifact: &mlpb.Artifact{ + Uri: proto.String("/location"), + Properties: map[string]*mlpb.Value{ + "state": &mlpb.Value{Value: &mlpb.Value_StringValue{"complete"}}, + "span": &mlpb.Value{Value: &mlpb.Value_IntValue{10}}, + }, + }, + }, + }, + }, + { + `[{ + "artifact_type": { + "name": "Artifact 1", + "properties": {"state": "STRING", "span": "INT" } }, + "artifact": { + "uri": "/location 1", + "properties": { + "state": {"stringValue": "complete"}, + "span": {"intValue": 10} } + } + }, + { + "artifact_type": { + "name": "Artifact 2", + "properties": {"state": "STRING", "span": "INT" } }, + "artifact": { + "uri": "/location 2", + "properties": { + "state": {"stringValue": "complete"}, + "span": {"intValue": 10} } + } + }]`, + []*artifactStruct{ + &artifactStruct{ + ArtifactType: &mlpb.ArtifactType{ + Name: proto.String("Artifact 1"), + Properties: map[string]mlpb.PropertyType{ + "state": mlpb.PropertyType_STRING, + "span": mlpb.PropertyType_INT, + }, + }, + Artifact: &mlpb.Artifact{ + Uri: proto.String("/location 1"), + Properties: map[string]*mlpb.Value{ + "state": &mlpb.Value{Value: &mlpb.Value_StringValue{"complete"}}, + "span": &mlpb.Value{Value: &mlpb.Value_IntValue{10}}, + }, + }, + }, + &artifactStruct{ + ArtifactType: &mlpb.ArtifactType{ + Name: proto.String("Artifact 2"), + Properties: map[string]mlpb.PropertyType{ + "state": mlpb.PropertyType_STRING, + "span": mlpb.PropertyType_INT, + }, + }, + Artifact: &mlpb.Artifact{ + Uri: proto.String("/location 2"), + Properties: map[string]*mlpb.Value{ + "state": &mlpb.Value{Value: &mlpb.Value_StringValue{"complete"}}, + "span": &mlpb.Value{Value: &mlpb.Value_IntValue{10}}, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + got, err := parseTFXMetadata(test.input) + if err != nil || !cmp.Equal(got, test.want) { + t.Errorf("parseTFXMetadata(%q)\nGot:\n<%+v, %+v>\nWant:\n%+v, nil error\nDiff:\n%s", test.input, got, err, test.want, cmp.Diff(test.want, got)) + } + } +} diff --git a/backend/src/apiserver/resource/client_manager_fake.go b/backend/src/apiserver/resource/client_manager_fake.go index 6eb857b131a..71a9ce8468e 100644 --- a/backend/src/apiserver/resource/client_manager_fake.go +++ b/backend/src/apiserver/resource/client_manager_fake.go @@ -58,12 +58,13 @@ func NewFakeClientManager(time util.TimeInterface, uuid util.UUIDGeneratorInterf return nil, err } + // TODO(neuromage): Pass in metadata.Store instance for tests as well. return &FakeClientManager{ db: db, experimentStore: storage.NewExperimentStore(db, time, uuid), pipelineStore: storage.NewPipelineStore(db, time, uuid), jobStore: storage.NewJobStore(db, time), - runStore: storage.NewRunStore(db, time), + runStore: storage.NewRunStore(db, time, nil), workflowClientFake: NewWorkflowClientFake(), resourceReferenceStore: storage.NewResourceReferenceStore(db), dBStatusStore: storage.NewDBStatusStore(db), diff --git a/backend/src/apiserver/resource/resource_manager.go b/backend/src/apiserver/resource/resource_manager.go index c171607d8bc..bdd62110b97 100644 --- a/backend/src/apiserver/resource/resource_manager.go +++ b/backend/src/apiserver/resource/resource_manager.go @@ -361,6 +361,7 @@ func (r *ResourceManager) DeleteJob(jobID string) error { func (r *ResourceManager) ReportWorkflowResource(workflow *util.Workflow) error { runId := string(workflow.UID) jobId := workflow.ScheduledWorkflowUUIDAsStringOrEmpty() + if jobId == "" { // If a run doesn't have owner UID, it's a one-time run created by Pipeline API server. // In this case the DB entry should already been created when argo workflow CRD is created. diff --git a/backend/src/apiserver/storage/BUILD.bazel b/backend/src/apiserver/storage/BUILD.bazel index 2d825d6c935..15faaed1f14 100644 --- a/backend/src/apiserver/storage/BUILD.bazel +++ b/backend/src/apiserver/storage/BUILD.bazel @@ -24,6 +24,7 @@ go_library( "//backend/api:go_default_library", "//backend/src/apiserver/common:go_default_library", "//backend/src/apiserver/list:go_default_library", + "//backend/src/apiserver/metadata:go_default_library", "//backend/src/apiserver/model:go_default_library", "//backend/src/common/util:go_default_library", "@com_github_ghodss_yaml//:go_default_library", @@ -35,6 +36,9 @@ go_library( "@com_github_minio_minio_go//:go_default_library", "@com_github_pkg_errors//:go_default_library", "@com_github_vividcortex_mysqlerr//:go_default_library", + "@google_ml_metadata//ml_metadata/metadata_store:metadata_store_go", # keep + "@google_ml_metadata//ml_metadata/proto:metadata_store_go_proto", # keep + "@google_ml_metadata//ml_metadata/proto:metadata_store_service_go_proto", # keep "@io_k8s_apimachinery//pkg/util/json:go_default_library", ], ) diff --git a/backend/src/apiserver/storage/run_store.go b/backend/src/apiserver/storage/run_store.go index e980ae2252e..48e4aa1a207 100644 --- a/backend/src/apiserver/storage/run_store.go +++ b/backend/src/apiserver/storage/run_store.go @@ -20,9 +20,11 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/golang/glog" + api "github.com/kubeflow/pipelines/backend/api/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/list" + "github.com/kubeflow/pipelines/backend/src/apiserver/metadata" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" "k8s.io/apimachinery/pkg/util/json" @@ -59,6 +61,7 @@ type RunStore struct { db *DB resourceReferenceStore *ResourceReferenceStore time util.TimeInterface + metadataStore *metadata.Store } // Runs two SQL queries in a transaction to return a list of matching runs, as well as their @@ -350,6 +353,33 @@ func (s *RunStore) CreateRun(r *model.RunDetail) (*model.RunDetail, error) { } func (s *RunStore) UpdateRun(runID string, condition string, workflowRuntimeManifest string) (err error) { + tx, err := s.db.DB.Begin() + if err != nil { + return util.NewInternalServerError(err, "transaction creation failed") + } + + // Lock the row for update, so we ensure no other update of the same run + // happens while we're parsing it for metadata. We rely on per-row updates + // being synchronous, so metadata can be recorded at most once. Right now, + // persistence agent will call UpdateRun all the time, even if there is nothing + // new in the status of an Argo manifest. This means we need to keep track + // manually here on what the previously updated state of the run is, to ensure + // we do not add duplicate metadata. Hence the locking below. + row := tx.QueryRow("SELECT WorkflowRuntimeManifest FROM run_details WHERE UUID = ? FOR UPDATE", runID) + var storedManifest string + if err := row.Scan(&storedManifest); err != nil { + tx.Rollback() + return util.NewInternalServerError(err, "failed to find row with run id %q", runID) + } + + if s.metadataStore != nil { + if err := s.metadataStore.RecordOutputArtifacts(runID, storedManifest, workflowRuntimeManifest); err != nil { + // Metadata storage failed. Log the error here, but continue to allow the run + // to be updated as per usual. + glog.Errorf("Failed to record output artifacts: %+v", err) + } + } + sql, args, err := sq. Update("run_details"). SetMap(sq.Eq{ @@ -358,17 +388,23 @@ func (s *RunStore) UpdateRun(runID string, condition string, workflowRuntimeMani Where(sq.Eq{"UUID": runID}). ToSql() if err != nil { + tx.Rollback() return util.NewInternalServerError(err, "Failed to create query to update run %s. error: '%v'", runID, err.Error()) } - result, err := s.db.Exec(sql, args...) + result, err := tx.Exec(sql, args...) if err != nil { + tx.Rollback() return util.NewInternalServerError(err, "Failed to update run %s. error: '%v'", runID, err.Error()) } if r, _ := result.RowsAffected(); r != 1 { + tx.Rollback() return util.NewInvalidInputError("Failed to update run %s. Row not found.", runID) } + if err := tx.Commit(); err != nil { + return util.NewInternalServerError(err, "failed to commit transaction") + } return nil } @@ -510,7 +546,13 @@ func (s *RunStore) toRunMetadatas(models []model.ListableDataModel) []model.Run return runMetadatas } -// factory function for run store -func NewRunStore(db *DB, time util.TimeInterface) *RunStore { - return &RunStore{db: db, resourceReferenceStore: NewResourceReferenceStore(db), time: time} +// NewRunStore creates a new RunStore. If metadataStore is non-nil, it will be +// used to record artifact metadata. +func NewRunStore(db *DB, time util.TimeInterface, metadataStore *metadata.Store) *RunStore { + return &RunStore{ + db: db, + resourceReferenceStore: NewResourceReferenceStore(db), + time: time, + metadataStore: metadataStore, + } } diff --git a/backend/src/apiserver/storage/run_store_test.go b/backend/src/apiserver/storage/run_store_test.go index 96e9bb6e4c4..71445147c9e 100644 --- a/backend/src/apiserver/storage/run_store_test.go +++ b/backend/src/apiserver/storage/run_store_test.go @@ -33,7 +33,7 @@ func initializeRunStore() (*DB, *RunStore) { expStore.CreateExperiment(&model.Experiment{Name: "exp1"}) expStore = NewExperimentStore(db, util.NewFakeTimeForEpoch(), util.NewFakeUUIDGeneratorOrFatal(defaultFakeExpIdTwo, nil)) expStore.CreateExperiment(&model.Experiment{Name: "exp2"}) - runStore := NewRunStore(db, util.NewFakeTimeForEpoch()) + runStore := NewRunStore(db, util.NewFakeTimeForEpoch(), nil) run1 := &model.RunDetail{ Run: model.Run{ diff --git a/go.mod b/go.mod index 1218b387ce2..4abb191fae6 100644 --- a/go.mod +++ b/go.mod @@ -25,13 +25,14 @@ require ( github.com/go-openapi/swag v0.17.0 github.com/go-openapi/validate v0.17.2 github.com/go-sql-driver/mysql v1.4.0 - github.com/gogo/protobuf v1.1.1 + github.com/gogo/protobuf v1.1.1 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b github.com/golang/groupcache v0.0.0-20180513044358-24b0969c4cb7 // indirect github.com/golang/protobuf v1.2.0 github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a // indirect github.com/google/go-cmp v0.2.0 github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf // indirect + github.com/google/ml-metadata v0.0.0-20190214221617-0fb82dc56ff7 github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57 // indirect github.com/google/uuid v1.0.0 github.com/googleapis/gnostic v0.2.0 // indirect @@ -85,7 +86,7 @@ require ( k8s.io/api v0.0.0-20180712090710-2d6f90ab1293 k8s.io/apiextensions-apiserver v0.0.0-20190103235604-e7617803aceb // indirect k8s.io/apimachinery v0.0.0-20180621070125-103fd098999d - k8s.io/apiserver v0.0.0-20190112184317-d55c9aeff1eb + k8s.io/apiserver v0.0.0-20190112184317-d55c9aeff1eb // indirect k8s.io/client-go v0.0.0-20180718001006-59698c7d9724 k8s.io/klog v0.1.0 // indirect k8s.io/kube-openapi v0.0.0-20180719232738-d8ea2fe547a4 // indirect diff --git a/go.sum b/go.sum index 1042c08173d..b1391d101d0 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf h1:+RRA9JqSOZFfKrOeqr2z77+8R2RKyh8PG66dcu1V0ck= github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= +github.com/google/ml-metadata v0.0.0-20190214221617-0fb82dc56ff7 h1:Db+CbWW+XCYzfL662n+i2/xQGmLy4nRFFX3fEscNstk= +github.com/google/ml-metadata v0.0.0-20190214221617-0fb82dc56ff7/go.mod h1:yO0xrdRxF2VbZGmBEPUsnKnANnPE/3kULpqDAFRPCmg= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=