Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
107 changes: 107 additions & 0 deletions flyteidl2/clients/go/coreutils/extract_literal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// extract_literal.go
// Utility methods to extract a native golang value from a given Literal.
// Usage:
// 1] string literal extraction
// lit, _ := MakeLiteral("test_string")
// val, _ := ExtractFromLiteral(lit)
// 2] integer literal extraction. integer would be extracted in type int64.
// lit, _ := MakeLiteral([]interface{}{1, 2, 3})
// val, _ := ExtractFromLiteral(lit)
// 3] float literal extraction. float would be extracted in type float64.
// lit, _ := MakeLiteral([]interface{}{1.0, 2.0, 3.0})
// val, _ := ExtractFromLiteral(lit)
// 4] map of boolean literal extraction.
// mapInstance := map[string]interface{}{
// "key1": []interface{}{1, 2, 3},
// "key2": []interface{}{5},
// }
// lit, _ := MakeLiteral(mapInstance)
// val, _ := ExtractFromLiteral(lit)
// For further examples check the test TestFetchLiteral in extract_literal_test.go

package coreutils

import (
"fmt"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
)

func ExtractFromLiteral(literal *core.Literal) (interface{}, error) {
switch literalValue := literal.GetValue().(type) {
case *core.Literal_Scalar:
switch scalarValue := literalValue.Scalar.GetValue().(type) {
case *core.Scalar_Primitive:
switch scalarPrimitive := scalarValue.Primitive.GetValue().(type) {
case *core.Primitive_Integer:
scalarPrimitiveInt := scalarPrimitive.Integer
return scalarPrimitiveInt, nil
case *core.Primitive_FloatValue:
scalarPrimitiveFloat := scalarPrimitive.FloatValue
return scalarPrimitiveFloat, nil
case *core.Primitive_StringValue:
scalarPrimitiveString := scalarPrimitive.StringValue
return scalarPrimitiveString, nil
case *core.Primitive_Boolean:
scalarPrimitiveBoolean := scalarPrimitive.Boolean
return scalarPrimitiveBoolean, nil
case *core.Primitive_Datetime:
scalarPrimitiveDateTime := scalarPrimitive.Datetime.AsTime()
return scalarPrimitiveDateTime, nil
case *core.Primitive_Duration:
scalarPrimitiveDuration := scalarPrimitive.Duration.AsDuration()
return scalarPrimitiveDuration, nil
default:
return nil, fmt.Errorf("unsupported literal scalar primitive type %T", scalarValue)
}
case *core.Scalar_Binary:
return scalarValue.Binary, nil
case *core.Scalar_Blob:
return scalarValue.Blob.GetUri(), nil
case *core.Scalar_Schema:
return scalarValue.Schema.GetUri(), nil
case *core.Scalar_Generic:
return scalarValue.Generic, nil
case *core.Scalar_StructuredDataset:
return scalarValue.StructuredDataset.GetUri(), nil
case *core.Scalar_Union:
// extract the value of the union but not the actual union object
extractedVal, err := ExtractFromLiteral(scalarValue.Union.GetValue())
if err != nil {
return nil, err
}
return extractedVal, nil
case *core.Scalar_NoneType:
return nil, nil
default:
return nil, fmt.Errorf("unsupported literal scalar type %T", scalarValue)
}
case *core.Literal_Collection:
collectionValue := literalValue.Collection.GetLiterals()
collection := make([]interface{}, len(collectionValue))
for index, val := range collectionValue {
if collectionElem, err := ExtractFromLiteral(val); err == nil {
collection[index] = collectionElem
} else {
return nil, err
}
}
return collection, nil
case *core.Literal_Map:
mapLiteralValue := literalValue.Map.GetLiterals()
mapResult := make(map[string]interface{}, len(mapLiteralValue))
for key, val := range mapLiteralValue {
if val, err := ExtractFromLiteral(val); err == nil {
mapResult[key] = val
} else {
return nil, err
}
}
return mapResult, nil
case *core.Literal_OffloadedMetadata:
// Return the URI of the offloaded metadata to be used when displaying in flytectl
return literalValue.OffloadedMetadata.GetUri(), nil

}
return nil, fmt.Errorf("unsupported literal type %T", literal)
}
267 changes: 267 additions & 0 deletions flyteidl2/clients/go/coreutils/extract_literal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
// extract_literal_test.go
// Test class for the utility methods which extract a native golang value from a flyte Literal.

package coreutils

import (
"os"
"testing"
"time"

structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
)

func TestFetchLiteral(t *testing.T) {
t.Run("Primitive", func(t *testing.T) {
lit, err := MakeLiteral("test_string")
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, "test_string", val)
})

t.Run("Timestamp", func(t *testing.T) {
now := time.Now().UTC()
lit, err := MakeLiteral(now)
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, now, val)
})

t.Run("Duration", func(t *testing.T) {
duration := time.Second * 10
lit, err := MakeLiteral(duration)
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, duration, val)
})

t.Run("Array", func(t *testing.T) {
lit, err := MakeLiteral([]interface{}{1, 2, 3})
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
arr := []interface{}{int64(1), int64(2), int64(3)}
assert.Equal(t, arr, val)
})

t.Run("Map", func(t *testing.T) {
mapInstance := map[string]interface{}{
"key1": []interface{}{1, 2, 3},
"key2": []interface{}{5},
}
lit, err := MakeLiteral(mapInstance)
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
expectedMapInstance := map[string]interface{}{
"key1": []interface{}{int64(1), int64(2), int64(3)},
"key2": []interface{}{int64(5)},
}
assert.Equal(t, expectedMapInstance, val)
})

t.Run("Map_Booleans", func(t *testing.T) {
mapInstance := map[string]interface{}{
"key1": []interface{}{true, false, true},
"key2": []interface{}{false},
}
lit, err := MakeLiteral(mapInstance)
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, mapInstance, val)
})

t.Run("Map_Floats", func(t *testing.T) {
mapInstance := map[string]interface{}{
"key1": []interface{}{1.0, 2.0, 3.0},
"key2": []interface{}{1.0},
}
lit, err := MakeLiteral(mapInstance)
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
expectedMapInstance := map[string]interface{}{
"key1": []interface{}{float64(1.0), float64(2.0), float64(3.0)},
"key2": []interface{}{float64(1.0)},
}
assert.Equal(t, expectedMapInstance, val)
})

t.Run("NestedMap", func(t *testing.T) {
mapInstance := map[string]interface{}{
"key1": map[string]interface{}{"key11": 1.0, "key12": 2.0, "key13": 3.0},
"key2": map[string]interface{}{"key21": 1.0},
}
lit, err := MakeLiteral(mapInstance)
assert.NoError(t, err)
val, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
expectedMapInstance := map[string]interface{}{
"key1": map[string]interface{}{"key11": float64(1.0), "key12": float64(2.0), "key13": float64(3.0)},
"key2": map[string]interface{}{"key21": float64(1.0)},
}
assert.Equal(t, expectedMapInstance, val)
})

t.Run("Binary", func(t *testing.T) {
s := MakeBinaryLiteral([]byte{'h'})
assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue())
_, err := ExtractFromLiteral(s)
assert.Nil(t, err)
})

t.Run("NoneType", func(t *testing.T) {
p, err := MakeLiteral(nil)
assert.NoError(t, err)
assert.NotNil(t, p.GetScalar())
_, err = ExtractFromLiteral(p)
assert.Nil(t, err)
})

t.Run("Generic", func(t *testing.T) {
os.Setenv(FlyteUseOldDcFormat, "true")
literalVal := map[string]interface{}{
"x": 1,
"y": "ystringvalue",
}
var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}}
lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
fieldsMap := map[string]*structpb.Value{
"x": {
Kind: &structpb.Value_NumberValue{NumberValue: 1},
},
"y": {
Kind: &structpb.Value_StringValue{StringValue: "ystringvalue"},
},
}
expectedStructVal := &structpb.Struct{
Fields: fieldsMap,
}
extractedStructValue := extractedLiteralVal.(*structpb.Struct)
assert.Equal(t, len(expectedStructVal.GetFields()), len(extractedStructValue.GetFields()))
for key, val := range expectedStructVal.GetFields() {
assert.Equal(t, val.GetKind(), extractedStructValue.GetFields()[key].GetKind())
}
os.Unsetenv(FlyteUseOldDcFormat)
})

t.Run("Generic Passed As String", func(t *testing.T) {
literalVal := "{\"x\": 1,\"y\": \"ystringvalue\"}"
var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}}
lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
fieldsMap := map[string]*structpb.Value{
"x": {
Kind: &structpb.Value_NumberValue{NumberValue: 1},
},
"y": {
Kind: &structpb.Value_StringValue{StringValue: "ystringvalue"},
},
}
expectedStructVal := &structpb.Struct{
Fields: fieldsMap,
}
extractedStructValue := extractedLiteralVal.(*structpb.Struct)
assert.Equal(t, len(expectedStructVal.GetFields()), len(extractedStructValue.GetFields()))
for key, val := range expectedStructVal.GetFields() {
assert.Equal(t, val.GetKind(), extractedStructValue.GetFields()[key].GetKind())
}
})

t.Run("Structured dataset", func(t *testing.T) {
literalVal := "s3://blah/blah/blah"
var dataSetColumns []*core.StructuredDatasetType_DatasetColumn
dataSetColumns = append(dataSetColumns, &core.StructuredDatasetType_DatasetColumn{
Name: "Price",
LiteralType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_FLOAT,
},
},
})
var literalType = &core.LiteralType{Type: &core.LiteralType_StructuredDatasetType{StructuredDatasetType: &core.StructuredDatasetType{
Columns: dataSetColumns,
Format: "testFormat",
}}}

lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Offloaded metadata", func(t *testing.T) {
literalVal := "s3://blah/blah/blah"
var storedLiteralType = &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_INTEGER,
},
},
},
}
offloadedLiteral := &core.Literal{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: literalVal,
InferredType: storedLiteralType,
},
},
}
extractedLiteralVal, err := ExtractFromLiteral(offloadedLiteral)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union", func(t *testing.T) {
literalVal := int64(1)
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union with None", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, nil)

assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Nil(t, extractedLiteralVal)
})
}
Loading
Loading