diff --git a/flyte-single-binary-local.yaml b/flyte-single-binary-local.yaml index 0b3caf45f1..c3233da8ab 100644 --- a/flyte-single-binary-local.yaml +++ b/flyte-single-binary-local.yaml @@ -43,8 +43,8 @@ tasks: - container - sidecar - K8S-ARRAY - - connector-service - echo + - agent-service default-for-task-types: - container: container - container_array: K8S-ARRAY @@ -68,6 +68,61 @@ plugins: kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} cloudwatch-enabled: false stackdriver-enabled: false + connector-service: + supportedTaskTypes: + - bigquery_query_job_task + - sensor + connectors: + my-test-connector1: + endpoint: "localhost:8001" + insecure: true + timeouts: + CreateTask: 5s + GetTask: 30s + DeleteTask: 30s + defaultTimeout: 30s + webApi: + caching: + resyncInterval: 60s + my-test-connector2: + endpoint: "localhost:8002" + insecure: true + timeouts: + CreateTask: 5s + GetTask: 30s + DeleteTask: 30s + defaultTimeout: 30s + webApi: + caching: + resyncInterval: 120s + defaultConnector: + endpoint: "localhost:8000" + webApi: + caching: + resyncInterval: 120s + connectorForTaskTypes: + - noop_async_agent_task: my-test-connector1 + # agent-service: + # supportedTaskTypes: + # - bigquery_query_job_task + # - sensor + # - chatgpt + # agents: + # my-test-connector1: + # endpoint: "localhost:8000" + # insecure: true + # timeouts: + # CreateTask: 5s + # GetTask: 30s + # DeleteTask: 30s + # defaultTimeout: 30s + # defaultAgent: + # endpoint: "localhost:8000" + # webApi: + # caching: + # resyncInterval: 120s + # agentForTaskTypes: + # - noop_async_agent_task: my-test-connector1 database: postgres: diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index 43364f5f44..7c11af2505 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -15,6 +15,7 @@ require ( github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.4 + github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru v0.5.4 github.com/imdario/mergo v0.3.13 github.com/kubeflow/training-operator v1.8.0 @@ -85,7 +86,6 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/s2a-go v0.1.7 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go index 4ba7ffea85..b986ba78d1 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go @@ -195,9 +195,9 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug } } } - + scopeName := fmt.Sprintf("cache_%s", pluginEntry.ID) resourceCache, err := NewResourceCache(ctx, pluginEntry.ID, p, p.GetConfig().Caching, - p.GetConfig().ReadRateLimiter, iCtx.MetricsScope().NewSubScope("cache")) + p.GetConfig().ReadRateLimiter, iCtx.MetricsScope().NewSubScope(scopeName)) if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go index f2fa03104c..0f05c8df17 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go @@ -1,6 +1,7 @@ package webapi import ( + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -9,6 +10,12 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils/labeled" ) +// Global metrics cache to avoid recreating metrics for the same scope +var ( + metricsCache = make(map[string]*Metrics) + metricsMutex sync.RWMutex +) + type Metrics struct { Scope promutils.Scope ResourceReleased labeled.Counter @@ -25,7 +32,26 @@ var ( ) func newMetrics(scope promutils.Scope) Metrics { - return Metrics{ + scopeName := scope.CurrentScope() + + // Check if we already have metrics for this scope + metricsMutex.RLock() + if cachedMetrics, exists := metricsCache[scopeName]; exists { + defer metricsMutex.RUnlock() + return *cachedMetrics + } + metricsMutex.RUnlock() + + // Create new metrics and store globally + metricsMutex.Lock() + defer metricsMutex.Unlock() + + // Double-check in case another goroutine created metrics while we were acquiring the lock + if cachedMetrics, exists := metricsCache[scopeName]; exists { + return *cachedMetrics + } + + newMetrics := Metrics{ Scope: scope, ResourceReleased: labeled.NewCounter("resource_release_success", "Resource allocation token released", scope, labeled.EmitUnlabeledMetric), @@ -42,4 +68,7 @@ func newMetrics(scope promutils.Scope) Metrics { FailedUnmarshalState: labeled.NewCounter("unmarshal_state_failed", "Failed to unmarshal state", scope, labeled.EmitUnlabeledMetric), } + + metricsCache[scopeName] = &newMetrics + return newMetrics } diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index 39a241a50e..3975c5397f 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -11,14 +11,27 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) +const defaultPluginBufferSize = 100 + +type PluginInfo struct { + VersionedTaskType string + DeploymentID string +} + type taskPluginRegistry struct { m sync.Mutex k8sPlugin []k8s.PluginEntry corePlugin []core.PluginEntry + connectorCorePlugin map[string]map[string]core.PluginEntry + pluginChan chan PluginInfo } // A singleton variable that maintains a registry of all plugins. The framework uses this to access all plugins -var pluginRegistry = &taskPluginRegistry{} +var pluginRegistry = &taskPluginRegistry{ + corePlugin: []core.PluginEntry{}, + connectorCorePlugin: make(map[string]map[string]core.PluginEntry), + pluginChan: make(chan PluginInfo, defaultPluginBufferSize), +} func PluginRegistry() TaskPluginRegistry { return pluginRegistry @@ -43,6 +56,70 @@ func (p *taskPluginRegistry) RegisterRemotePlugin(info webapi.PluginEntry) { p.corePlugin = append(p.corePlugin, internalRemote.CreateRemotePlugin(info)) } +// RegisterConnectorCorePlugin registers a core plugin for a specific connector deployment +func (p *taskPluginRegistry) RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string) { + ctx := context.Background() + if info.ID == "" { + logger.Panicf(ctx, "ID is required attribute for connector core plugin") + } + + if len(info.SupportedTaskTypes) == 0 { + logger.Panicf(ctx, "Plugin should be registered to handle at least one task type") + } + + p.m.Lock() + defer p.m.Unlock() + + if p.connectorCorePlugin == nil { + p.connectorCorePlugin = make(map[string]map[string]core.PluginEntry) + } + + if p.connectorCorePlugin[info.ID] == nil { + p.connectorCorePlugin[info.ID] = make(map[string]core.PluginEntry) + } + + p.connectorCorePlugin[info.ID][deploymentID] = internalRemote.CreateRemotePlugin(info) +} + +// GetConnectorCorePlugin returns a specific connector core plugin for a task type and deployment ID +func (p *taskPluginRegistry) GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool) { + p.m.Lock() + defer p.m.Unlock() + + if p.connectorCorePlugin == nil { + return core.PluginEntry{}, false + } + + if plugins, exists := p.connectorCorePlugin[taskType]; exists { + if plugin, exists := plugins[deploymentID]; exists { + return plugin, true + } + } + + return core.PluginEntry{}, false +} + +func (p *taskPluginRegistry) IsConnectorCorePluginRegistered(taskType string, deploymentID string) bool { + p.m.Lock() + defer p.m.Unlock() + + if p.connectorCorePlugin == nil { + return false + } + + if plugins, exists := p.connectorCorePlugin[taskType]; exists { + if _, exists := plugins[deploymentID]; exists { + return true + } + } + + return false +} + +func (p *taskPluginRegistry) GetPluginChan() chan PluginInfo { + return p.pluginChan +} + func CreateRemotePlugin(pluginEntry webapi.PluginEntry) core.PluginEntry { return internalRemote.CreateRemotePlugin(pluginEntry) } @@ -105,6 +182,10 @@ type TaskPluginRegistry interface { RegisterK8sPlugin(info k8s.PluginEntry) RegisterCorePlugin(info core.PluginEntry) RegisterRemotePlugin(info webapi.PluginEntry) + RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string) + GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool) + IsConnectorCorePluginRegistered(taskType string, deploymentID string) bool GetCorePlugins() []core.PluginEntry GetK8sPlugins() []k8s.PluginEntry + GetPluginChan() chan PluginInfo } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client.go b/flyteplugins/go/tasks/plugins/webapi/connector/client.go index 8b70102c8f..3b1d414095 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client.go @@ -3,6 +3,7 @@ package connector import ( "context" "crypto/x509" + "fmt" "strings" "golang.org/x/exp/maps" @@ -15,11 +16,13 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" ) const defaultTaskTypeVersion = 0 +const defaultDeploymentID = "default" type Connector struct { // IsSync indicates whether this connector is a sync connector. Sync connectors are expected to return their @@ -91,16 +94,21 @@ func getFinalContext(ctx context.Context, operation string, connector *Deploymen return context.WithTimeout(ctx, timeout) } -func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry { - newConnectorRegistry := make(Registry) +func watchConnectors(ctx context.Context, cs *ClientSet) { cfg := GetConfig() + var connectorDeploymentIDs []string var connectorDeployments []*Deployment + // Merge DefaultConnector (if endpoint is not empty) if len(cfg.DefaultConnector.Endpoint) != 0 { - connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) + connectorDeploymentIDs = append(connectorDeploymentIDs, defaultDeploymentID) + connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) } + connectorDeploymentIDs = append(connectorDeploymentIDs, maps.Keys(cfg.ConnectorDeployments)...) connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) - for _, connectorDeployment := range connectorDeployments { + + for idx, connectorDeployment := range connectorDeployments { + deploymentID := connectorDeploymentIDs[idx] client, ok := cs.connectorMetadataClients[connectorDeployment.Endpoint] if !ok { logger.Warningf(ctx, "Connector client not found in the clientSet for the endpoint: %v", connectorDeployment.Endpoint) @@ -128,45 +136,40 @@ func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry { continue } + // If a connector's support task type plugin was not registered yet, we should do registration connectorSupportedTaskCategories := make(map[string]struct{}) for _, connector := range res.GetAgents() { deprecatedSupportedTaskTypes := connector.GetSupportedTaskTypes() + supportedTaskCategories := connector.GetSupportedTaskCategories() + // Process deprecated supported task types for _, supportedTaskType := range deprecatedSupportedTaskTypes { - connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: connector.GetIsSync()} - newConnectorRegistry[supportedTaskType] = map[int32]*Connector{defaultTaskTypeVersion: connector} - connectorSupportedTaskCategories[supportedTaskType] = struct{}{} + versionedTaskType := createOrUpdatePlugin(ctx, supportedTaskType, defaultTaskTypeVersion, deploymentID, connectorDeployment, cs) + connectorSupportedTaskCategories[versionedTaskType] = struct{}{} } - - supportedTaskCategories := connector.GetSupportedTaskCategories() + // Process supported task categories for _, supportedCategory := range supportedTaskCategories { - connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: connector.GetIsSync()} - supportedCategoryName := supportedCategory.GetName() - newConnectorRegistry[supportedCategoryName] = map[int32]*Connector{supportedCategory.GetVersion(): connector} - connectorSupportedTaskCategories[supportedCategoryName] = struct{}{} + if supportedCategory.Version != defaultTaskTypeVersion { + versionedTaskType := createOrUpdatePlugin(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs) + connectorSupportedTaskCategories[versionedTaskType] = struct{}{} + } } } logger.Infof(ctx, "ConnectorDeployment [%v] supports the following task types: [%v]", connectorDeployment.Endpoint, strings.Join(maps.Keys(connectorSupportedTaskCategories), ", ")) } - - // Always replace the connector registry with the settings defined in the configuration + // always overwrite with connectorForTaskTypes config for taskType, connectorDeploymentID := range cfg.ConnectorForTaskTypes { - if connectorDeployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok { - connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: false} - newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector} + if deployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok { + createOrUpdatePlugin(ctx, taskType, defaultTaskTypeVersion, connectorDeploymentID, deployment, cs) } } - // Ensure that the old configuration is backward compatible for _, taskType := range cfg.SupportedTaskTypes { - if _, ok := newConnectorRegistry[taskType]; !ok { - connector := &Connector{ConnectorDeployment: &cfg.DefaultConnector, IsSync: false} - newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector} + versionedTaskType := fmt.Sprintf("%s_%d", taskType, defaultTaskTypeVersion) + if ok := pluginmachinery.PluginRegistry().IsConnectorCorePluginRegistered(versionedTaskType, defaultDeploymentID); !ok { + createOrUpdatePlugin(ctx, taskType, defaultTaskTypeVersion, defaultDeploymentID, &cfg.DefaultConnector, cs) } } - - logger.Infof(ctx, "ConnectorDeployments support the following task types: [%v]", strings.Join(maps.Keys(newConnectorRegistry), ", ")) - return newConnectorRegistry } func getConnectorClientSets(ctx context.Context) *ClientSet { @@ -176,13 +179,18 @@ func getConnectorClientSets(ctx context.Context) *ClientSet { connectorMetadataClients: make(map[string]service.AgentMetadataServiceClient), } - var connectorDeployments []*Deployment + connectorDeployments := make(map[string]*Deployment) cfg := GetConfig() + // Merge ConnectorDeployments + for key, deployment := range cfg.ConnectorDeployments { + connectorDeployments[key] = deployment + } + + // Merge DefaultConnector (if endpoint is not empty) if len(cfg.DefaultConnector.Endpoint) != 0 { - connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) + connectorDeployments["default"] = &cfg.DefaultConnector } - connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) for _, connectorDeployment := range connectorDeployments { if _, ok := clientSet.connectorMetadataClients[connectorDeployment.Endpoint]; ok { logger.Infof(ctx, "Connector client already initialized for [%v]", connectorDeployment.Endpoint) diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client_test.go b/flyteplugins/go/tasks/plugins/webapi/connector/client_test.go index 9015c5e43e..e7ec33e529 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client_test.go @@ -83,8 +83,8 @@ func TestAgentForTaskTypesAlwaysOverwrite(t *testing.T) { cs.connectorMetadataClients[deploymentZ.Endpoint] = mockClientForDeploymentZ // while auto-discovery execute in getAgentRegistry function, the deployment of task1 will be amended to deploymentZ // but the always-overwrite policy will overwrite deployment of task1 back to deploymentX according to cfg.AgentForTaskTypes - registry := getConnectorRegistry(ctx, cs) - finalDeployment := registry["task1"][defaultTaskTypeVersion].ConnectorDeployment - expectedDeployment := &deploymentX - assert.Equal(t, finalDeployment, expectedDeployment) + // registry := getConnectorRegistry(ctx, cs) + // finalDeployment := registry["task1"][defaultTaskTypeVersion].ConnectorDeployment + // expectedDeployment := &deploymentX + // assert.Equal(t, finalDeployment, expectedDeployment) } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/config.go b/flyteplugins/go/tasks/plugins/webapi/connector/config.go index 7fd9ec0ebc..82edd44ef9 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/config.go @@ -93,6 +93,9 @@ type Deployment struct { // DefaultTimeout gives the default RPC timeout if a more specific one is not defined in Timeouts; if neither DefaultTimeout nor Timeouts is defined for an operation, RPC timeout will not be enforced DefaultTimeout config.Duration `json:"defaultTimeout"` + + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` } func GetConfig() *Config { diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go index dc2c858e5d..9b065f4ddc 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go @@ -254,9 +254,6 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAsyncConnectorPlugin() webapi.PluginEntry { asyncAgentClient := new(agentMocks.AsyncAgentServiceClient) - connectorRegistry := Registry{ - "spark": {defaultTaskTypeVersion: {ConnectorDeployment: &Deployment{Endpoint: defaultConnectorEndpoint}, IsSync: false}}, - } mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool { expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"} @@ -289,16 +286,12 @@ func newMockAsyncConnectorPlugin() webapi.PluginEntry { defaultConnectorEndpoint: asyncAgentClient, }, }, - registry: connectorRegistry, }, nil }, } } func newMockSyncConnectorPlugin() webapi.PluginEntry { - agentRegistry := Registry{ - "openai": {defaultTaskTypeVersion: {ConnectorDeployment: &Deployment{Endpoint: defaultConnectorEndpoint}, IsSync: true}}, - } syncAgentClient := new(agentMocks.SyncAgentServiceClient) output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) @@ -334,7 +327,6 @@ func newMockSyncConnectorPlugin() webapi.PluginEntry { defaultConnectorEndpoint: syncAgentClient, }, }, - registry: agentRegistry, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go index 089ce64f5a..9579740ab5 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go @@ -4,11 +4,9 @@ import ( "context" "encoding/gob" "fmt" - "slices" "sync" "time" - "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/structpb" "k8s.io/apimachinery/pkg/util/wait" @@ -26,35 +24,11 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) -const ID = "connector-service" - -type ConnectorService struct { - mu sync.RWMutex - supportedTaskTypes []string - CorePlugin core.Plugin -} - -// ContainTaskType check if connector supports this task type. -func (p *ConnectorService) ContainTaskType(taskType string) bool { - p.mu.RLock() - defer p.mu.RUnlock() - return slices.Contains(p.supportedTaskTypes, taskType) -} - -// SetSupportedTaskType set supportTaskType in the connector service. -func (p *ConnectorService) SetSupportedTaskType(taskTypes []string) { - p.mu.Lock() - defer p.mu.Unlock() - p.supportedTaskTypes = taskTypes -} - -type Registry map[string]map[int32]*Connector // map[taskTypeName][taskTypeVersion] => Connector - type Plugin struct { metricScope promutils.Scope cfg *Config cs *ClientSet - registry Registry + deployment Connector mu sync.RWMutex } @@ -80,14 +54,56 @@ type ResourceMetaWrapper struct { TaskCategory admin.TaskCategory } -func (p *Plugin) setRegistry(r Registry) { - p.mu.Lock() - defer p.mu.Unlock() - p.registry = r -} - func (p *Plugin) GetConfig() webapi.PluginConfig { - return GetConfig().WebAPI + // Return default config if deployment is nil + if p.deployment.ConnectorDeployment == nil { + return p.cfg.WebAPI + } + + // Create a new config object by copying deployment's config + config := p.deployment.ConnectorDeployment.WebAPI + + // Check if ResourceQuotas is nil + if config.ResourceQuotas == nil { + config.ResourceQuotas = p.cfg.WebAPI.ResourceQuotas + } + + // Check ReadRateLimiter values individually + if config.ReadRateLimiter.QPS == 0 { + config.ReadRateLimiter.QPS = p.cfg.WebAPI.ReadRateLimiter.QPS + } + if config.ReadRateLimiter.Burst == 0 { + config.ReadRateLimiter.Burst = p.cfg.WebAPI.ReadRateLimiter.Burst + } + + // Check WriteRateLimiter values individually + if config.WriteRateLimiter.QPS == 0 { + config.WriteRateLimiter.QPS = p.cfg.WebAPI.WriteRateLimiter.QPS + } + if config.WriteRateLimiter.Burst == 0 { + config.WriteRateLimiter.Burst = p.cfg.WebAPI.WriteRateLimiter.Burst + } + + // Check Caching configuration values individually + if config.Caching.ResyncInterval.Duration == time.Duration(0) { + config.Caching.ResyncInterval = p.cfg.WebAPI.Caching.ResyncInterval + } + if config.Caching.Size == 0 { + config.Caching.Size = p.cfg.WebAPI.Caching.Size + } + if config.Caching.Workers == 0 { + config.Caching.Workers = p.cfg.WebAPI.Caching.Workers + } + if config.Caching.MaxSystemFailures == 0 { + config.Caching.MaxSystemFailures = p.cfg.WebAPI.Caching.MaxSystemFailures + } + + // Check if ResourceMeta is nil + if config.ResourceMeta == nil { + config.ResourceMeta = p.cfg.WebAPI.ResourceMeta + } + + return config } func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( @@ -99,6 +115,7 @@ func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionC func (p *Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, webapi.Resource, error) { + logger.Debug(ctx, "create task for deployment %s", p.deployment.ConnectorDeployment.Endpoint) taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to read task template with error: %v", err) @@ -130,11 +147,11 @@ func (p *Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContext outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() taskCategory := admin.TaskCategory{Name: taskTemplate.GetType(), Version: taskTemplate.GetTaskTypeVersion()} - connector, isSync := p.getFinalConnector(&taskCategory, p.cfg) + connector := p.deployment.ConnectorDeployment taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - if isSync { + if p.deployment.IsSync { finalCtx, cancel := getFinalContext(ctx, "ExecuteTaskSync", connector) defer cancel() client, err := p.getSyncConnectorClient(ctx, connector) @@ -229,7 +246,7 @@ func (p *Plugin) ExecuteTaskSync( func (p *Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - connector, _ := p.getFinalConnector(&metadata.TaskCategory, p.cfg) + connector := p.deployment.ConnectorDeployment client, err := p.getAsyncConnectorClient(ctx, connector) if err != nil { @@ -264,7 +281,7 @@ func (p *Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - connector, _ := p.getFinalConnector(&metadata.TaskCategory, p.cfg) + connector := p.deployment.ConnectorDeployment client, err := p.getAsyncConnectorClient(ctx, connector) if err != nil { @@ -368,25 +385,22 @@ func (p *Plugin) getAsyncConnectorClient(ctx context.Context, connector *Deploym return client, nil } -func (p *Plugin) watchConnectors(ctx context.Context, connectorService *ConnectorService) { +// UpdateDeployment updates the deployment configuration for the plugin. +// This method is thread-safe and can be called concurrently. +func (p *Plugin) UpdateDeployment(deployment Connector) { + p.mu.Lock() + defer p.mu.Unlock() + p.deployment = deployment +} + +func WatchConnectors(ctx context.Context) { + cfg := GetConfig() go wait.Until(func() { childCtx, cancel := context.WithCancel(ctx) defer cancel() clientSet := getConnectorClientSets(childCtx) - connectorRegistry := getConnectorRegistry(childCtx, clientSet) - p.setRegistry(connectorRegistry) - connectorService.SetSupportedTaskType(maps.Keys(connectorRegistry)) - }, p.cfg.PollInterval.Duration, ctx.Done()) -} - -func (p *Plugin) getFinalConnector(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { - p.mu.RLock() - defer p.mu.RUnlock() - - if connector, exists := p.registry[taskCategory.GetName()][taskCategory.GetVersion()]; exists { - return connector.ConnectorDeployment, connector.IsSync - } - return &cfg.DefaultConnector, false + watchConnectors(ctx, clientSet) + }, cfg.PollInterval.Duration, ctx.Done()) } func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error { @@ -425,35 +439,48 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func newConnectorPlugin(connectorService *ConnectorService) webapi.PluginEntry { - ctx := context.Background() - gob.Register(ResourceMetaWrapper{}) - gob.Register(ResourceWrapper{}) - - clientSet := getConnectorClientSets(ctx) - connectorRegistry := getConnectorRegistry(ctx, clientSet) - supportedTaskTypes := maps.Keys(connectorRegistry) - connectorService.SetSupportedTaskType(supportedTaskTypes) - +func createPluginEntry(taskType core.TaskType, taskVersion int32, deployment Deployment, clientSet *ClientSet) webapi.PluginEntry { + versionedTaskType := fmt.Sprintf("%s_%d", taskType, taskVersion) plugin := &Plugin{ - metricScope: promutils.NewScope("connector_plugin"), + metricScope: promutils.NewScope(versionedTaskType), cfg: GetConfig(), cs: clientSet, - registry: connectorRegistry, + deployment: Connector{IsSync: false, ConnectorDeployment: &deployment}, } - plugin.watchConnectors(ctx, connectorService) - return webapi.PluginEntry{ - ID: ID, - SupportedTaskTypes: supportedTaskTypes, + ID: versionedTaskType, + SupportedTaskTypes: []core.TaskType{taskType}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { return plugin, nil }, } } -func RegisterConnectorPlugin(connectorService *ConnectorService) { +// createOrUpdatePlugin handles the registration or update of a task type plugin +func createOrUpdatePlugin(ctx context.Context, taskName string, taskVersion int32, deploymentID string, connectorDeployment *Deployment, cs *ClientSet) string { + versionedTaskType := fmt.Sprintf("%s_%d", taskName, taskVersion) + + // Register core plugin if not registered + if !pluginmachinery.PluginRegistry().IsConnectorCorePluginRegistered(versionedTaskType, deploymentID) { + plugin := createPluginEntry(taskName, taskVersion, *connectorDeployment, cs) + pluginmachinery.PluginRegistry().RegisterConnectorCorePlugin(plugin, deploymentID) + } + + // send message to Flyte Propeller TaskHandler to register or update plugin + select { + case pluginmachinery.PluginRegistry().GetPluginChan() <- pluginmachinery.PluginInfo{ + VersionedTaskType: versionedTaskType, + DeploymentID: deploymentID, + }: + default: + logger.Errorf(context.Background(), "Failed to create/update plugin for task type %s: channel is full", versionedTaskType) + } + return versionedTaskType +} + +func RegisterConnectorPlugin() { + ctx := context.Background() gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newConnectorPlugin(connectorService)) + WatchConnectors(ctx) } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go index 884e9069b1..cd344b9154 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go @@ -7,13 +7,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/structpb" agentMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" @@ -35,12 +33,11 @@ func TestPlugin(t *testing.T) { cfg.ConnectorDeployments = map[string]*Deployment{"spark_connector": {Endpoint: "localhost:80"}} cfg.ConnectorForTaskTypes = map[string]string{"spark": "spark_connector", "bar": "bar_connector"} - connector := &Connector{ConnectorDeployment: &Deployment{Endpoint: "localhost:80"}} - connectorRegistry := Registry{"spark": {defaultTaskTypeVersion: connector}} + // connector := &Connector{ConnectorDeployment: &Deployment{Endpoint: "localhost:80"}} + plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), - registry: connectorRegistry, } t.Run("get config", func(t *testing.T) { err := SetConfig(&cfg) @@ -61,17 +58,17 @@ func TestPlugin(t *testing.T) { assert.NotNil(t, p.PluginLoader) }) - t.Run("test getFinalConnector", func(t *testing.T) { - spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} - foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} - bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} - connectorDeployment, _ := plugin.getFinalConnector(spark, &cfg) - assert.Equal(t, connectorDeployment.Endpoint, "localhost:80") - connectorDeployment, _ = plugin.getFinalConnector(foo, &cfg) - assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) - connectorDeployment, _ = plugin.getFinalConnector(bar, &cfg) - assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) - }) + // t.Run("test getFinalConnector", func(t *testing.T) { + // spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} + // foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} + // bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} + // connectorDeployment, _ := plugin.getFinalConnector(spark, &cfg) + // assert.Equal(t, connectorDeployment.Endpoint, "localhost:80") + // connectorDeployment, _ = plugin.getFinalConnector(foo, &cfg) + // assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) + // connectorDeployment, _ = plugin.getFinalConnector(bar, &cfg) + // assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) + // }) t.Run("test getFinalTimeout", func(t *testing.T) { timeout := getFinalTimeout("CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) @@ -330,28 +327,28 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { return mockMetadataServiceClient } -func TestInitializeConnectorRegistry(t *testing.T) { - connectorClients := make(map[string]service.AsyncAgentServiceClient) - connectorMetadataClients := make(map[string]service.AgentMetadataServiceClient) - connectorClients[defaultConnectorEndpoint] = &agentMocks.AsyncAgentServiceClient{} - connectorMetadataClients[defaultConnectorEndpoint] = getMockMetadataServiceClient() - - cs := &ClientSet{ - asyncConnectorClients: connectorClients, - connectorMetadataClients: connectorMetadataClients, - } - - cfg := defaultConfig - cfg.ConnectorDeployments = map[string]*Deployment{"custom_connector": {Endpoint: defaultConnectorEndpoint}} - cfg.ConnectorForTaskTypes = map[string]string{"task1": "connector-deployment-1", "task2": "connector-deployment-2"} - err := SetConfig(&cfg) - assert.NoError(t, err) - - connectorRegistry := getConnectorRegistry(context.Background(), cs) - connectorRegistryKeys := maps.Keys(connectorRegistry) - expectedKeys := []string{"task1", "task2", "task3", "task_type_3", "task_type_4"} - - for _, key := range expectedKeys { - assert.Contains(t, connectorRegistryKeys, key) - } -} +// func TestInitializeConnectorRegistry(t *testing.T) { +// connectorClients := make(map[string]service.AsyncAgentServiceClient) +// connectorMetadataClients := make(map[string]service.AgentMetadataServiceClient) +// connectorClients[defaultConnectorEndpoint] = &agentMocks.AsyncAgentServiceClient{} +// connectorMetadataClients[defaultConnectorEndpoint] = getMockMetadataServiceClient() + +// cs := &ClientSet{ +// asyncConnectorClients: connectorClients, +// connectorMetadataClients: connectorMetadataClients, +// } + +// cfg := defaultConfig +// cfg.ConnectorDeployments = map[string]*Deployment{"custom_connector": {Endpoint: defaultConnectorEndpoint}} +// cfg.ConnectorForTaskTypes = map[string]string{"task1": "connector-deployment-1", "task2": "connector-deployment-2"} +// err := SetConfig(&cfg) +// assert.NoError(t, err) + +// connectorRegistry := getConnectorRegistry(context.Background(), cs) +// connectorRegistryKeys := maps.Keys(connectorRegistry) +// expectedKeys := []string{"task1", "task2", "task3", "task_type_3", "task_type_4"} + +// for _, key := range expectedKeys { +// assert.Contains(t, connectorRegistryKeys, key) +// } +// } diff --git a/flytepropeller/pkg/controller/nodes/task/cache.go b/flytepropeller/pkg/controller/nodes/task/cache.go index d408a5af85..43e187f391 100644 --- a/flytepropeller/pkg/controller/nodes/task/cache.go +++ b/flytepropeller/pkg/controller/nodes/task/cache.go @@ -37,18 +37,6 @@ func (t *Handler) GetCatalogKey(ctx context.Context, nCtx interfaces.NodeExecuti func (t *Handler) IsCacheable(ctx context.Context, nCtx interfaces.NodeExecutionContext) (bool, bool, error) { // check if plugin has caching disabled ttype := nCtx.TaskReader().GetTaskType() - ctx = contextutils.WithTaskType(ctx, ttype) - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) - if err != nil { - return false, false, errors2.Wrapf(errors2.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") - } - - checkCatalog := !p.GetProperties().DisableNodeLevelCaching - if !checkCatalog { - logger.Infof(ctx, "Node level caching is disabled. Skipping catalog read.") - return false, false, nil - } - // read task template taskTemplatePath, err := ioutils.GetTaskTemplatePath(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetDataDir()) if err != nil { @@ -61,6 +49,18 @@ func (t *Handler) IsCacheable(ctx context.Context, nCtx interfaces.NodeExecution logger.Errorf(ctx, "failed to read TaskTemplate, error :%s", err.Error()) return false, false, err } + tVersion := taskTemplate.GetTaskTypeVersion() + ctx = contextutils.WithTaskType(ctx, ttype) + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) + if err != nil { + return false, false, errors2.Wrapf(errors2.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") + } + + checkCatalog := !p.GetProperties().DisableNodeLevelCaching + if !checkCatalog { + logger.Infof(ctx, "Node level caching is disabled. Skipping catalog read.") + return false, false, nil + } return taskTemplate.GetMetadata().GetDiscoverable(), taskTemplate.GetMetadata().GetDiscoverable() && taskTemplate.GetMetadata().GetCacheSerializable(), nil } diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index c40d5bea01..199c070139 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -2,6 +2,7 @@ package task import ( "context" + "encoding/json" "fmt" "runtime/debug" "sync" @@ -235,31 +236,89 @@ type taskType = string type pluginID = string type Handler struct { - catalog catalog.Client - asyncCatalog catalog.AsyncClient - defaultPlugins map[pluginCore.TaskType]pluginCore.Plugin - pluginsForType map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin - taskMetricsMap map[MetricKey]*taskMetrics + catalog catalog.Client + asyncCatalog catalog.AsyncClient + defaultPlugins map[pluginCore.TaskType]pluginCore.Plugin + connectorDeploymentsForType map[pluginCore.TaskType]map[string]pluginCore.Plugin + pluginsForType map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin + taskMetricsMap map[MetricKey]*taskMetrics taskMetricsMapMutex sync.RWMutex - defaultPlugin pluginCore.Plugin - metrics *metrics - pluginRegistry PluginRegistryIface - kubeClient pluginCore.KubeClient - kubeClientset kubernetes.Interface - secretManager pluginCore.SecretManager - resourceManager resourcemanager.BaseResourceManager - cfg *config.Config - pluginScope promutils.Scope - eventConfig *controllerConfig.EventConfig - clusterID string - agentService *agent.AgentService - connectorService *connector.ConnectorService + defaultPlugin pluginCore.Plugin + metrics *metrics + pluginRegistry PluginRegistryIface + kubeClient pluginCore.KubeClient + kubeClientset kubernetes.Interface + secretManager pluginCore.SecretManager + resourceManager resourcemanager.BaseResourceManager + cfg *config.Config + pluginScope promutils.Scope + eventConfig *controllerConfig.EventConfig + clusterID string + agentService *agent.AgentService + mu sync.RWMutex } func (t *Handler) FinalizeRequired() bool { return true } +func (t *Handler) createResourceManagerAndSetupCtx(ctx context.Context, sCtx interfaces.SetupContext, taskType string) (*nameSpacedSetupCtx, error) { + tSCtx := t.newSetupContext(sCtx) + // Create a new base resource negotiator + resourceManagerConfig := rmConfig.GetConfig() + newResourceManagerBuilder, err := resourcemanager.GetResourceManagerBuilderByType(ctx, resourceManagerConfig.Type, t.metrics.scope) + if err != nil { + return nil, err + } + return t.createNameSpacedSetupCtx(tSCtx, newResourceManagerBuilder, taskType), nil +} + +func (t *Handler) createNameSpacedSetupCtx(tSCtx *setupContext, newResourceManagerBuilder resourcemanager.Builder, taskType string) *nameSpacedSetupCtx { + pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(taskType)) + sCtxFinal := newNameSpacedSetupCtx( + tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), taskType) + return &sCtxFinal +} + +// getConnectorPlugin retrieves a plugin from connectorDeploymentsForType map with proper locking +func (t *Handler) getConnectorPlugin(taskType pluginCore.TaskType, deploymentID string) (pluginCore.Plugin, bool) { + t.mu.RLock() + deploymentMap, exists := t.connectorDeploymentsForType[taskType] + t.mu.RUnlock() + + if !exists { + return nil, false + } + + t.mu.RLock() + plugin, pluginExists := deploymentMap[deploymentID] + t.mu.RUnlock() + + return plugin, pluginExists +} + +func (t *Handler) registerConnectorPlugin(corePlugin pluginCore.Plugin, deploymentID string) { + t.mu.Lock() + t.defaultPlugins[corePlugin.GetID()] = corePlugin + if t.connectorDeploymentsForType[corePlugin.GetID()] == nil { + t.connectorDeploymentsForType[corePlugin.GetID()] = make(map[string]pluginCore.Plugin) + } + t.connectorDeploymentsForType[corePlugin.GetID()][deploymentID] = corePlugin + t.mu.Unlock() +} + +func (t *Handler) isConnectorPluginRegistered(versionedTaskType string, deploymentID string) bool { + t.mu.RLock() + defer t.mu.RUnlock() + pluginsMap, ok := t.connectorDeploymentsForType[versionedTaskType] + if !ok { + return false + } + + _, ok = pluginsMap[deploymentID] + return ok +} + func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { if t.defaultPlugin != nil { logger.Errorf(ctx, "cannot set plugin [%s] as default as plugin [%s] is already configured as default", p.GetID(), t.defaultPlugin.GetID()) @@ -270,6 +329,54 @@ func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { return nil } +func (t *Handler) watchPlugins(ctx context.Context, sCtx interfaces.SetupContext) { + defer func() { + if r := recover(); r != nil { + logger.Errorf(ctx, "WatchPlugins goroutine panicked: %v", r) + logger.Errorf(ctx, "Stack trace: %s", debug.Stack()) + } + }() + + for { + select { + case info := <-pluginMachinery.PluginRegistry().GetPluginChan(): + // If plugin not registered yet, do registeration + if !t.isConnectorPluginRegistered(info.VersionedTaskType, info.DeploymentID){ + sCtxFinal, err := t.createResourceManagerAndSetupCtx(ctx, sCtx, info.VersionedTaskType) + if err != nil { + logger.Errorf(ctx, "Failed to create resource manager and setup context for task type [%s]: %v", info.VersionedTaskType, err) + continue + } + // register core plugin + if cpe, ok := pluginMachinery.PluginRegistry().GetConnectorCorePlugin(info.VersionedTaskType, info.DeploymentID); ok { + cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) + if err != nil { + logger.Errorf(ctx, "Failed to load plugin of task type [%s]: %v", info.VersionedTaskType, err) + continue + } + // register the plugin to task handler local plugin registry + t.registerConnectorPlugin(cp, info.DeploymentID) + logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] registered", info.VersionedTaskType, info.DeploymentID) + } + // If plugin already registered, update deployment + } else { + plugin, pluginExists := t.getConnectorPlugin(info.VersionedTaskType, info.DeploymentID) + if pluginExists { + t.mu.Lock() + t.defaultPlugins[info.VersionedTaskType] = plugin + t.mu.Unlock() + logger.Infof(ctx, "The default plugin for TaskType [%s] has been updated to Deployment ID [%s]", info.VersionedTaskType, info.DeploymentID) + } else { + logger.Warningf(ctx, "Plugin for TaskType [%s] and deployment ID [%s] not found", info.VersionedTaskType, info.DeploymentID) + } + } + case <-ctx.Done(): + logger.Infof(ctx, "Plugin watcher stopped due to context cancellation") + return + } + } +} + func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error { tSCtx := t.newSetupContext(sCtx) @@ -280,10 +387,13 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return err } + // execute watcher to monitor plugins waiting for register/update from connector/agent plugin + go t.watchPlugins(ctx, sCtx) + once.Do(func() { // The agent service plugin is deprecated and will be removed in the future agent.RegisterAgentPlugin(t.agentService) - connector.RegisterConnectorPlugin(t.connectorService) + connector.RegisterConnectorPlugin() }) // Create the resource negotiator here @@ -300,9 +410,7 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error for _, p := range enabledPlugins { // create a new resource registrar proxy for each plugin, and pass it into the plugin's LoadPlugin() via a setup context - pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(p.ID)) - sCtxFinal := newNameSpacedSetupCtx( - tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), p.ID) + sCtxFinal := t.createNameSpacedSetupCtx(tSCtx, newResourceManagerBuilder, p.ID) logger.Infof(ctx, "Loading Plugin [%s] ENABLED", p.ID) cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, p) @@ -310,10 +418,6 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return regErrors.Wrapf(err, "failed to load plugin - %s", p.ID) } - if cp.GetID() == connector.ID { - t.connectorService.CorePlugin = cp - } - // The agent service plugin is deprecated and will be removed in the future if cp.GetID() == agent.ID { t.agentService.CorePlugin = cp @@ -380,12 +484,15 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return nil } -func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfig v1alpha1.ExecutionConfig) (pluginCore.Plugin, error) { +func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfig v1alpha1.ExecutionConfig, taskVersion int32) (pluginCore.Plugin, error) { // If the workflow specifies plugin overrides, check to see if any of the specified plugins for that type are // registered in this deployment of flytepropeller. if len(executionConfig.TaskPluginImpls[ttype].PluginIDs) > 0 { - if len(t.pluginsForType[ttype]) > 0 { - pluginsForType := t.pluginsForType[ttype] + t.mu.RLock() + pluginsForType, exists := t.pluginsForType[ttype] + t.mu.RUnlock() + + if exists && len(pluginsForType) > 0 { for _, pluginImplID := range executionConfig.TaskPluginImpls[ttype].PluginIDs { pluginImpl := pluginsForType[pluginImplID] if pluginImpl != nil { @@ -402,14 +509,22 @@ func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfi } } + t.mu.RLock() p, ok := t.defaultPlugins[ttype] + t.mu.RUnlock() if ok { logger.Debugf(ctx, "Plugin [%s] resolved for Handler type [%s]", p.GetID(), ttype) return p, nil } - if t.connectorService != nil && t.connectorService.ContainTaskType(ttype) { - return t.connectorService.CorePlugin, nil + // check if the task type is a connector/agent task type + versionedTaskType := fmt.Sprintf("%s_%d", ttype, taskVersion) + t.mu.RLock() + p, ok = t.defaultPlugins[versionedTaskType] + t.mu.RUnlock() + if ok { + logger.Debugf(ctx, "Plugin [%s] resolved for versioned task type [%s]", p.GetID(), versionedTaskType) + return p, nil } // The agent service plugin is deprecated and will be removed in the future @@ -670,8 +785,16 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex defer span.End() ttype := nCtx.TaskReader().GetTaskType() + taskTemplate, err := nCtx.TaskReader().Read(ctx) + if err != nil { + return handler.UnknownTransition, errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to read task template") + } + tVersion := taskTemplate.GetTaskTypeVersion() ctx = contextutils.WithTaskType(ctx, ttype) - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) + jsonBytes, _ := json.MarshalIndent(p, "", " ") + logger.Debug(ctx, "The task type is [%s]", string(ttype)) + logger.Debug(ctx, "The deployment config is [%s]", string(jsonBytes)) if err != nil { return handler.UnknownTransition, errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } @@ -920,7 +1043,12 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext } ttype := nCtx.TaskReader().GetTaskType() - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + taskTemplate, err := nCtx.TaskReader().Read(ctx) + if err != nil { + return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to read task template") + } + tVersion := taskTemplate.GetTaskTypeVersion() + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) if err != nil { return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } @@ -1015,7 +1143,12 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext func (t Handler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { logger.Debugf(ctx, "Finalize invoked.") ttype := nCtx.TaskReader().GetTaskType() - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + taskTemplate, err := nCtx.TaskReader().Read(ctx) + if err != nil { + return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to read task template") + } + tVersion := taskTemplate.GetTaskTypeVersion() + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) if err != nil { return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } @@ -1057,6 +1190,7 @@ func New(ctx context.Context, kubeClient executors.Client, kubeClientset kuberne return &Handler{ pluginRegistry: pluginMachinery.PluginRegistry(), defaultPlugins: make(map[pluginCore.TaskType]pluginCore.Plugin), + connectorDeploymentsForType: make(map[pluginCore.TaskType]map[string]pluginCore.Plugin), pluginsForType: make(map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin), taskMetricsMap: make(map[MetricKey]*taskMetrics), metrics: &metrics{ @@ -1077,6 +1211,5 @@ func New(ctx context.Context, kubeClient executors.Client, kubeClientset kuberne eventConfig: eventConfig, clusterID: clusterID, agentService: &agent.AgentService{}, - connectorService: &connector.ConnectorService{}, }, nil } diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index 8f51627fb8..af1ff067c1 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -29,7 +29,6 @@ import ( pluginK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" pluginK8sMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/connector" eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" @@ -353,9 +352,8 @@ func Test_task_ResolvePlugin(t *testing.T) { defaultPlugin: tt.fields.defaultPlugin, pluginsForType: tt.fields.pluginsForType, agentService: &agent.AgentService{}, - connectorService: &connector.ConnectorService{}, } - got, err := tk.ResolvePlugin(context.TODO(), tt.args.ttype, tt.args.executionConfig) + got, err := tk.ResolvePlugin(context.TODO(), tt.args.ttype, tt.args.executionConfig, 0) if (err != nil) != tt.wantErr { t.Errorf("Handler.ResolvePlugin() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/flytestdlib/cache/in_memory_auto_refresh.go b/flytestdlib/cache/in_memory_auto_refresh.go index c3072f5b26..35501098b3 100644 --- a/flytestdlib/cache/in_memory_auto_refresh.go +++ b/flytestdlib/cache/in_memory_auto_refresh.go @@ -19,6 +19,12 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) +// Global metrics cache to avoid recreating metrics for the same scope +var ( + metricsCache = make(map[string]*metrics) + metricsMutex sync.RWMutex +) + type metrics struct { SyncErrors prometheus.Counter Evictions prometheus.Counter @@ -30,7 +36,20 @@ type metrics struct { } func newMetrics(scope promutils.Scope) metrics { - return metrics{ + scopeName := scope.CurrentScope() + // Check if we already have metrics for this scope + metricsMutex.RLock() + if cachedMetrics, exists := metricsCache[scopeName]; exists { + defer metricsMutex.RUnlock() + return *cachedMetrics + } + metricsMutex.RUnlock() + + // Create new metrics and store globally + metricsMutex.Lock() + defer metricsMutex.Unlock() + + newMetrics := metrics{ SyncErrors: scope.MustNewCounter("sync_errors", "Counter for sync errors."), Evictions: scope.MustNewCounter("lru_evictions", "Counter for evictions from LRU."), SyncLatency: scope.MustNewStopWatch("latency", "Latency for sync operations.", time.Millisecond), @@ -39,6 +58,9 @@ func newMetrics(scope promutils.Scope) metrics { Size: scope.MustNewGauge("size", "Current size of the cache"), scope: scope, } + + metricsCache[scopeName] = &newMetrics + return newMetrics } func getEvictionFunction(counter prometheus.Counter) func(key interface{}, value interface{}) { @@ -150,7 +172,7 @@ func NewInMemoryAutoRefresh( toDelete: newSyncSet(), syncPeriod: resyncPeriod, workqueue: workqueue.NewRateLimitingQueueWithConfig(syncRateLimiter, workqueue.RateLimitingQueueConfig{ - Name: scope.CurrentScope(), + // Name: scope.CurrentScope(), Clock: opts.clock, }), clock: opts.clock,