diff --git a/catalog.go b/catalog.go index da3c6bb..a4d7bc6 100644 --- a/catalog.go +++ b/catalog.go @@ -2,20 +2,36 @@ package stac import ( "encoding/json" + "fmt" + "regexp" "github.com/mitchellh/mapstructure" ) type Catalog struct { - Version string `json:"stac_version"` - Id string `json:"id"` - Title string `json:"title,omitempty"` - Description string `json:"description"` - Links []*Link `json:"links"` - ConformsTo []string `json:"conformsTo,omitempty"` + Version string `json:"stac_version"` + Id string `json:"id"` + Title string `json:"title,omitempty"` + Description string `json:"description"` + Links []*Link `json:"links"` + ConformsTo []string `json:"conformsTo,omitempty"` + Extensions []Extension `json:"-"` } -var _ json.Marshaler = (*Catalog)(nil) +var ( + _ json.Marshaler = (*Catalog)(nil) + _ json.Unmarshaler = (*Catalog)(nil) +) + +var catalogExtensions = newExtensionRegistry() + +func RegisterCatalogExtension(pattern *regexp.Regexp, provider ExtensionProvider) { + catalogExtensions.register(pattern, provider) +} + +func GetCatalogExtension(uri string) Extension { + return catalogExtensions.get(uri) +} func (catalog Catalog) MarshalJSON() ([]byte, error) { collectionMap := map[string]any{ @@ -34,5 +50,63 @@ func (catalog Catalog) MarshalJSON() ([]byte, error) { return nil, decodeErr } + extensionUris := []string{} + lookup := map[string]bool{} + + for _, extension := range catalog.Extensions { + if err := extension.Encode(collectionMap); err != nil { + return nil, err + } + uris, err := GetExtensionUris(collectionMap) + if err != nil { + return nil, err + } + uris = append(uris, extension.URI()) + for _, uri := range uris { + if !lookup[uri] { + extensionUris = append(extensionUris, uri) + lookup[uri] = true + } + } + } + + SetExtensionUris(collectionMap, extensionUris) return json.Marshal(collectionMap) } + +func (catalog *Catalog) UnmarshalJSON(data []byte) error { + collectionMap := map[string]any{} + if err := json.Unmarshal(data, &collectionMap); err != nil { + return err + } + + decoder, decoderErr := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + TagName: "json", + Result: catalog, + }) + if decoderErr != nil { + return decoderErr + } + + if err := decoder.Decode(collectionMap); err != nil { + return err + } + + extensionUris, err := GetExtensionUris(collectionMap) + if err != nil { + return err + } + + for _, uri := range extensionUris { + extension := GetCatalogExtension(uri) + if extension == nil { + continue + } + if err := extension.Decode(collectionMap); err != nil { + return fmt.Errorf("decoding error for %s: %w", uri, err) + } + catalog.Extensions = append(catalog.Extensions, extension) + } + + return nil +} diff --git a/catalog_test.go b/catalog_test.go index ab88ab7..c225c77 100644 --- a/catalog_test.go +++ b/catalog_test.go @@ -2,8 +2,10 @@ package stac_test import ( "encoding/json" + "regexp" "testing" + "github.com/mitchellh/mapstructure" "github.com/planetlabs/go-stac" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -38,3 +40,135 @@ func TestCatalogMarshal(t *testing.T) { assert.JSONEq(t, expected, string(data)) } + +const ( + extensionAlias = "test-catalog-extension" + extensionUri = "https://example.com/test-catalog-extension/v1.0.0/schema.json" + extensionPattern = `https://example.com/test-catalog-extension/v1\..*/schema.json` +) + +type CatalogExtension struct { + RequiredNum float64 `json:"required_num"` + OptionalBool *bool `json:"optional_bool,omitempty"` +} + +var _ stac.Extension = (*CatalogExtension)(nil) + +func (*CatalogExtension) URI() string { + return extensionUri +} + +func (e *CatalogExtension) Encode(catalogMap map[string]any) error { + extendedProps := map[string]any{} + encoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + TagName: "json", + Result: &extendedProps, + }) + if err != nil { + return err + } + if err := encoder.Decode(e); err != nil { + return err + } + catalogMap[extensionAlias] = extendedProps + return nil +} + +func (e *CatalogExtension) Decode(catalogMap map[string]any) error { + extendedProps, present := catalogMap[extensionAlias] + if !present { + return nil + } + + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + TagName: "json", + Result: e, + }) + if err != nil { + return err + } + return decoder.Decode(extendedProps) +} + +func TestExtendedCatalogMarshal(t *testing.T) { + stac.RegisterCatalogExtension( + regexp.MustCompile(extensionPattern), + func() stac.Extension { + return &CatalogExtension{} + }, + ) + + catalog := &stac.Catalog{ + Description: "Test catalog with extension", + Id: "catalog-id", + Extensions: []stac.Extension{ + &CatalogExtension{ + RequiredNum: 42, + }, + }, + Links: []*stac.Link{}, + Version: "1.2.3", + } + + data, err := json.Marshal(catalog) + require.NoError(t, err) + + expected := `{ + "type": "Catalog", + "description": "Test catalog with extension", + "id": "catalog-id", + "test-catalog-extension": { + "required_num": 42 + }, + "links": [], + "stac_extensions": [ + "https://example.com/test-catalog-extension/v1.0.0/schema.json" + ], + "stac_version": "1.2.3" + }` + + assert.JSONEq(t, expected, string(data)) +} + +func TestExtendedCatalogUnmarshal(t *testing.T) { + stac.RegisterCatalogExtension( + regexp.MustCompile(extensionPattern), + func() stac.Extension { + return &CatalogExtension{} + }, + ) + + data := []byte(`{ + "type": "Catalog", + "description": "Test catalog with extension", + "id": "catalog-id", + "test-catalog-extension": { + "required_num": 100, + "optional_bool": true + }, + "links": [], + "stac_extensions": [ + "https://example.com/test-catalog-extension/v1.0.0/schema.json" + ], + "stac_version": "1.2.3" + }`) + + catalog := &stac.Catalog{} + require.NoError(t, json.Unmarshal(data, catalog)) + + b := true + expected := &stac.Catalog{ + Description: "Test catalog with extension", + Id: "catalog-id", + Extensions: []stac.Extension{ + &CatalogExtension{ + RequiredNum: 100, + OptionalBool: &b, + }, + }, + Links: []*stac.Link{}, + Version: "1.2.3", + } + + assert.Equal(t, expected, catalog) +}