Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion flyte-single-binary-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ tasks:
- container
- sidecar
- K8S-ARRAY
- connector-service
- echo
- agent-service
default-for-task-types:
- container: container
- container_array: K8S-ARRAY
Expand All @@ -68,6 +68,61 @@ plugins:
kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }}
cloudwatch-enabled: false
stackdriver-enabled: false
connector-service:
supportedTaskTypes:
- bigquery_query_job_task
- sensor
connectors:
my-test-connector1:
endpoint: "localhost:8001"
insecure: true
timeouts:
CreateTask: 5s
GetTask: 30s
DeleteTask: 30s
defaultTimeout: 30s
webApi:
caching:
resyncInterval: 60s
my-test-connector2:
endpoint: "localhost:8002"
insecure: true
timeouts:
CreateTask: 5s
GetTask: 30s
DeleteTask: 30s
defaultTimeout: 30s
webApi:
caching:
resyncInterval: 120s
defaultConnector:
endpoint: "localhost:8000"
webApi:
caching:
resyncInterval: 120s
connectorForTaskTypes:
- noop_async_agent_task: my-test-connector1
# agent-service:
# supportedTaskTypes:
# - bigquery_query_job_task
# - sensor
# - chatgpt
# agents:
# my-test-connector1:
# endpoint: "localhost:8000"
# insecure: true
# timeouts:
# CreateTask: 5s
# GetTask: 30s
# DeleteTask: 30s
# defaultTimeout: 30s
# defaultAgent:
# endpoint: "localhost:8000"
# webApi:
# caching:
# resyncInterval: 120s
# agentForTaskTypes:
# - noop_async_agent_task: my-test-connector1

database:
postgres:
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru v0.5.4
github.com/imdario/mergo v0.3.13
github.com/kubeflow/training-operator v1.8.0
Expand Down Expand Up @@ -85,7 +86,6 @@ require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug
}
}
}

scopeName := fmt.Sprintf("cache_%s", pluginEntry.ID)
resourceCache, err := NewResourceCache(ctx, pluginEntry.ID, p, p.GetConfig().Caching,
p.GetConfig().ReadRateLimiter, iCtx.MetricsScope().NewSubScope("cache"))
p.GetConfig().ReadRateLimiter, iCtx.MetricsScope().NewSubScope(scopeName))

if err != nil {
return nil, err
Expand Down
31 changes: 30 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package webapi

import (
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -9,6 +10,12 @@ import (
"github.com/flyteorg/flyte/flytestdlib/promutils/labeled"
)

// Global metrics cache to avoid recreating metrics for the same scope
var (
metricsCache = make(map[string]*Metrics)
metricsMutex sync.RWMutex
)

type Metrics struct {
Scope promutils.Scope
ResourceReleased labeled.Counter
Expand All @@ -25,7 +32,26 @@ var (
)

func newMetrics(scope promutils.Scope) Metrics {
return Metrics{
scopeName := scope.CurrentScope()

// Check if we already have metrics for this scope
metricsMutex.RLock()
if cachedMetrics, exists := metricsCache[scopeName]; exists {
defer metricsMutex.RUnlock()
return *cachedMetrics
}
metricsMutex.RUnlock()

// Create new metrics and store globally
metricsMutex.Lock()
defer metricsMutex.Unlock()

// Double-check in case another goroutine created metrics while we were acquiring the lock
if cachedMetrics, exists := metricsCache[scopeName]; exists {
return *cachedMetrics
}

newMetrics := Metrics{
Scope: scope,
ResourceReleased: labeled.NewCounter("resource_release_success",
"Resource allocation token released", scope, labeled.EmitUnlabeledMetric),
Expand All @@ -42,4 +68,7 @@ func newMetrics(scope promutils.Scope) Metrics {
FailedUnmarshalState: labeled.NewCounter("unmarshal_state_failed",
"Failed to unmarshal state", scope, labeled.EmitUnlabeledMetric),
}

metricsCache[scopeName] = &newMetrics
return newMetrics
}
83 changes: 82 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,27 @@ import (
"github.com/flyteorg/flyte/flytestdlib/logger"
)

const defaultPluginBufferSize = 100

type PluginInfo struct {
VersionedTaskType string
DeploymentID string
}

type taskPluginRegistry struct {
m sync.Mutex
k8sPlugin []k8s.PluginEntry
corePlugin []core.PluginEntry
connectorCorePlugin map[string]map[string]core.PluginEntry
pluginChan chan PluginInfo
}

// A singleton variable that maintains a registry of all plugins. The framework uses this to access all plugins
var pluginRegistry = &taskPluginRegistry{}
var pluginRegistry = &taskPluginRegistry{
corePlugin: []core.PluginEntry{},
connectorCorePlugin: make(map[string]map[string]core.PluginEntry),
pluginChan: make(chan PluginInfo, defaultPluginBufferSize),
}

func PluginRegistry() TaskPluginRegistry {
return pluginRegistry
Expand All @@ -43,6 +56,70 @@ func (p *taskPluginRegistry) RegisterRemotePlugin(info webapi.PluginEntry) {
p.corePlugin = append(p.corePlugin, internalRemote.CreateRemotePlugin(info))
}

// RegisterConnectorCorePlugin registers a core plugin for a specific connector deployment
func (p *taskPluginRegistry) RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string) {
ctx := context.Background()
if info.ID == "" {
logger.Panicf(ctx, "ID is required attribute for connector core plugin")
}

if len(info.SupportedTaskTypes) == 0 {
logger.Panicf(ctx, "Plugin should be registered to handle at least one task type")
}

p.m.Lock()
defer p.m.Unlock()

if p.connectorCorePlugin == nil {
p.connectorCorePlugin = make(map[string]map[string]core.PluginEntry)
}

if p.connectorCorePlugin[info.ID] == nil {
p.connectorCorePlugin[info.ID] = make(map[string]core.PluginEntry)
}

p.connectorCorePlugin[info.ID][deploymentID] = internalRemote.CreateRemotePlugin(info)
}

// GetConnectorCorePlugin returns a specific connector core plugin for a task type and deployment ID
func (p *taskPluginRegistry) GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool) {
p.m.Lock()
defer p.m.Unlock()

if p.connectorCorePlugin == nil {
return core.PluginEntry{}, false
}

if plugins, exists := p.connectorCorePlugin[taskType]; exists {
if plugin, exists := plugins[deploymentID]; exists {
return plugin, true
}
}

return core.PluginEntry{}, false
}

func (p *taskPluginRegistry) IsConnectorCorePluginRegistered(taskType string, deploymentID string) bool {
p.m.Lock()
defer p.m.Unlock()

if p.connectorCorePlugin == nil {
return false
}

if plugins, exists := p.connectorCorePlugin[taskType]; exists {
if _, exists := plugins[deploymentID]; exists {
return true
}
}

return false
}

func (p *taskPluginRegistry) GetPluginChan() chan PluginInfo {
return p.pluginChan
}

func CreateRemotePlugin(pluginEntry webapi.PluginEntry) core.PluginEntry {
return internalRemote.CreateRemotePlugin(pluginEntry)
}
Expand Down Expand Up @@ -105,6 +182,10 @@ type TaskPluginRegistry interface {
RegisterK8sPlugin(info k8s.PluginEntry)
RegisterCorePlugin(info core.PluginEntry)
RegisterRemotePlugin(info webapi.PluginEntry)
RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string)
GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool)
IsConnectorCorePluginRegistered(taskType string, deploymentID string) bool
GetCorePlugins() []core.PluginEntry
GetK8sPlugins() []k8s.PluginEntry
GetPluginChan() chan PluginInfo
}
64 changes: 36 additions & 28 deletions flyteplugins/go/tasks/plugins/webapi/connector/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package connector
import (
"context"
"crypto/x509"
"fmt"
"strings"

"golang.org/x/exp/maps"
Expand All @@ -15,11 +16,13 @@ import (

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
)

const defaultTaskTypeVersion = 0
const defaultDeploymentID = "default"

type Connector struct {
// IsSync indicates whether this connector is a sync connector. Sync connectors are expected to return their
Expand Down Expand Up @@ -91,16 +94,21 @@ func getFinalContext(ctx context.Context, operation string, connector *Deploymen
return context.WithTimeout(ctx, timeout)
}

func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry {
newConnectorRegistry := make(Registry)
func watchConnectors(ctx context.Context, cs *ClientSet) {
cfg := GetConfig()
var connectorDeploymentIDs []string
var connectorDeployments []*Deployment

// Merge DefaultConnector (if endpoint is not empty)
if len(cfg.DefaultConnector.Endpoint) != 0 {
connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector)
connectorDeploymentIDs = append(connectorDeploymentIDs, defaultDeploymentID)
connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector)
}
connectorDeploymentIDs = append(connectorDeploymentIDs, maps.Keys(cfg.ConnectorDeployments)...)
connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...)
for _, connectorDeployment := range connectorDeployments {

for idx, connectorDeployment := range connectorDeployments {
deploymentID := connectorDeploymentIDs[idx]
client, ok := cs.connectorMetadataClients[connectorDeployment.Endpoint]
if !ok {
logger.Warningf(ctx, "Connector client not found in the clientSet for the endpoint: %v", connectorDeployment.Endpoint)
Expand Down Expand Up @@ -128,45 +136,40 @@ func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry {
continue
}

// If a connector's support task type plugin was not registered yet, we should do registration
connectorSupportedTaskCategories := make(map[string]struct{})
for _, connector := range res.GetAgents() {
deprecatedSupportedTaskTypes := connector.GetSupportedTaskTypes()
supportedTaskCategories := connector.GetSupportedTaskCategories()
// Process deprecated supported task types
for _, supportedTaskType := range deprecatedSupportedTaskTypes {
connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: connector.GetIsSync()}
newConnectorRegistry[supportedTaskType] = map[int32]*Connector{defaultTaskTypeVersion: connector}
connectorSupportedTaskCategories[supportedTaskType] = struct{}{}
versionedTaskType := createOrUpdatePlugin(ctx, supportedTaskType, defaultTaskTypeVersion, deploymentID, connectorDeployment, cs)
connectorSupportedTaskCategories[versionedTaskType] = struct{}{}
}

supportedTaskCategories := connector.GetSupportedTaskCategories()
// Process supported task categories
for _, supportedCategory := range supportedTaskCategories {
connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: connector.GetIsSync()}
supportedCategoryName := supportedCategory.GetName()
newConnectorRegistry[supportedCategoryName] = map[int32]*Connector{supportedCategory.GetVersion(): connector}
connectorSupportedTaskCategories[supportedCategoryName] = struct{}{}
if supportedCategory.Version != defaultTaskTypeVersion {
versionedTaskType := createOrUpdatePlugin(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs)
connectorSupportedTaskCategories[versionedTaskType] = struct{}{}
}
}
}
logger.Infof(ctx, "ConnectorDeployment [%v] supports the following task types: [%v]", connectorDeployment.Endpoint,
strings.Join(maps.Keys(connectorSupportedTaskCategories), ", "))
}

// Always replace the connector registry with the settings defined in the configuration
// always overwrite with connectorForTaskTypes config
for taskType, connectorDeploymentID := range cfg.ConnectorForTaskTypes {
if connectorDeployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok {
connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: false}
newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector}
if deployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok {
createOrUpdatePlugin(ctx, taskType, defaultTaskTypeVersion, connectorDeploymentID, deployment, cs)
}
}

// Ensure that the old configuration is backward compatible
for _, taskType := range cfg.SupportedTaskTypes {
if _, ok := newConnectorRegistry[taskType]; !ok {
connector := &Connector{ConnectorDeployment: &cfg.DefaultConnector, IsSync: false}
newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector}
versionedTaskType := fmt.Sprintf("%s_%d", taskType, defaultTaskTypeVersion)
if ok := pluginmachinery.PluginRegistry().IsConnectorCorePluginRegistered(versionedTaskType, defaultDeploymentID); !ok {
createOrUpdatePlugin(ctx, taskType, defaultTaskTypeVersion, defaultDeploymentID, &cfg.DefaultConnector, cs)
}
}

logger.Infof(ctx, "ConnectorDeployments support the following task types: [%v]", strings.Join(maps.Keys(newConnectorRegistry), ", "))
return newConnectorRegistry
}

func getConnectorClientSets(ctx context.Context) *ClientSet {
Expand All @@ -176,13 +179,18 @@ func getConnectorClientSets(ctx context.Context) *ClientSet {
connectorMetadataClients: make(map[string]service.AgentMetadataServiceClient),
}

var connectorDeployments []*Deployment
connectorDeployments := make(map[string]*Deployment)
cfg := GetConfig()

// Merge ConnectorDeployments
for key, deployment := range cfg.ConnectorDeployments {
connectorDeployments[key] = deployment
}

// Merge DefaultConnector (if endpoint is not empty)
if len(cfg.DefaultConnector.Endpoint) != 0 {
connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector)
connectorDeployments["default"] = &cfg.DefaultConnector
}
connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...)
for _, connectorDeployment := range connectorDeployments {
if _, ok := clientSet.connectorMetadataClients[connectorDeployment.Endpoint]; ok {
logger.Infof(ctx, "Connector client already initialized for [%v]", connectorDeployment.Endpoint)
Expand Down
Loading
Loading