diff --git a/go.mod b/go.mod index 19e4fd4ee..0ebcf138c 100644 --- a/go.mod +++ b/go.mod @@ -146,3 +146,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteplugins => github.com/flyteorg/flyteplugins v1.0.16-0.20221011220618-4654389800fe diff --git a/go.sum b/go.sum index 963053582..c577b44a7 100644 --- a/go.sum +++ b/go.sum @@ -294,8 +294,8 @@ github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4 github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v1.1.19 h1:1CtSbuFhFHwUbKdv66PqbcER01iacAJU+snh0eTsXc4= github.com/flyteorg/flyteidl v1.1.19/go.mod h1:SLTYz2JgIKvM5MbPVlMP7uILb65fnuuZQZFHHIEYh2U= -github.com/flyteorg/flyteplugins v1.0.15 h1:LewZIw2qSyGy34OoghYeuc7N/KazeVZvD0gNYXt/ZcM= -github.com/flyteorg/flyteplugins v1.0.15/go.mod h1:GfbmRByI/rSatm/Epoj3bNyrXwIQ9NOXTVwLS6Z0p84= +github.com/flyteorg/flyteplugins v1.0.16-0.20221011220618-4654389800fe h1:SKV7Nn9aUHCVEVPP8/S+Qcl1t83bzzwz/6deAYIldPc= +github.com/flyteorg/flyteplugins v1.0.16-0.20221011220618-4654389800fe/go.mod h1:GfbmRByI/rSatm/Epoj3bNyrXwIQ9NOXTVwLS6Z0p84= github.com/flyteorg/flytestdlib v1.0.0/go.mod h1:QSVN5wIM1lM9d60eAEbX7NwweQXW96t5x4jbyftn89c= github.com/flyteorg/flytestdlib v1.0.5 h1:80A/vfpAJl+pgU6vxccbsYApZPrvyGhOIsCAFngsjnk= github.com/flyteorg/flytestdlib v1.0.5/go.mod h1:WTe0k3DmmrKFjj3hwiIbjjdCK89X63MBzBbXhQ4Yxf0= diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go index b776d8c55..f74a0f151 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go @@ -6,6 +6,9 @@ import ( "fmt" "time" + "github.com/flyteorg/flytestdlib/storage" + "golang.org/x/exp/maps" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" @@ -135,13 +138,22 @@ func (m *CatalogClient) Get(ctx context.Context, key catalog.Key) (catalog.Entry md := EventCatalogMetadata(dataset.GetId(), relevantTag, source) outputs, err := GenerateTaskOutputsFromArtifact(key.Identifier, key.TypedInterface, artifact) + var deckURI *storage.DataReference + if artifact.GetMetadata() != nil { + deckURIValue, ok := artifact.GetMetadata().KeyMap[DeckURIKey] + if ok { + reference := storage.DataReference(deckURIValue) + deckURI = &reference + } + } + if err != nil { logger.Errorf(ctx, "DataCatalog failed to get outputs from artifact %+v, err: %+v", artifact.Id, err) - return catalog.NewCatalogEntry(ioutils.NewInMemoryOutputReader(outputs, nil, nil), catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, md)), err + return catalog.NewCatalogEntry(ioutils.NewInMemoryOutputReader(outputs, deckURI, nil), catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, md)), err } logger.Infof(ctx, "Retrieved %v outputs from artifact %v, tag: %v", len(outputs.Literals), artifact.Id, tag) - return catalog.NewCatalogEntry(ioutils.NewInMemoryOutputReader(outputs, nil, nil), catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, md)), nil + return catalog.NewCatalogEntry(ioutils.NewInMemoryOutputReader(outputs, deckURI, nil), catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, md)), nil } func (m *CatalogClient) CreateDataset(ctx context.Context, key catalog.Key, metadata *datacatalog.Metadata) (*datacatalog.DatasetID, error) { @@ -240,11 +252,12 @@ func (m *CatalogClient) Put(ctx context.Context, key catalog.Key, reader io.Outp } // Create the artifact for the execution that belongs in the task - cachedArtifact, err := m.CreateArtifact(ctx, datasetID, outputs, GetArtifactMetadataForSource(metadata.TaskExecutionIdentifier)) + artifactMetadata := GetArtifactMetadataForSource(metadata.TaskExecutionIdentifier) + maps.Copy(artifactMetadata.KeyMap, reader.GetOutputMetadata(ctx)) + cachedArtifact, err := m.CreateArtifact(ctx, datasetID, outputs, artifactMetadata) if err != nil { return catalog.Status{}, errors.Wrapf(err, "failed to create dataset for ID %s", key.Identifier.String()) } - // Tag the artifact since it is the cached artifact tagName, err := GenerateArtifactTagName(ctx, inputs) if err != nil { diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go index abf36f128..34e0519e0 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go @@ -219,6 +219,7 @@ const ( execProjectKey = "exec-project" execNodeIDKey = "exec-node" execTaskAttemptKey = "exec-attempt" + DeckURIKey = "deck-uri" ) // Understanding Catalog Identifiers diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index a5685c6f7..068a06de8 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -37,6 +37,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog/datacatalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" ) @@ -591,8 +592,13 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) logger.Errorf(ctx, "failed to write cached value to datastore, err: %s", err.Error()) return handler.UnknownTransition, err } - - pluginTrns.CacheHit(tCtx.ow.GetOutputPath(), nil, entry) + deckPathValue, ok := tCtx.ow.GetReader().GetOutputMetadata(ctx)[datacatalog.DeckURIKey] + if ok { + deckPath := storage.DataReference(deckPathValue) + pluginTrns.CacheHit(tCtx.ow.GetOutputPath(), &deckPath, entry) + } else { + pluginTrns.CacheHit(tCtx.ow.GetOutputPath(), nil, entry) + } } else { logger.Infof(ctx, "No CacheHIT. Status [%s]", entry.GetStatus().GetCacheStatus().String()) pluginTrns.PopulateCacheInfo(entry) diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index bc2487a6a..5d1544d6c 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -48,6 +48,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + datacatalogClient "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog/datacatalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" @@ -59,6 +60,7 @@ var eventConfig = &controllerConfig.EventConfig{ } const testClusterID = "C1" +const deckPath = "deck.html" func Test_task_setDefault(t *testing.T) { type fields struct { @@ -908,6 +910,7 @@ func Test_task_Handle_Catalog(t *testing.T) { if tt.args.catalogFetch { or := &ioMocks.OutputReader{} or.OnDeckExistsMatch(mock.Anything).Return(true, nil) + or.OnGetOutputMetadataMatch(mock.Anything).Return(map[string]string{datacatalogClient.DeckURIKey: deckPath}) or.OnReadMatch(mock.Anything).Return(&core.LiteralMap{}, nil, nil) c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewCatalogEntry(or, catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil)), nil) } else { @@ -935,6 +938,9 @@ func Test_task_Handle_Catalog(t *testing.T) { } if err == nil { assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) + if tt.name == "cache-hit" { + assert.Equal(t, deckPath, got.Info().GetInfo().OutputInfo.DeckURI.String()) + } if assert.Equal(t, 1, len(ev.evs)) { e := ev.evs[0] assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String()) @@ -1136,6 +1142,7 @@ func Test_task_Handle_Reservation(t *testing.T) { if tt.args.catalogFetch { or := &ioMocks.OutputReader{} or.OnDeckExistsMatch(mock.Anything).Return(true, nil) + or.OnGetOutputMetadataMatch(mock.Anything).Return(map[string]string{datacatalogClient.DeckURIKey: deckPath}) or.OnReadMatch(mock.Anything).Return(&core.LiteralMap{}, nil, nil) c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewCatalogEntry(or, catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil)), nil) } else { @@ -1157,6 +1164,9 @@ func Test_task_Handle_Reservation(t *testing.T) { } if err == nil { assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) + if tt.name == "cache-hit" { + assert.Equal(t, deckPath, got.Info().GetInfo().OutputInfo.DeckURI.String()) + } if assert.Equal(t, 1, len(ev.evs)) { e := ev.evs[0] assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String())