Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 08dd79b

Browse files
authored
Create a FileOutput reader if the agent produce file output (#391)
Signed-off-by: Kevin Su <[email protected]>
1 parent bf27745 commit 08dd79b

File tree

4 files changed

+46
-14
lines changed

4 files changed

+46
-14
lines changed

go/tasks/plugins/webapi/agent/integration_test.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.Crea
4848
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
4949
}
5050

51-
func (m *MockClient) GetTask(_ context.Context, _ *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
52-
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
53-
Literals: map[string]*flyteIdlCore.Literal{
54-
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
55-
},
56-
}}}, nil
51+
func (m *MockClient) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
52+
if req.GetTaskType() == "bigquery_query_job_task" {
53+
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
54+
Literals: map[string]*flyteIdlCore.Literal{
55+
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
56+
},
57+
}}}, nil
58+
}
59+
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil
5760
}
5861

5962
func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) {
@@ -113,6 +116,11 @@ func TestEndToEnd(t *testing.T) {
113116

114117
phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)
115118
assert.Equal(t, true, phase.Phase().IsSuccess())
119+
120+
template.Type = "spark_job"
121+
phase = tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)
122+
assert.Equal(t, true, phase.Phase().IsSuccess())
123+
116124
})
117125

118126
t.Run("failed to create a job", func(t *testing.T) {
@@ -251,7 +259,7 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {
251259
func newMockAgentPlugin() webapi.PluginEntry {
252260
return webapi.PluginEntry{
253261
ID: "agent-service",
254-
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"},
262+
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job"},
255263
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
256264
return &MockPlugin{
257265
Plugin{

go/tasks/plugins/webapi/agent/plugin.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88

99
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
10-
"github.com/flyteorg/flytestdlib/config"
1110
"google.golang.org/grpc/credentials"
1211
"google.golang.org/grpc/credentials/insecure"
1312

@@ -19,8 +18,11 @@ import (
1918
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
2019
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
2120
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
21+
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
2222
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
2323
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
24+
"github.com/flyteorg/flytestdlib/config"
25+
"github.com/flyteorg/flytestdlib/logger"
2426
"github.com/flyteorg/flytestdlib/promutils"
2527
"google.golang.org/grpc"
2628
)
@@ -176,17 +178,38 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
176178
case admin.State_RETRYABLE_FAILURE:
177179
return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
178180
case admin.State_SUCCEEDED:
179-
if resource.Outputs != nil {
180-
err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil))
181-
if err != nil {
182-
return core.PhaseInfoUndefined, err
183-
}
181+
err = writeOutput(ctx, taskCtx, resource)
182+
if err != nil {
183+
logger.Errorf(ctx, "Failed to write output with err %s", err.Error())
184+
return core.PhaseInfoUndefined, err
184185
}
185186
return core.PhaseInfoSuccess(taskInfo), nil
186187
}
187188
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State)
188189
}
189190

191+
func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource *ResourceWrapper) error {
192+
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
193+
if err != nil {
194+
return err
195+
}
196+
197+
if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil {
198+
logger.Debugf(ctx, "The task declares no outputs. Skipping writing the outputs.")
199+
return nil
200+
}
201+
202+
var opReader io.OutputReader
203+
if resource.Outputs != nil {
204+
logger.Debugf(ctx, "Agent returned an output")
205+
opReader = ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)
206+
} else {
207+
logger.Debugf(ctx, "Agent didn't return any output, assuming file based outputs.")
208+
opReader = ioutils.NewRemoteFileOutputReader(ctx, taskCtx.DataStore(), taskCtx.OutputWriter(), taskCtx.MaxDatasetSizeBytes())
209+
}
210+
return taskCtx.OutputWriter().Put(ctx, opReader)
211+
}
212+
190213
func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
191214
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
192215
if agent, exists := cfg.Agents[id]; exists {

go/tasks/plugins/webapi/agent/plugin_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func TestPlugin(t *testing.T) {
4242
assert.Equal(t, plugin.cfg.ResourceConstraints, constraints)
4343
})
4444

45-
t.Run("tet newAgentPlugin", func(t *testing.T) {
45+
t.Run("test newAgentPlugin", func(t *testing.T) {
4646
p := newAgentPlugin()
4747
assert.NotNil(t, p)
4848
assert.Equal(t, "agent-service", p.ID)

tests/end_to_end.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
9292
outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb")
9393
outputWriter.OnGetCheckpointPrefix().Return("/checkpoint")
9494
outputWriter.OnGetPreviousCheckpointsPrefix().Return("/prev")
95+
outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil)
9596

9697
outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
9798
or := args.Get(1).(io.OutputReader)

0 commit comments

Comments
 (0)