From ec46a3a4375d6fe1c948c6f25146bb572717c651 Mon Sep 17 00:00:00 2001 From: Jake Van Vorhis <83739412+jakedoublev@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:26:47 -0800 Subject: [PATCH] feat(policy): limit/offset throughout LIST service RPCs/db (#1669) Closes #55 --- docs/configuration.md | 18 + opentdf-dev.yaml | 5 + opentdf-example.yaml | 5 + opentdf-with-hsm.yaml | 2 + service/cmd/policy.go | 8 +- service/integration/attribute_values_test.go | 191 +++++- service/integration/attributes_test.go | 203 ++++-- service/integration/config.go | 4 - service/integration/kas_registry_test.go | 242 ++++++-- service/integration/namespaces_test.go | 169 +++-- service/integration/resource_mappings_test.go | 197 ++++-- service/integration/subject_mappings_test.go | 167 ++++- service/internal/fixtures/db.go | 12 +- service/internal/fixtures/fixtures.go | 6 - service/pkg/db/errors.go | 6 + service/policy/attributes/attributes.go | 25 +- service/policy/config/config.go | 38 ++ service/policy/db/attribute_fqn.go | 8 +- service/policy/db/attribute_values.go | 60 +- service/policy/db/attributes.go | 63 +- .../policy/db/key_access_server_registry.go | 63 +- service/policy/db/namespaces.go | 69 +- service/policy/db/policy.go | 32 +- service/policy/db/query.sql | 235 ++++--- service/policy/db/query.sql.go | 587 +++++++++++++----- service/policy/db/resource_mapping.go | 62 +- service/policy/db/subject_mappings.go | 60 +- service/policy/db/utils.go | 21 + service/policy/db/utils_test.go | 92 +++ .../kasregistry/key_access_server_registry.go | 25 +- service/policy/namespaces/namespaces.go | 15 +- .../resourcemapping/resource_mapping.go | 22 +- .../policy/subjectmapping/subject_mapping.go | 22 +- service/policy/unsafe/unsafe.go | 10 +- 34 files changed, 2178 insertions(+), 566 deletions(-) create mode 100644 service/policy/config/config.go create mode 100644 service/policy/db/utils_test.go diff --git a/docs/configuration.md b/docs/configuration.md index b9ffd50e6..9997a71a5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -219,6 +219,24 @@ services: query: data.opentdf.entitlements.attributes ``` +### Policy + +Root level key `policy` + +| Field | Description | Default | Environment Variables | +| ---------------------------- | ------------------------------------------------------ | ------- | -------------------------------------------------- | +| `list_request_limit_default` | Policy List request limit default when not provided | 1000 | OPENTDF_SERVICES_POLICY_LIST_REQUEST_LIMIT_DEFAULT | +| `list_request_limit_max` | Policy List request limit maximum enforced by services | 2500 | OPENTDF_SERVICES_POLICY_LIST_REQUEST_LIMIT_MAX | + +Example: + +```yaml +services: + policy: + list_request_limit_default: 1000 + list_request_limit_max: 2500 +``` + ### Casbin Endpoint Authorization OpenTDF uses Casbin to manage authorization policies. This document provides an overview of how to configure and manage the default authorization policy in OpenTDF. diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index 7c1f5d2b9..5ea3bf2fe 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -32,6 +32,11 @@ services: from: email: true username: true + # policy is enabled by default in mode 'all' + # policy: + # enabled: true + # list_request_limit_default: 1000 + # list_request_limit_max: 2500 server: tls: enabled: false diff --git a/opentdf-example.yaml b/opentdf-example.yaml index 6fd6549ae..1033c066c 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -23,6 +23,11 @@ services: from: email: true username: true + # policy is enabled by default in mode 'all' + # policy: + # enabled: true + # list_request_limit_default: 1000 + # list_request_limit_max: 2500 server: auth: enabled: true diff --git a/opentdf-with-hsm.yaml b/opentdf-with-hsm.yaml index 476a45922..92871daae 100644 --- a/opentdf-with-hsm.yaml +++ b/opentdf-with-hsm.yaml @@ -15,6 +15,8 @@ services: rsacertid: r1 policy: enabled: true + # list_request_limit_default: 1000 + # list_request_limit_max: 2500 entityresolution: enabled: true url: http://localhost:8888/auth diff --git a/service/cmd/policy.go b/service/cmd/policy.go index cadcf4ecc..eea0bca5f 100644 --- a/service/cmd/policy.go +++ b/service/cmd/policy.go @@ -83,7 +83,13 @@ func policyDBClient(conf *config.Config) (policydb.PolicyDBClient, error) { return policydb.PolicyDBClient{}, err } - return policydb.NewClient(dbClient, logger), nil + // This command connects directly to the database so runtime policy config list limit settings can be ignored + var ( + limitDefault int32 = 1000 + limitMax int32 = 2500 + ) + + return policydb.NewClient(dbClient, logger, limitMax, limitDefault), nil } func init() { diff --git a/service/integration/attribute_values_test.go b/service/integration/attribute_values_test.go index 882daa7a3..60a1b5a1f 100644 --- a/service/integration/attribute_values_test.go +++ b/service/integration/attribute_values_test.go @@ -13,8 +13,8 @@ import ( "github.com/opentdf/platform/protocol/go/policy/unsafe" "github.com/opentdf/platform/service/internal/fixtures" "github.com/opentdf/platform/service/pkg/db" - policydb "github.com/opentdf/platform/service/policy/db" "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" ) var absentAttributeValueUUID = "78909865-8888-9999-9999-000000000000" @@ -44,30 +44,143 @@ func (s *AttributeValuesSuite) TearDownSuite() { s.f.TearDown() } -func (s *AttributeValuesSuite) Test_ListAttributeValues() { +func (s *AttributeValuesSuite) Test_ListAttributeValues_WithAttributeID_Succeeds() { attrID := s.f.GetAttributeValueKey("example.com/attr/attr1/value/value1").AttributeDefinitionID - list, err := s.db.PolicyClient.ListAttributeValues(s.ctx, attrID, policydb.StateActive) + listRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + AttributeId: attrID, + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, + }) s.Require().NoError(err) - s.NotNil(list) + s.NotNil(listRsp) + listed := listRsp.GetValues() // ensure list contains the two test fixtures and that response matches expected data f1 := s.f.GetAttributeValueKey("example.com/attr/attr1/value/value1") f2 := s.f.GetAttributeValueKey("example.com/attr/attr1/value/value2") - for _, item := range list { - if item.GetId() == f1.ID { - s.Equal(f1.ID, item.GetId()) - s.Equal(f1.Value, item.GetValue()) - // s.Equal(f1.AttributeDefinitionId, item.AttributeId) - } else if item.GetId() == f2.ID { - s.Equal(f2.ID, item.GetId()) - s.Equal(f2.Value, item.GetValue()) - // s.Equal(f2.AttributeDefinitionId, item.AttributeId) + for _, val := range listed { + if val.GetId() == f1.ID { + s.Equal(f1.ID, val.GetId()) + s.Equal(f1.Value, val.GetValue()) + s.Equal(f1.AttributeDefinitionID, val.GetAttribute().GetId()) + } else if val.GetId() == f2.ID { + s.Equal(f2.ID, val.GetId()) + s.Equal(f2.Value, val.GetValue()) + s.Equal(f2.AttributeDefinitionID, val.GetAttribute().GetId()) } } } +func (s *AttributeValuesSuite) Test_ListAttributeValues_NoPagination_Succeeds() { + allFixtureValueFqns := map[string]bool{ + "https://example.com/attr/attr1/value/value1": false, + "https://example.com/attr/attr1/value/value2": false, + "https://example.com/attr/attr2/value/value1": false, + "https://example.com/attr/attr2/value/value2": false, + "https://example.net/attr/attr1/value/value1": false, + "https://example.net/attr/attr1/value/value2": false, + "https://scenario.com/attr/working_group/value/blue": false, + "https://deactivated.io/attr/deactivated_attr/value/deactivated_value": false, + } + listRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + // mark every listed value true + for _, val := range listRsp.GetValues() { + allFixtureValueFqns[val.GetFqn()] = true + } + // ensure all fixtures were found by unbounded list + for fqn, found := range allFixtureValueFqns { + if !found { + s.Failf("failed to list fixture", fqn) + } + } +} + +func (s *AttributeValuesSuite) Test_ListAttributeValues_Limit_Succeeds() { + var limit int32 = 2 + listRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetValues() + s.Equal(len(listed), int(limit)) + + for _, val := range listed { + s.NotEmpty(val.GetFqn()) + s.NotEmpty(val.GetId()) + s.NotEmpty(val.GetValue()) + } + + // request with one below maximum + listRsp, err = s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax - 1, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + // request with exactly maximum + listRsp, err = s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) +} + +func (s *NamespacesSuite) Test_ListAttributeValues_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *AttributeValuesSuite) Test_ListAttributeValues_Offset_Succeeds() { + req := &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + } + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetValues() + + // set the offset pagination + offset := 5 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetValues() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, val := range offsetListed { + s.True(proto.Equal(val, listed[i+offset])) + } +} + func (s *AttributeValuesSuite) Test_GetAttributeValue() { f := s.f.GetAttributeValueKey("example.com/attr/attr1/value/value1") v, err := s.db.PolicyClient.GetAttributeValue(s.ctx, f.ID) @@ -133,7 +246,7 @@ func (s *AttributeValuesSuite) Test_CreateAttributeValue_SetsActiveStateTrueByDe attrDef := s.f.GetAttributeKey("example.net/attr/attr1") req := &attributes.CreateAttributeValueRequest{ - Value: "testing create gives active true by default", + Value: "testing-create-gives-active-true-by-default", } createdValue, err := s.db.PolicyClient.CreateAttributeValue(s.ctx, attrDef.ID, req) s.Require().NoError(err) @@ -495,15 +608,18 @@ func setupDeactivateAttributeValue(s *AttributeValuesSuite) (string, string, str func (s *AttributeValuesSuite) Test_DeactivateAttribute_Cascades_List() { type test struct { name string - testFunc func(state string) bool - state string + testFunc func(state common.ActiveStateEnum) bool + state common.ActiveStateEnum isFound bool } - listNamespaces := func(state string) bool { - listedNamespaces, err := s.db.PolicyClient.ListNamespaces(s.ctx, state) + listNamespaces := func(state common.ActiveStateEnum) bool { + listedNamespacesRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: state, + }) s.Require().NoError(err) - s.NotNil(listedNamespaces) + s.NotNil(listedNamespacesRsp) + listedNamespaces := listedNamespacesRsp.GetNamespaces() for _, ns := range listedNamespaces { if stillActiveNsID == ns.GetId() { return true @@ -512,10 +628,13 @@ func (s *AttributeValuesSuite) Test_DeactivateAttribute_Cascades_List() { return false } - listAttributes := func(state string) bool { - listedAttrs, err := s.db.PolicyClient.ListAttributes(s.ctx, state, "") + listAttributes := func(state common.ActiveStateEnum) bool { + listedAttrsRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: state, + }) s.Require().NoError(err) - s.NotNil(listedAttrs) + s.NotNil(listedAttrsRsp) + listedAttrs := listedAttrsRsp.GetAttributes() for _, a := range listedAttrs { if stillActiveAttributeID == a.GetId() { return true @@ -524,10 +643,14 @@ func (s *AttributeValuesSuite) Test_DeactivateAttribute_Cascades_List() { return false } - listValues := func(state string) bool { - listedVals, err := s.db.PolicyClient.ListAttributeValues(s.ctx, stillActiveAttributeID, state) + listValues := func(state common.ActiveStateEnum) bool { + listedValsRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + State: state, + AttributeId: stillActiveAttributeID, + }) s.Require().NoError(err) - s.NotNil(listedVals) + s.NotNil(listedValsRsp) + listedVals := listedValsRsp.GetValues() for _, v := range listedVals { if deactivatedAttrValueID == v.GetId() { return true @@ -540,55 +663,55 @@ func (s *AttributeValuesSuite) Test_DeactivateAttribute_Cascades_List() { { name: "namespace is NOT found in LIST of INACTIVE", testFunc: listNamespaces, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: false, }, { name: "namespace is found when filtering for ACTIVE state", testFunc: listNamespaces, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: true, }, { name: "namespace is found when filtering for ANY state", testFunc: listNamespaces, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, { name: "attribute is NOT found when filtering for INACTIVE state", testFunc: listAttributes, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: false, }, { name: "attribute is found when filtering for ANY state", testFunc: listAttributes, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, { name: "attribute is found when filtering for ACTIVE state", testFunc: listAttributes, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: true, }, { name: "value is NOT found in LIST of ACTIVE", testFunc: listValues, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: false, }, { name: "value is found when filtering for INACTIVE state", testFunc: listValues, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: true, }, { name: "value is found when filtering for ANY state", testFunc: listValues, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, } diff --git a/service/integration/attributes_test.go b/service/integration/attributes_test.go index af98d819f..e99d0ddcf 100644 --- a/service/integration/attributes_test.go +++ b/service/integration/attributes_test.go @@ -18,6 +18,7 @@ import ( "github.com/opentdf/platform/service/pkg/db" policydb "github.com/opentdf/platform/service/policy/db" "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" ) type AttributesSuite struct { @@ -369,17 +370,20 @@ func (s *AttributesSuite) Test_GetAttribute_ContainsKASGrants() { s.Equal(kas.GetName(), gotGrants[0].GetName()) } -func (s *AttributesSuite) Test_ListAttributes() { +func (s *AttributesSuite) Test_ListAttributes_NoPagination_Succeeds() { fixtures := s.getAttributeFixtures() - list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateActive, "") + r := &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + } + listRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, r) s.Require().NoError(err) - s.NotNil(list) + s.NotNil(listRsp) // all fixtures are listed for _, f := range fixtures { var found bool - for _, l := range list { + for _, l := range listRsp.GetAttributes() { if f.ID == l.GetId() { found = true break @@ -389,6 +393,87 @@ func (s *AttributesSuite) Test_ListAttributes() { } } +func (s *AttributesSuite) Test_ListAttributes_Limit_Succeeds() { + var limit int32 = 2 + listRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetAttributes() + s.Equal(len(listed), int(limit)) + + for _, definition := range listed { + s.NotEmpty(definition.GetFqn()) + s.NotEmpty(definition.GetId()) + s.NotEmpty(definition.GetName()) + } + + // request with one below maximum + listRsp, err = s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax - 1, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + // exactly maximum + listRsp, err = s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) +} + +func (s *NamespacesSuite) Test_ListAttributes_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *AttributesSuite) Test_ListAttributes_Offset_Succeeds() { + req := &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + } + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetAttributes() + + // set the offset pagination + offset := 2 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetAttributes() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, attr := range offsetListed { + s.True(proto.Equal(attr, listed[i+offset])) + } +} + func (s *AttributesSuite) Test_ListAttributes_FqnsIncluded() { // create an attribute attr := &attributes.CreateAttributeRequest{ @@ -401,11 +486,15 @@ func (s *AttributesSuite) Test_ListAttributes_FqnsIncluded() { s.Require().NoError(err) s.NotNil(createdAttr) - list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateActive, fixtureNamespaceID) + r := &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, + Namespace: fixtureNamespaceID, + } + list, err := s.db.PolicyClient.ListAttributes(s.ctx, r) s.Require().NoError(err) s.NotNil(list) - for _, a := range list { + for _, a := range list.GetAttributes() { // attr fqn s.NotEqual("", a.GetFqn()) s.Equal(fmt.Sprintf("https://%s/attr/%s", a.GetNamespace().GetName(), a.GetName()), a.GetFqn()) @@ -428,37 +517,45 @@ func (s *AttributesSuite) Test_ListAttributes_ByNamespaceIdOrName() { for _, f := range s.getAttributeFixtures() { namespaces[f.NamespaceID] = "" } + r := &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + } // list attributes by namespace id for nsID := range namespaces { - list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, nsID) + r.Namespace = nsID + rsp, err := s.db.PolicyClient.ListAttributes(s.ctx, r) s.Require().NoError(err) - s.NotNil(list) - s.NotEmpty(list) - for _, l := range list { + s.NotNil(rsp) + listed := rsp.GetAttributes() + s.NotEmpty(listed) + for _, l := range listed { s.Equal(nsID, l.GetNamespace().GetId()) } - namespaces[nsID] = list[0].GetNamespace().GetName() + namespaces[nsID] = listed[0].GetNamespace().GetName() } // list attributes by namespace name for _, nsName := range namespaces { - list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, nsName) + r.Namespace = nsName + rsp, err := s.db.PolicyClient.ListAttributes(s.ctx, r) s.Require().NoError(err) - s.NotNil(list) - s.NotEmpty(list) - for _, l := range list { + s.NotNil(rsp) + listed := rsp.GetAttributes() + s.NotEmpty(listed) + for _, l := range listed { s.Equal(nsName, l.GetNamespace().GetName()) } } // list attributes by namespace name with case insensitivity for _, nsName := range namespaces { - upperNsName := strings.ToUpper(nsName) - list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, upperNsName) + r.Namespace = strings.ToUpper(nsName) + rsp, err := s.db.PolicyClient.ListAttributes(s.ctx, r) s.Require().NoError(err) - s.NotNil(list) - s.NotEmpty(list) - for _, l := range list { + s.NotNil(rsp) + listed := rsp.GetAttributes() + s.NotEmpty(listed) + for _, l := range listed { s.Equal(nsName, l.GetNamespace().GetName()) } } @@ -812,10 +909,14 @@ func (s *AttributesSuite) Test_UnsafeDeleteAttribute() { s.NotEqual("", ns.GetId()) // attribute should not be listed anymore - list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, fixtureNamespaceID) + rsp, err := s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Namespace: fixtureNamespaceID, + }) s.Require().NoError(err) - s.NotNil(list) - for _, l := range list { + s.NotNil(rsp) + listed := rsp.GetAttributes() + for _, l := range listed { s.NotEqual(createdAttr.GetId(), l.GetId()) } @@ -908,16 +1009,19 @@ func setupCascadeDeactivateAttribute(s *AttributesSuite) (string, string, string func (s *AttributesSuite) Test_DeactivateAttribute_Cascades_List() { type test struct { name string - testFunc func(state string) bool - state string + testFunc func(state common.ActiveStateEnum) bool + state common.ActiveStateEnum isFound bool } - listNamespaces := func(state string) bool { - listedNamespaces, err := s.db.PolicyClient.ListNamespaces(s.ctx, state) + listNamespaces := func(state common.ActiveStateEnum) bool { + nsListRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: state, + }) s.Require().NoError(err) - s.NotNil(listedNamespaces) - for _, ns := range listedNamespaces { + s.NotNil(nsListRsp) + listed := nsListRsp.GetNamespaces() + for _, ns := range listed { if stillActiveNsID == ns.GetId() { return true } @@ -925,11 +1029,14 @@ func (s *AttributesSuite) Test_DeactivateAttribute_Cascades_List() { return false } - listAttributes := func(state string) bool { - listedAttrs, err := s.db.PolicyClient.ListAttributes(s.ctx, state, "") + listAttributes := func(state common.ActiveStateEnum) bool { + listAttrsRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: state, + }) s.Require().NoError(err) - s.NotNil(listedAttrs) - for _, a := range listedAttrs { + s.NotNil(listAttrsRsp) + listed := listAttrsRsp.GetAttributes() + for _, a := range listed { if deactivatedAttrID == a.GetId() { return true } @@ -937,11 +1044,15 @@ func (s *AttributesSuite) Test_DeactivateAttribute_Cascades_List() { return false } - listValues := func(state string) bool { - listedVals, err := s.db.PolicyClient.ListAttributeValues(s.ctx, deactivatedAttrID, state) + listValues := func(state common.ActiveStateEnum) bool { + valsListRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + AttributeId: deactivatedAttrID, + State: state, + }) s.Require().NoError(err) - s.NotNil(listedVals) - for _, v := range listedVals { + s.NotNil(valsListRsp) + listed := valsListRsp.GetValues() + for _, v := range listed { if deactivatedAttrValueID == v.GetId() { return true } @@ -953,55 +1064,55 @@ func (s *AttributesSuite) Test_DeactivateAttribute_Cascades_List() { { name: "namespace is NOT found in LIST of INACTIVE", testFunc: listNamespaces, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: false, }, { name: "namespace is found when filtering for ACTIVE state", testFunc: listNamespaces, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: true, }, { name: "namespace is found when filtering for ANY state", testFunc: listNamespaces, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, { name: "attribute is found when filtering for INACTIVE state", testFunc: listAttributes, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: true, }, { name: "attribute is found when filtering for ANY state", testFunc: listAttributes, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, { name: "attribute is NOT found when filtering for ACTIVE state", testFunc: listAttributes, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: false, }, { name: "value is NOT found in LIST of ACTIVE", testFunc: listValues, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: false, }, { name: "value is found when filtering for INACTIVE state", testFunc: listValues, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: true, }, { name: "value is found when filtering for ANY state", testFunc: listValues, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, } diff --git a/service/integration/config.go b/service/integration/config.go index 1956f8b2f..64ad2b2b7 100644 --- a/service/integration/config.go +++ b/service/integration/config.go @@ -2,10 +2,6 @@ package integration import "github.com/opentdf/platform/service/internal/config" -const ( - nonExistentAttributeValueUUID = "78909865-8888-9999-9999-000000000000" -) - var Config *config.Config func init() { diff --git a/service/integration/kas_registry_test.go b/service/integration/kas_registry_test.go index 87e472363..993660f25 100644 --- a/service/integration/kas_registry_test.go +++ b/service/integration/kas_registry_test.go @@ -14,6 +14,7 @@ import ( "github.com/opentdf/platform/protocol/go/policy/namespaces" "github.com/opentdf/platform/service/internal/fixtures" "github.com/opentdf/platform/service/pkg/db" + "google.golang.org/protobuf/proto" "github.com/stretchr/testify/suite" ) @@ -49,23 +50,91 @@ func (s *KasRegistrySuite) getKasRegistryFixtures() []fixtures.FixtureDataKasReg } } -func (s *KasRegistrySuite) Test_ListKeyAccessServers() { +func (s *KasRegistrySuite) Test_ListKeyAccessServers_NoPagination_Succeeds() { fixtures := s.getKasRegistryFixtures() - list, err := s.db.PolicyClient.ListKeyAccessServers(s.ctx) - s.Require().NoError(err) - s.NotNil(list) - for _, fixture := range fixtures { - for _, kas := range list { - if kas.GetId() == fixture.ID { - if kas.GetPublicKey().GetRemote() != "" { - s.Equal(fixture.PubKey.Remote, kas.GetPublicKey().GetRemote()) - } else { - s.Equal(fixture.PubKey.Cached, kas.GetPublicKey().GetCached()) - } - s.Equal(fixture.URI, kas.GetUri()) - s.Equal(fixture.Name, kas.GetName()) + listRsp, err := s.db.PolicyClient.ListKeyAccessServers(s.ctx, &kasregistry.ListKeyAccessServersRequest{}) + s.Require().NoError(err) + s.NotNil(listRsp) + + listed := listRsp.GetKeyAccessServers() + s.NotEmpty(listed) + + // ensure we find each fixture in the list response + for _, f := range fixtures { + found := false + for _, kasr := range listed { + if kasr.GetId() == f.ID { + found = true } } + s.True(found) + } +} + +func (s *KasRegistrySuite) Test_ListKeyAccessServers_Limit_Succeeds() { + var limit int32 = 2 + listRsp, err := s.db.PolicyClient.ListKeyAccessServers(s.ctx, &kasregistry.ListKeyAccessServersRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + listed := listRsp.GetKeyAccessServers() + s.Equal(len(listed), int(limit)) + + for _, kas := range listed { + s.NotEmpty(kas.GetId()) + s.NotEmpty(kas.GetUri()) + s.NotNil(kas.GetPublicKey()) + } + + // request with one below maximum + listRsp, err = s.db.PolicyClient.ListKeyAccessServers(s.ctx, &kasregistry.ListKeyAccessServersRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax - 1, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) +} + +func (s *NamespacesSuite) Test_ListKeyAccessServers_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListKeyAccessServers(s.ctx, &kasregistry.ListKeyAccessServersRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *KasRegistrySuite) Test_ListKeyAccessServers_Offset_Succeeds() { + req := &kasregistry.ListKeyAccessServersRequest{} + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListKeyAccessServers(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetKeyAccessServers() + + // set the offset pagination + offset := 1 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListKeyAccessServers(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetKeyAccessServers() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, val := range offsetListed { + s.True(proto.Equal(val, listed[i+offset])) } } @@ -591,9 +660,13 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasId() { s.NotNil(valGrant) // list grants by KAS ID - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, firstKAS.GetId(), "", "") + listRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + KasId: firstKAS.GetId(), + }) s.Require().NoError(err) - s.NotNil(listedGrants) + s.NotNil(listRsp) + + listedGrants := listRsp.GetGrants() s.Len(listedGrants, 1) g := listedGrants[0] s.Equal(firstKAS.GetId(), g.GetKeyAccessServer().GetId()) @@ -604,9 +677,13 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasId() { s.Empty(g.GetNamespaceGrants()) // list grants by the other KAS ID - listedGrants, err = s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, otherKAS.GetId(), "", "") + listRsp, err = s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + KasId: otherKAS.GetId(), + }) s.Require().NoError(err) - s.NotNil(listedGrants) + s.NotNil(listRsp) + + listedGrants = listRsp.GetGrants() s.Len(listedGrants, 1) g = listedGrants[0] s.Equal(otherKAS.GetId(), g.GetKeyAccessServer().GetId()) @@ -619,9 +696,11 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasId() { func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasId_NoResultsIfNotFound() { // list grants by KAS ID - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, nonExistentKasRegistryID, "", "") + listRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + KasId: nonExistentKasRegistryID, + }) s.Require().NoError(err) - s.Empty(listedGrants) + s.Empty(listRsp.GetGrants()) } func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasUri() { @@ -647,10 +726,12 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasUri() { s.NotNil(createdGrant) // list grants by KAS URI - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", fixtureKAS.URI, "") - + listGrantsRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + KasUri: fixtureKAS.URI, + }) s.Require().NoError(err) - s.NotNil(listedGrants) + s.NotNil(listGrantsRsp) + listedGrants := listGrantsRsp.GetGrants() s.GreaterOrEqual(len(listedGrants), 1) for _, g := range listedGrants { s.Equal(fixtureKAS.ID, g.GetKeyAccessServer().GetId()) @@ -661,9 +742,11 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasUri() { func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasUri_NoResultsIfNotFound() { // list grants by KAS ID - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", "https://notfound.com/kas/uri", "") + listGrantsRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + KasUri: "https://notfound.com/kas/uri", + }) s.Require().NoError(err) - s.Empty(listedGrants) + s.Empty(listGrantsRsp.GetGrants()) } func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasName() { @@ -689,23 +772,38 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasName() { s.NotNil(createdGrant) // list grants by KAS URI - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", "", fixtureKAS.Name) + listedGrantsRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, + &kasregistry.ListKeyAccessServerGrantsRequest{ + KasName: fixtureKAS.Name, + }) s.Require().NoError(err) - s.NotNil(listedGrants) + s.NotNil(listedGrantsRsp) + listedGrants := listedGrantsRsp.GetGrants() s.GreaterOrEqual(len(listedGrants), 1) + found := false for _, g := range listedGrants { - s.Equal(fixtureKAS.ID, g.GetKeyAccessServer().GetId()) - s.Equal(fixtureKAS.URI, g.GetKeyAccessServer().GetUri()) - s.Equal(fixtureKAS.Name, g.GetKeyAccessServer().GetName()) + if g.GetKeyAccessServer().GetId() == fixtureKAS.ID { + s.Equal(fixtureKAS.URI, g.GetKeyAccessServer().GetUri()) + s.Equal(fixtureKAS.Name, g.GetKeyAccessServer().GetName()) + for _, attrGrant := range g.GetAttributeGrants() { + if attrGrant.GetId() == createdAttr.GetId() { + found = true + s.Equal(attrGrant.GetFqn(), createdAttr.GetFqn()) + } + } + } } + s.True(found) } func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_KasName_NoResultsIfNotFound() { // list grants by KAS ID - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", "", "unknown_kas") + listGrantsRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + KasName: "unknown-name", + }) s.Require().NoError(err) - s.Empty(listedGrants) + s.Empty(listGrantsRsp.GetGrants()) } func (s *KasRegistrySuite) Test_ListAllKeyAccessServerGrants() { @@ -812,11 +910,14 @@ func (s *KasRegistrySuite) Test_ListAllKeyAccessServerGrants() { s.NotNil(nsAnotherGrant) // list all grants - listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", "", "") + listGrantsRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{}) s.Require().NoError(err) - s.NotNil(listedGrants) - s.GreaterOrEqual(len(listedGrants), 2) + s.NotNil(listGrantsRsp) + listedGrants := listGrantsRsp.GetGrants() + s.GreaterOrEqual(len(listedGrants), 1) + s.GreaterOrEqual(len(listedGrants), 2) + foundCount := 0 for _, g := range listedGrants { switch g.GetKeyAccessServer().GetId() { case firstKAS.GetId(): @@ -828,6 +929,9 @@ func (s *KasRegistrySuite) Test_ListAllKeyAccessServerGrants() { s.Len(g.GetNamespaceGrants(), 1) s.Equal(createdNs.GetId(), g.GetNamespaceGrants()[0].GetId()) s.Equal(nsFQN, g.GetNamespaceGrants()[0].GetFqn()) + + foundCount++ + case secondKAS.GetId(): // should have expected value grant s.Len(g.GetValueGrants(), 1) @@ -837,8 +941,76 @@ func (s *KasRegistrySuite) Test_ListAllKeyAccessServerGrants() { s.Len(g.GetNamespaceGrants(), 1) s.Equal(createdNs.GetId(), g.GetNamespaceGrants()[0].GetId()) s.Equal(nsFQN, g.GetNamespaceGrants()[0].GetFqn()) + + foundCount++ } } + s.Equal(2, foundCount) +} + +func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_Limit_Succeeds() { + var limit int32 = 2 + listRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + listed := listRsp.GetGrants() + s.Equal(len(listed), int(limit)) + + for _, grant := range listed { + s.NotNil(grant.GetKeyAccessServer()) + } + + // request with one below maximum + listRsp, err = s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax - 1, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) +} + +func (s *NamespacesSuite) Test_ListKeyAccessServerGrants_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, &kasregistry.ListKeyAccessServerGrantsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *KasRegistrySuite) Test_ListKeyAccessServerGrants_Offset_Succeeds() { + req := &kasregistry.ListKeyAccessServerGrantsRequest{} + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetGrants() + + // set the offset pagination + offset := 1 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetGrants() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, val := range offsetListed { + s.True(proto.Equal(val, listed[i+offset])) + } } func TestKasRegistrySuite(t *testing.T) { diff --git a/service/integration/namespaces_test.go b/service/integration/namespaces_test.go index bd36a4ec9..249b33f87 100644 --- a/service/integration/namespaces_test.go +++ b/service/integration/namespaces_test.go @@ -13,8 +13,8 @@ import ( "github.com/opentdf/platform/protocol/go/policy/namespaces" "github.com/opentdf/platform/service/internal/fixtures" "github.com/opentdf/platform/service/pkg/db" - policydb "github.com/opentdf/platform/service/policy/db" "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" ) type NamespacesSuite struct { @@ -129,13 +129,87 @@ func (s *NamespacesSuite) Test_GetNamespace_DoesNotExist_ShouldFail() { s.Nil(ns) } -func (s *NamespacesSuite) Test_ListNamespaces() { +func (s *NamespacesSuite) Test_ListNamespaces_NoPagination_Succeeds() { testData := s.getActiveNamespaceFixtures() - gotNamespaces, err := s.db.PolicyClient.ListNamespaces(s.ctx, policydb.StateActive) + listNamespacesRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, + }) + s.Require().NoError(err) + s.NotNil(listNamespacesRsp) + listed := listNamespacesRsp.GetNamespaces() + s.GreaterOrEqual(len(listed), len(testData)) + + for _, f := range testData { + found := false + for _, ns := range listed { + if ns.GetId() == f.ID { + found = true + } + } + s.True(found) + } +} + +func (s *NamespacesSuite) Test_ListNamespaces_Limit_Succeeds() { + var limit int32 = 2 + listRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) s.Require().NoError(err) - s.NotNil(gotNamespaces) - s.GreaterOrEqual(len(gotNamespaces), len(testData)) + s.NotNil(listRsp) + listed := listRsp.GetNamespaces() + s.Equal(len(listed), int(limit)) + + for _, ns := range listed { + s.NotEmpty(ns.GetFqn()) + s.NotEmpty(ns.GetId()) + s.NotEmpty(ns.GetName()) + } +} + +func (s *NamespacesSuite) Test_ListNamespaces_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *NamespacesSuite) Test_ListNamespaces_Offset_Succeeds() { + req := &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + } + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetNamespaces() + + // set the offset pagination + offset := 4 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetNamespaces() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, ns := range offsetListed { + s.True(proto.Equal(ns, listed[i+offset])) + } } func (s *NamespacesSuite) Test_UpdateNamespace() { @@ -218,10 +292,13 @@ func (s *NamespacesSuite) Test_DeactivateNamespace() { s.False(inactive.GetActive().GetValue()) // Deactivated namespace should not be found on List - gotNamespaces, err := s.db.PolicyClient.ListNamespaces(s.ctx, policydb.StateActive) + listNamespacesRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, + }) s.Require().NoError(err) - s.NotNil(gotNamespaces) - for _, ns := range gotNamespaces { + s.NotNil(listNamespacesRsp) + listed := listNamespacesRsp.GetNamespaces() + for _, ns := range listed { s.NotEqual(n.GetId(), ns.GetId()) } @@ -269,16 +346,20 @@ func setupCascadeDeactivateNamespace(s *NamespacesSuite) (string, string, string func (s *NamespacesSuite) Test_DeactivateNamespace_Cascades_List() { type test struct { name string - testFunc func(state string) bool - state string + testFunc func(state common.ActiveStateEnum) bool + state common.ActiveStateEnum isFound bool } - listNamespaces := func(state string) bool { - listedNamespaces, err := s.db.PolicyClient.ListNamespaces(s.ctx, state) + listNamespaces := func(state common.ActiveStateEnum) bool { + listNamespacesRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: state, + }) s.Require().NoError(err) - s.NotNil(listedNamespaces) - for _, ns := range listedNamespaces { + s.NotNil(listNamespacesRsp) + + listed := listNamespacesRsp.GetNamespaces() + for _, ns := range listed { if deactivatedNsID == ns.GetId() { return true } @@ -286,11 +367,15 @@ func (s *NamespacesSuite) Test_DeactivateNamespace_Cascades_List() { return false } - listAttributes := func(state string) bool { - listedAttrs, err := s.db.PolicyClient.ListAttributes(s.ctx, state, "") + listAttributes := func(state common.ActiveStateEnum) bool { + listAttrsRsp, err := s.db.PolicyClient.ListAttributes(s.ctx, &attributes.ListAttributesRequest{ + State: state, + }) s.Require().NoError(err) - s.NotNil(listedAttrs) - for _, a := range listedAttrs { + s.NotNil(listAttrsRsp) + + listed := listAttrsRsp.GetAttributes() + for _, a := range listed { if deactivatedAttrID == a.GetId() { return true } @@ -298,11 +383,15 @@ func (s *NamespacesSuite) Test_DeactivateNamespace_Cascades_List() { return false } - listValues := func(state string) bool { - listedVals, err := s.db.PolicyClient.ListAttributeValues(s.ctx, deactivatedAttrID, state) + listValues := func(state common.ActiveStateEnum) bool { + listedValsRsp, err := s.db.PolicyClient.ListAttributeValues(s.ctx, &attributes.ListAttributeValuesRequest{ + AttributeId: deactivatedAttrID, + State: state, + }) s.Require().NoError(err) - s.NotNil(listedVals) - for _, v := range listedVals { + s.NotNil(listedValsRsp) + listed := listedValsRsp.GetValues() + for _, v := range listed { if deactivatedAttrValueID == v.GetId() { return true } @@ -314,55 +403,55 @@ func (s *NamespacesSuite) Test_DeactivateNamespace_Cascades_List() { { name: "namespace is NOT found in LIST of ACTIVE", testFunc: listNamespaces, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: false, }, { name: "namespace is found when filtering for INACTIVE state", testFunc: listNamespaces, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: true, }, { name: "namespace is found when filtering for ANY state", testFunc: listNamespaces, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, { name: "attribute is found when filtering for INACTIVE state", testFunc: listAttributes, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: true, }, { name: "attribute is found when filtering for ANY state", testFunc: listAttributes, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, { name: "attribute is NOT found when filtering for ACTIVE state", testFunc: listAttributes, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: false, }, { name: "value is NOT found in LIST of ACTIVE", testFunc: listValues, - state: policydb.StateActive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, isFound: false, }, { name: "value is found when filtering for INACTIVE state", testFunc: listValues, - state: policydb.StateInactive, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, isFound: true, }, { name: "value is found when filtering for ANY state", testFunc: listValues, - state: policydb.StateAny, + state: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, isFound: true, }, } @@ -554,11 +643,14 @@ func (s *NamespacesSuite) Test_UnsafeReactivateNamespace_SetsActiveStatusOfNames s.True(active.GetActive().GetValue()) // test that the namespace is found in the list of active namespaces - gotNamespaces, err := s.db.PolicyClient.ListNamespaces(s.ctx, policydb.StateActive) + listNamespacesRsp, err := s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, + }) s.Require().NoError(err) - s.NotNil(gotNamespaces) + s.NotNil(listNamespacesRsp) + listed := listNamespacesRsp.GetNamespaces() found := false - for _, ns := range gotNamespaces { + for _, ns := range listed { if n.GetId() == ns.GetId() { found = true break @@ -567,11 +659,14 @@ func (s *NamespacesSuite) Test_UnsafeReactivateNamespace_SetsActiveStatusOfNames s.True(found) // test that the namespace is not found in the list of inactive namespaces - gotNamespaces, err = s.db.PolicyClient.ListNamespaces(s.ctx, policydb.StateInactive) + listNamespacesRsp, err = s.db.PolicyClient.ListNamespaces(s.ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE, + }) s.Require().NoError(err) - s.NotNil(gotNamespaces) + s.NotNil(listNamespacesRsp) + listed = listNamespacesRsp.GetNamespaces() found = false - for _, ns := range gotNamespaces { + for _, ns := range listed { if n.GetId() == ns.GetId() { found = true break diff --git a/service/integration/resource_mappings_test.go b/service/integration/resource_mappings_test.go index e8d7d704f..93e5bfb10 100644 --- a/service/integration/resource_mappings_test.go +++ b/service/integration/resource_mappings_test.go @@ -7,15 +7,19 @@ import ( "testing" "github.com/opentdf/platform/protocol/go/common" + "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/resourcemapping" "github.com/opentdf/platform/service/internal/fixtures" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" ) -var unknownNamespaceID = "64257d69-c007-4893-931a-434f1819a4f7" -var unknownResourceMappingGroupID = "c70cad07-21b4-4cb1-9095-bce54615536a" -var unknownResourceMappingID = "45674556-8888-9999-9999-000001230000" +const ( + unknownNamespaceID = "64257d69-c007-4893-931a-434f1819a4f7" + unknownResourceMappingGroupID = "c70cad07-21b4-4cb1-9095-bce54615536a" + unknownResourceMappingID = "45674556-8888-9999-9999-000001230000" +) type ResourceMappingsSuite struct { suite.Suite @@ -76,14 +80,15 @@ func (s *ResourceMappingsSuite) getResourceMappingAttributeValueFixtures() []fix Resource Mapping Groups */ -func (s *ResourceMappingsSuite) Test_ListResourceMappingGroups() { +func (s *ResourceMappingsSuite) Test_ListResourceMappingGroups_NoPagination_Succeeds() { testData := s.getResourceMappingGroupFixtures() - rmGroups, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, &resourcemapping.ListResourceMappingGroupsRequest{}) + listRmGroupsRsp, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, &resourcemapping.ListResourceMappingGroupsRequest{}) s.Require().NoError(err) - s.NotNil(rmGroups) + s.NotNil(listRmGroupsRsp) + listed := listRmGroupsRsp.GetResourceMappingGroups() for _, testRmGroup := range testData { found := false - for _, rmGroup := range rmGroups { + for _, rmGroup := range listed { if testRmGroup.ID == rmGroup.GetId() { found = true break @@ -93,17 +98,75 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappingGroups() { } } -func (s *ResourceMappingsSuite) Test_ListResourceMappingGroupsWithNamespaceIdSucceeds() { +func (s *ResourceMappingsSuite) Test_ListResourceMappingGroups_Limit_Succeeds() { + var limit int32 = 2 + listRsp, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, &resourcemapping.ListResourceMappingGroupsRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetResourceMappingGroups() + s.Equal(len(listed), int(limit)) + + for _, rmg := range listed { + s.NotEmpty(rmg.GetNamespaceId()) + s.NotEmpty(rmg.GetId()) + s.NotEmpty(rmg.GetName()) + } +} + +func (s *NamespacesSuite) Test_ListResourceMappingGroups_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, &resourcemapping.ListResourceMappingGroupsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *ResourceMappingsSuite) Test_ListResourceMappingGroups_Offset_Succeeds() { + req := &resourcemapping.ListResourceMappingGroupsRequest{} + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetResourceMappingGroups() + + // set the offset pagination + offset := 2 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetResourceMappingGroups() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, val := range offsetListed { + s.True(proto.Equal(val, listed[i+offset])) + } +} + +func (s *ResourceMappingsSuite) Test_ListResourceMappingGroups_WithNamespaceId_Succeeds() { scenarioDotComRmGroup := s.f.GetResourceMappingGroupKey("scenario.com_ns_group_1") - rmGroups, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, &resourcemapping.ListResourceMappingGroupsRequest{ + rmGroupsRsp, err := s.db.PolicyClient.ListResourceMappingGroups(s.ctx, &resourcemapping.ListResourceMappingGroupsRequest{ NamespaceId: scenarioDotComRmGroup.NamespaceID, }) s.Require().NoError(err) - s.NotNil(rmGroups) - s.Len(rmGroups, 1) - s.Equal(scenarioDotComRmGroup.ID, rmGroups[0].GetId()) - s.Equal(scenarioDotComRmGroup.NamespaceID, rmGroups[0].GetNamespaceId()) - s.Equal(scenarioDotComRmGroup.Name, rmGroups[0].GetName()) + s.NotNil(rmGroupsRsp) + list := rmGroupsRsp.GetResourceMappingGroups() + s.Len(list, 1) + s.Equal(scenarioDotComRmGroup.ID, list[0].GetId()) + s.Equal(scenarioDotComRmGroup.NamespaceID, list[0].GetNamespaceId()) + s.Equal(scenarioDotComRmGroup.Name, list[0].GetName()) } func (s *ResourceMappingsSuite) Test_GetResourceMappingGroup() { @@ -368,22 +431,6 @@ func (s *ResourceMappingsSuite) Test_CreateResourceMappingWithUnknownAttributeVa s.Require().ErrorIs(err, db.ErrForeignKeyViolation) } -func (s *ResourceMappingsSuite) Test_CreateResourceMappingWithEmptyTermsSucceeds() { - metadata := &common.MetadataMutable{} - - attrValue := s.f.GetAttributeValueKey("example.com/attr/attr2/value/value2") - mapping := &resourcemapping.CreateResourceMappingRequest{ - AttributeValueId: attrValue.ID, - Metadata: metadata, - Terms: []string{}, - } - createdMapping, err := s.db.PolicyClient.CreateResourceMapping(s.ctx, mapping) - s.Require().NoError(err) - s.NotNil(createdMapping) - s.NotNil(createdMapping.GetTerms()) - s.Empty(createdMapping.GetTerms()) -} - func (s *ResourceMappingsSuite) Test_CreateResourceMappingWithGroupIdSucceeds() { metadata := &common.MetadataMutable{} @@ -416,7 +463,7 @@ func (s *ResourceMappingsSuite) Test_CreateResourceMappingWithUnknownGroupIdFail s.Nil(createdMapping) } -func (s *ResourceMappingsSuite) Test_ListResourceMappings() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_NoPagination_Succeeds() { testMappings := make(map[string]fixtures.FixtureDataResourceMapping) for _, testMapping := range s.getResourceMappingFixtures() { testMappings[testMapping.ID] = testMapping @@ -427,18 +474,20 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappings() { testValues[testValue.ID] = testValue } - req := &resourcemapping.ListResourceMappingsRequest{} - mappings, err := s.db.PolicyClient.ListResourceMappings(s.ctx, req) + listRsp, err := s.db.PolicyClient.ListResourceMappings(s.ctx, &resourcemapping.ListResourceMappingsRequest{}) s.Require().NoError(err) - s.NotNil(mappings) + s.NotNil(listRsp) + + list := listRsp.GetResourceMappings() + s.NotEmpty(list) testMappingCount := len(testMappings) foundCount := 0 - for _, mapping := range mappings { + for _, mapping := range list { testMapping, ok := testMappings[mapping.GetId()] if !ok { - // todo: DB is not cleaned up between tests, so ignore any unexpected mappings + // only validating presence of all fixtures within the list response continue } foundCount++ @@ -462,13 +511,71 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappings() { s.Equal(testMappingCount, foundCount) } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupId() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_Limit_Succeeds() { + var limit int32 = 4 + listRsp, err := s.db.PolicyClient.ListResourceMappings(s.ctx, &resourcemapping.ListResourceMappingsRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + listed := listRsp.GetResourceMappings() + s.Equal(len(listed), int(limit)) + + for _, rm := range listed { + s.NotEmpty(rm.GetId()) + s.NotEmpty(rm.GetAttributeValue()) + } +} + +func (s *NamespacesSuite) Test_ListResourceMappings_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListResourceMappings(s.ctx, &resourcemapping.ListResourceMappingsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *ResourceMappingsSuite) Test_ListResourceMappings_Offset_Succeeds() { + req := &resourcemapping.ListResourceMappingsRequest{} + // make initial list request to compare against + listRsp, err := s.db.PolicyClient.ListResourceMappings(s.ctx, req) + s.Require().NoError(err) + s.NotNil(listRsp) + listed := listRsp.GetResourceMappings() + + // set the offset pagination + offset := 2 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + offsetListRsp, err := s.db.PolicyClient.ListResourceMappings(s.ctx, req) + s.Require().NoError(err) + s.NotNil(offsetListRsp) + offsetListed := offsetListRsp.GetResourceMappings() + + // length is reduced by the offset amount + s.Equal(len(offsetListed), len(listed)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, rm := range offsetListed { + s.True(proto.Equal(rm, listed[i+offset])) + } +} + +func (s *ResourceMappingsSuite) Test_ListResourceMappings_ByGroupId_Succeeds() { req := &resourcemapping.ListResourceMappingsRequest{ GroupId: s.getResourceMappingGroupFixtures()[0].ID, } - mappings, err := s.db.PolicyClient.ListResourceMappings(s.ctx, req) + listRsp, err := s.db.PolicyClient.ListResourceMappings(s.ctx, req) s.Require().NoError(err) - s.NotNil(mappings) + s.NotNil(listRsp) + mappings := listRsp.GetResourceMappings() for _, mapping := range mappings { expectedGroupID := req.GetGroupId() actualGroupID := mapping.GetGroup().GetId() @@ -477,7 +584,7 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupId() { } } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqns() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_ByGroupFqns_Succeeds() { scenarioDotComNs := s.getScenarioDotComNamespace() scenarioDotComGroup := s.f.GetResourceMappingGroupKey("scenario.com_ns_group_1") scenarioDotComGroupMapping := s.f.GetResourceMappingKey("resource_mapping_to_attribute_value3") @@ -519,7 +626,7 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqns() { s.Equal("https://scenario.com/attr/working_group/value/blue", value.GetFqn()) } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithEmptyOrNilFqnsFails() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_ByGroupFqns_WithEmptyOrNilFqns_Fails() { fqnRmGroupMap, err := s.db.PolicyClient.ListResourceMappingsByGroupFqns(s.ctx, nil) s.Require().Error(err) s.Nil(fqnRmGroupMap) @@ -529,20 +636,20 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithEmptyOrN s.Nil(fqnRmGroupMap) } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithInvalidFqnsFails() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_ByGroupFqns_WithInvalidFqns_Fails() { fqnRmGroupMap, err := s.db.PolicyClient.ListResourceMappingsByGroupFqns(s.ctx, []string{"invalid_fqn"}) s.Require().Error(err) s.Nil(fqnRmGroupMap) } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithUnknownFqnsFails() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_ByGroupFqns_WithUnknownFqns_Fails() { unknownFqn := "https://unknown.com/resm/unknown_group" fqnRmGroupMap, err := s.db.PolicyClient.ListResourceMappingsByGroupFqns(s.ctx, []string{unknownFqn}) s.Require().Error(err) s.Nil(fqnRmGroupMap) } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithKnownAndUnknownFqnsSucceeds() { +func (s *ResourceMappingsSuite) Test_ListResourceMappings_ByGroupFqns_WithKnownAndUnknownFqns_Succeeds() { exampleDotComNs := s.getExampleDotComNamespace() exampleDotComRmGroup1 := s.f.GetResourceMappingGroupKey("example.com_ns_group_1") @@ -562,7 +669,7 @@ func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithKnownAnd s.Nil(unknownResp) } -func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqnsWithKnownAndInvalidFqnsSucceeds() { +func (s *ResourceMappingsSuite) Test_ListResourceMappingsByGroupFqns_WithKnownAndInvalidFqns_Succeeds() { exampleDotComNs := s.getExampleDotComNamespace() exampleDotComRmGroup1 := s.f.GetResourceMappingGroupKey("example.com_ns_group_1") diff --git a/service/integration/subject_mappings_test.go b/service/integration/subject_mappings_test.go index 6d28e5f1f..a70ff0fc4 100644 --- a/service/integration/subject_mappings_test.go +++ b/service/integration/subject_mappings_test.go @@ -11,8 +11,11 @@ import ( "github.com/opentdf/platform/service/internal/fixtures" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" ) +const nonExistentAttributeValueUUID = "78909865-8888-9999-9999-000000000000" + type SubjectMappingsSuite struct { suite.Suite f fixtures.Fixtures @@ -377,10 +380,12 @@ func (s *SubjectMappingsSuite) TestGetSubjectMapping_NonExistentId_Fails() { s.Require().ErrorIs(err, db.ErrNotFound) } -func (s *SubjectMappingsSuite) TestListSubjectMappings() { - list, err := s.db.PolicyClient.ListSubjectMappings(s.ctx) +func (s *SubjectMappingsSuite) Test_ListSubjectMappings_NoPagination_Succeeds() { + listRsp, err := s.db.PolicyClient.ListSubjectMappings(context.Background(), &subjectmapping.ListSubjectMappingsRequest{}) s.Require().NoError(err) - s.NotNil(list) + s.NotNil(listRsp) + listed := listRsp.GetSubjectMappings() + s.NotEmpty(listed) fixture1 := s.f.GetSubjectMappingKey("subject_mapping_subject_attribute1") found1 := false @@ -388,7 +393,7 @@ func (s *SubjectMappingsSuite) TestListSubjectMappings() { found2 := false fixture3 := s.f.GetSubjectMappingKey("subject_mapping_subject_attribute3") found3 := false - s.GreaterOrEqual(len(list), 3) + s.GreaterOrEqual(len(listed), 3) assertEqual := func(sm *policy.SubjectMapping, fixture fixtures.FixtureDataSubjectMapping) { s.Equal(fixture.AttributeValueID, sm.GetAttributeValue().GetId()) @@ -396,7 +401,7 @@ func (s *SubjectMappingsSuite) TestListSubjectMappings() { s.Equal(fixture.SubjectConditionSetID, sm.GetSubjectConditionSet().GetId()) s.Equal(len(fixture.Actions), len(sm.GetActions())) } - for _, sm := range list { + for _, sm := range listed { if sm.GetId() == fixture1.ID { assertEqual(sm, fixture1) found1 = true @@ -415,6 +420,77 @@ func (s *SubjectMappingsSuite) TestListSubjectMappings() { s.True(found3) } +func (s *SubjectMappingsSuite) Test_ListSubjectMappings_Limit_Succeeds() { + var limit int32 = 3 + listRsp, err := s.db.PolicyClient.ListSubjectMappings(context.Background(), &subjectmapping.ListSubjectMappingsRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + listed := listRsp.GetSubjectMappings() + s.NotEmpty(listed) + + for _, sm := range listed { + s.NotEmpty(sm.GetId()) + s.NotEmpty(sm.GetAttributeValue()) + s.NotNil(sm.GetSubjectConditionSet()) + } + + // request with one below maximum + listRsp, err = s.db.PolicyClient.ListSubjectMappings(context.Background(), &subjectmapping.ListSubjectMappingsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax - 1, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) +} + +func (s *NamespacesSuite) Test_ListSubjectMappings_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListSubjectMappings(context.Background(), &subjectmapping.ListSubjectMappingsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *SubjectMappingsSuite) Test_ListSubjectMappings_Offset_Succeeds() { + req := &subjectmapping.ListSubjectMappingsRequest{} + totalListRsp, err := s.db.PolicyClient.ListSubjectMappings(context.Background(), req) + s.Require().NoError(err) + s.NotNil(totalListRsp) + + totalList := totalListRsp.GetSubjectMappings() + s.NotEmpty(totalList) + + // set the offset pagination + offset := 2 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + + offetListRsp, err := s.db.PolicyClient.ListSubjectMappings(context.Background(), req) + s.Require().NoError(err) + s.NotNil(offetListRsp) + + offsetList := offetListRsp.GetSubjectMappings() + s.NotEmpty(offsetList) + + // length is reduced by the offset amount + s.Equal(len(offsetList), len(totalList)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, sm := range offsetList { + s.True(proto.Equal(sm, totalList[i+offset])) + } +} + func (s *SubjectMappingsSuite) TestDeleteSubjectMapping() { // create a new subject mapping, delete it, and verify get fails with not found fixtureAttrValID := s.f.GetAttributeValueKey("example.com/attr/attr2/value/value1").ID @@ -577,10 +653,11 @@ func (s *SubjectMappingsSuite) TestGetSubjectConditionSet_NonExistentId_Fails() s.Require().ErrorIs(err, db.ErrNotFound) } -func (s *SubjectMappingsSuite) TestListSubjectConditionSet() { - list, err := s.db.PolicyClient.ListSubjectConditionSets(s.ctx) +func (s *SubjectMappingsSuite) Test_ListSubjectConditionSet_NoPagination_Succeeds() { + listRsp, err := s.db.PolicyClient.ListSubjectConditionSets(context.Background(), &subjectmapping.ListSubjectConditionSetsRequest{}) s.Require().NoError(err) - s.NotNil(list) + s.NotNil(listRsp) + listed := listRsp.GetSubjectConditionSets() fixture1 := s.f.GetSubjectConditionSetKey("subject_condition_set1") found1 := false @@ -591,8 +668,8 @@ func (s *SubjectMappingsSuite) TestListSubjectConditionSet() { fixture4 := s.f.GetSubjectConditionSetKey("subject_condition_simple_in") found4 := false - s.GreaterOrEqual(len(list), 3) - for _, scs := range list { + s.GreaterOrEqual(len(listed), 3) + for _, scs := range listed { switch scs.GetId() { case fixture1.ID: found1 = true @@ -610,6 +687,76 @@ func (s *SubjectMappingsSuite) TestListSubjectConditionSet() { s.True(found4) } +func (s *SubjectMappingsSuite) Test_ListSubjectConditionSet_Limit_Succeeds() { + var limit int32 = 3 + listRsp, err := s.db.PolicyClient.ListSubjectConditionSets(context.Background(), &subjectmapping.ListSubjectConditionSetsRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) + + listed := listRsp.GetSubjectConditionSets() + s.NotEmpty(listed) + + for _, sm := range listed { + s.NotEmpty(sm.GetId()) + s.NotEmpty(sm.GetSubjectSets()) + } + + // request with one below maximum + listRsp, err = s.db.PolicyClient.ListSubjectConditionSets(context.Background(), &subjectmapping.ListSubjectConditionSetsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax - 1, + }, + }) + s.Require().NoError(err) + s.NotNil(listRsp) +} + +func (s *NamespacesSuite) Test_ListSubjectConditionSets_Limit_TooLarge_Fails() { + listRsp, err := s.db.PolicyClient.ListSubjectConditionSets(context.Background(), &subjectmapping.ListSubjectConditionSetsRequest{ + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Nil(listRsp) +} + +func (s *SubjectMappingsSuite) Test_ListSubjectConditionSet_Offset_Succeeds() { + req := &subjectmapping.ListSubjectConditionSetsRequest{} + totalListRsp, err := s.db.PolicyClient.ListSubjectConditionSets(context.Background(), req) + s.Require().NoError(err) + s.NotNil(totalListRsp) + + totalList := totalListRsp.GetSubjectConditionSets() + s.NotEmpty(totalList) + + // set the offset pagination + offset := 5 + req.Pagination = &policy.PageRequest{ + Offset: int32(offset), + } + + offetListRsp, err := s.db.PolicyClient.ListSubjectConditionSets(context.Background(), req) + s.Require().NoError(err) + s.NotNil(offetListRsp) + + offsetList := offetListRsp.GetSubjectConditionSets() + s.NotEmpty(offsetList) + + // length is reduced by the offset amount + s.Equal(len(offsetList), len(totalList)-offset) + + // objects are equal between offset and original list beginning at offset index + for i, scs := range offsetList { + s.True(proto.Equal(scs, totalList[i+offset])) + } +} + func (s *SubjectMappingsSuite) TestDeleteSubjectConditionSet() { // create a new subject condition set, delete it, and verify get fails with not found newConditionSet := &subjectmapping.SubjectConditionSetCreate{ diff --git a/service/internal/fixtures/db.go b/service/internal/fixtures/db.go index 5fc628cf6..dc70ce6c2 100644 --- a/service/internal/fixtures/db.go +++ b/service/internal/fixtures/db.go @@ -13,10 +13,18 @@ import ( policydb "github.com/opentdf/platform/service/policy/db" ) +var ( + // Configured default LIST Limit when working with fixtures + fixtureLimitDefault int32 = 1000 + fixtureLimitMax int32 = 5000 +) + type DBInterface struct { Client *db.Client PolicyClient policydb.PolicyDBClient Schema string + LimitDefault int32 + LimitMax int32 } func NewDBInterface(cfg config.Config) DBInterface { @@ -42,7 +50,9 @@ func NewDBInterface(cfg config.Config) DBInterface { return DBInterface{ Client: c, Schema: config.Schema, - PolicyClient: policydb.NewClient(c, logger), + PolicyClient: policydb.NewClient(c, logger, fixtureLimitMax, fixtureLimitDefault), + LimitDefault: fixtureLimitDefault, + LimitMax: fixtureLimitMax, } } diff --git a/service/internal/fixtures/fixtures.go b/service/internal/fixtures/fixtures.go index ed4ebbdcb..5c00892aa 100644 --- a/service/internal/fixtures/fixtures.go +++ b/service/internal/fixtures/fixtures.go @@ -47,12 +47,6 @@ type FixtureDataAttributeValue struct { Active bool `yaml:"active"` } -type FixtureDataValueMember struct { - ID string `yaml:"id"` - ValueID string `yaml:"value_id"` - MemberID string `yaml:"member_id"` -} - type FixtureDataAttributeValueKeyAccessServer struct { ValueID string `yaml:"value_id"` KeyAccessServerID string `yaml:"key_access_server_id"` diff --git a/service/pkg/db/errors.go b/service/pkg/db/errors.go index 49188f14d..74a159dd4 100644 --- a/service/pkg/db/errors.go +++ b/service/pkg/db/errors.go @@ -21,6 +21,7 @@ var ( ErrEnumValueInvalid = errors.New("ErrEnumValueInvalid: not a valid enum value") ErrUUIDInvalid = errors.New("ErrUUIDInvalid: value not a valid UUID") ErrMissingValue = errors.New("ErrMissingValue: value must be included") + ErrListLimitTooLarge = errors.New("ErrListLimitTooLarge: requested limit greater than configured maximum") ) // Get helpful error message for PostgreSQL violation @@ -97,6 +98,7 @@ const ( ErrTextUUIDInvalid = "invalid input syntax for type uuid" ErrTextRestrictViolation = "intended action would violate a restriction" ErrTextFqnMissingValue = "FQN must specify a valid value and be of format 'https:///attr//value/'" + ErrTextListLimitTooLarge = "requested pagination limit must be less than or equal to configured limit" ) func StatusifyError(err error, fallbackErr string, log ...any) error { @@ -125,6 +127,10 @@ func StatusifyError(err error, fallbackErr string, log ...any) error { slog.Error(ErrTextRestrictViolation, l...) return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextRestrictViolation)) } + if errors.Is(err, ErrListLimitTooLarge) { + slog.Error(ErrTextListLimitTooLarge, l...) + return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextListLimitTooLarge)) + } slog.Error(err.Error(), l...) return connect.NewError(connect.CodeInternal, errors.New(fallbackErr)) } diff --git a/service/policy/attributes/attributes.go b/service/policy/attributes/attributes.go index 8ca210ac0..da59557ff 100644 --- a/service/policy/attributes/attributes.go +++ b/service/policy/attributes/attributes.go @@ -13,12 +13,14 @@ import ( "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" + policyconfig "github.com/opentdf/platform/service/policy/config" policydb "github.com/opentdf/platform/service/policy/db" ) type AttributesService struct { //nolint:revive // AttributesService is a valid name for this struct dbClient policydb.PolicyDBClient logger *logger.Logger + config *policyconfig.Config } func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[attributesconnect.AttributesServiceHandler] { @@ -30,8 +32,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer ConnectRPCFunc: attributesconnect.NewAttributesServiceHandler, GRPCGateayFunc: attributes.RegisterAttributesServiceHandlerFromEndpoint, RegisterFunc: func(srp serviceregistry.RegistrationParams) (attributesconnect.AttributesServiceHandler, serviceregistry.HandlerServer) { - as := &AttributesService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} - return as, nil + cfg := policyconfig.GetSharedPolicyConfig(srp) + return &AttributesService{ + dbClient: policydb.NewClient(srp.DBClient, srp.Logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)), + logger: srp.Logger, + config: cfg, + }, nil }, }, } @@ -67,16 +73,13 @@ func (s AttributesService) CreateAttribute(ctx context.Context, func (s *AttributesService) ListAttributes(ctx context.Context, req *connect.Request[attributes.ListAttributesRequest], ) (*connect.Response[attributes.ListAttributesResponse], error) { - state := policydb.GetDBStateTypeTransformedEnum(req.Msg.GetState()) - namespace := req.Msg.GetNamespace() + state := req.Msg.GetState().String() s.logger.Debug("listing attribute definitions", slog.String("state", state)) - rsp := &attributes.ListAttributesResponse{} - list, err := s.dbClient.ListAttributes(ctx, state, namespace) + rsp, err := s.dbClient.ListAttributes(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.Attributes = list return connect.NewResponse(rsp), nil } @@ -205,17 +208,13 @@ func (s *AttributesService) CreateAttributeValue(ctx context.Context, req *conne } func (s *AttributesService) ListAttributeValues(ctx context.Context, req *connect.Request[attributes.ListAttributeValuesRequest]) (*connect.Response[attributes.ListAttributeValuesResponse], error) { - rsp := &attributes.ListAttributeValuesResponse{} - - state := policydb.GetDBStateTypeTransformedEnum(req.Msg.GetState()) + state := req.Msg.GetState().String() s.logger.Debug("listing attribute values", slog.String("attributeId", req.Msg.GetAttributeId()), slog.String("state", state)) - list, err := s.dbClient.ListAttributeValues(ctx, req.Msg.GetAttributeId(), state) + rsp, err := s.dbClient.ListAttributeValues(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed, slog.String("attributeId", req.Msg.GetAttributeId())) } - rsp.Values = list - return connect.NewResponse(rsp), nil } diff --git a/service/policy/config/config.go b/service/policy/config/config.go new file mode 100644 index 000000000..2c2c5ec3d --- /dev/null +++ b/service/policy/config/config.go @@ -0,0 +1,38 @@ +package config + +import ( + "fmt" + + "github.com/creasty/defaults" + "github.com/mitchellh/mapstructure" + "github.com/opentdf/platform/service/pkg/serviceregistry" +) + +// Global policy config to share among policy services +type Config struct { + // Default pagination list limit when not provided in request + ListRequestLimitDefault int `mapstructure:"list_request_limit_default" default:"1000"` + // Maximum pagination list limit allowed by policy services + ListRequestLimitMax int `mapstructure:"list_request_limit_max" default:"2500"` +} + +func GetSharedPolicyConfig(srp serviceregistry.RegistrationParams) *Config { + policyCfg := new(Config) + + if err := defaults.Set(policyCfg); err != nil { + panic(fmt.Errorf("failed to set defaults for policy service config: %w", err)) + } + + // Only decode config if it exists + if srp.Config != nil { + if err := mapstructure.Decode(srp.Config, &policyCfg); err != nil { + panic(fmt.Errorf("invalid policy svc cfg [%v] %w", srp.Config, err)) + } + } + + if policyCfg.ListRequestLimitMax <= policyCfg.ListRequestLimitDefault { + panic(fmt.Errorf("policy svc config request limit maximum [%d] must be greater than request limit default [%d]", policyCfg.ListRequestLimitMax, policyCfg.ListRequestLimitDefault)) + } + + return policyCfg +} diff --git a/service/policy/db/attribute_fqn.go b/service/policy/db/attribute_fqn.go index 30a5eeb8a..3374817fa 100644 --- a/service/policy/db/attribute_fqn.go +++ b/service/policy/db/attribute_fqn.go @@ -5,7 +5,9 @@ import ( "fmt" "strings" + "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/namespaces" "github.com/opentdf/platform/service/pkg/db" ) @@ -26,14 +28,16 @@ func (c *PolicyDBClient) AttrFqnReindex(ctx context.Context) (res struct { //nol }, ) { // Get all namespaces - ns, err := c.ListNamespaces(ctx, StateAny) + ns, err := c.ListNamespaces(ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + }) if err != nil { panic(fmt.Errorf("could not get namespaces: %w", err)) } // Reindex all namespaces reindexedRecords := []UpsertAttributeNamespaceFqnRow{} - for _, n := range ns { + for _, n := range ns.GetNamespaces() { rows, err := c.Queries.UpsertAttributeNamespaceFqn(ctx, n.GetId()) if err != nil { panic(fmt.Errorf("could not update namespace [%s] FQN: %w", n.GetId(), err)) diff --git a/service/policy/db/attribute_values.go b/service/policy/db/attribute_values.go index 5cddd269d..952b39073 100644 --- a/service/policy/db/attribute_values.go +++ b/service/policy/db/attribute_values.go @@ -74,18 +74,28 @@ func (c PolicyDBClient) GetAttributeValue(ctx context.Context, id string) (*poli }, nil } -func (c PolicyDBClient) ListAttributeValues(ctx context.Context, attributeID string, state string) ([]*policy.Value, error) { +func (c PolicyDBClient) ListAttributeValues(ctx context.Context, r *attributes.ListAttributeValuesRequest) (*attributes.ListAttributeValuesResponse, error) { + state := getDBStateTypeTransformedEnum(r.GetState()) + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + active := pgtype.Bool{ Valid: false, } - if state != "" && state != StateAny { - active = pgtypeBool(state == StateActive) + if state != stateAny { + active = pgtypeBool(state == stateActive) } list, err := c.Queries.ListAttributeValues(ctx, ListAttributeValuesParams{ - AttributeDefinitionID: attributeID, + AttributeDefinitionID: r.GetAttributeId(), Active: active, + Limit: limit, + Offset: offset, }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) @@ -110,13 +120,49 @@ func (c PolicyDBClient) ListAttributeValues(ctx context.Context, attributeID str Fqn: av.Fqn.String, } } + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } - return attributeValues, nil + return &attributes.ListAttributeValuesResponse{ + Values: attributeValues, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } +// Loads all attribute values into memory by making iterative db roundtrip requests of defaultObjectListAllLimit size func (c PolicyDBClient) ListAllAttributeValues(ctx context.Context) ([]*policy.Value, error) { - // call ListAttributeValues method with "empty" param values to make the query return all rows - return c.ListAttributeValues(ctx, "", StateAny) + var nextOffset int32 + valsList := make([]*policy.Value, 0) + + for { + listed, err := c.ListAttributeValues(ctx, &attributes.ListAttributeValuesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: c.listCfg.limitMax, + Offset: nextOffset, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to list all attributes: %w", err) + } + + nextOffset = listed.GetPagination().GetNextOffset() + valsList = append(valsList, listed.GetValues()...) + + // offset becomes zero when list is exhausted + if nextOffset <= 0 { + break + } + } + return valsList, nil } func (c PolicyDBClient) UpdateAttributeValue(ctx context.Context, r *attributes.UpdateAttributeValueRequest) (*policy.Value, error) { diff --git a/service/policy/db/attributes.go b/service/policy/db/attributes.go index acf3f2e59..86c8a93d2 100644 --- a/service/policy/db/attributes.go +++ b/service/policy/db/attributes.go @@ -114,7 +114,10 @@ func hydrateAttribute(row *attributeQueryRow) (*policy.Attribute, error) { // CRUD operations /// -func (c PolicyDBClient) ListAttributes(ctx context.Context, state string, namespace string) ([]*policy.Attribute, error) { +func (c PolicyDBClient) ListAttributes(ctx context.Context, r *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) { + namespace := r.GetNamespace() + state := getDBStateTypeTransformedEnum(r.GetState()) + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) var ( active = pgtype.Bool{ Valid: false, @@ -123,8 +126,13 @@ func (c PolicyDBClient) ListAttributes(ctx context.Context, state string, namesp namespaceName = "" ) - if state != "" && state != StateAny { - active = pgtypeBool(state == StateActive) + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + + if state != stateAny { + active = pgtypeBool(state == stateActive) } if namespace != "" { @@ -139,6 +147,8 @@ func (c PolicyDBClient) ListAttributes(ctx context.Context, state string, namesp Active: active, NamespaceID: namespaceID, NamespaceName: namespaceName, + Limit: limit, + Offset: offset, }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) @@ -163,12 +173,49 @@ func (c PolicyDBClient) ListAttributes(ctx context.Context, state string, namesp } } - return policyAttributes, nil + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + return &attributes.ListAttributesResponse{ + Attributes: policyAttributes, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } +// Loads all attributes into memory by making iterative db roundtrip requests of defaultObjectListAllLimit size func (c PolicyDBClient) ListAllAttributes(ctx context.Context) ([]*policy.Attribute, error) { - // call general List method with empty params to get all attributes - return c.ListAttributes(ctx, "", "") + var nextOffset int32 + attrsList := make([]*policy.Attribute, 0) + + for { + listed, err := c.ListAttributes(ctx, &attributes.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: c.listCfg.limitMax, + Offset: nextOffset, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to list all attributes: %w", err) + } + + nextOffset = listed.GetPagination().GetNextOffset() + attrsList = append(attrsList, listed.GetAttributes()...) + + // offset becomes zero when list is exhausted + if nextOffset <= 0 { + break + } + } + return attrsList, nil } func (c PolicyDBClient) GetAttribute(ctx context.Context, id string) (*policy.Attribute, error) { @@ -253,7 +300,9 @@ func (c PolicyDBClient) GetAttributeByFqn(ctx context.Context, fqn string) (*pol } func (c PolicyDBClient) GetAttributesByNamespace(ctx context.Context, namespaceID string) ([]*policy.Attribute, error) { - list, err := c.Queries.ListAttributesSummary(ctx, namespaceID) + list, err := c.Queries.ListAttributesSummary(ctx, ListAttributesSummaryParams{ + NamespaceID: namespaceID, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } diff --git a/service/policy/db/key_access_server_registry.go b/service/policy/db/key_access_server_registry.go index 45cf9d11b..1ae09e63f 100644 --- a/service/policy/db/key_access_server_registry.go +++ b/service/policy/db/key_access_server_registry.go @@ -12,8 +12,18 @@ import ( "google.golang.org/protobuf/encoding/protojson" ) -func (c PolicyDBClient) ListKeyAccessServers(ctx context.Context) ([]*policy.KeyAccessServer, error) { - list, err := c.Queries.ListKeyAccessServers(ctx) +func (c PolicyDBClient) ListKeyAccessServers(ctx context.Context, r *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + + list, err := c.Queries.ListKeyAccessServers(ctx, ListKeyAccessServersParams{ + Offset: offset, + Limit: limit, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } @@ -37,13 +47,26 @@ func (c PolicyDBClient) ListKeyAccessServers(ctx context.Context) ([]*policy.Key keyAccessServer.Id = kas.ID keyAccessServer.Uri = kas.Uri keyAccessServer.PublicKey = publicKey - keyAccessServer.Name = kas.Name.String + keyAccessServer.Name = kas.KasName.String keyAccessServer.Metadata = metadata keyAccessServers[i] = keyAccessServer } + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } - return keyAccessServers, nil + return &kasregistry.ListKeyAccessServersResponse{ + KeyAccessServers: keyAccessServers, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } func (c PolicyDBClient) GetKeyAccessServer(ctx context.Context, id string) (*policy.KeyAccessServer, error) { @@ -171,11 +194,19 @@ func (c PolicyDBClient) DeleteKeyAccessServer(ctx context.Context, id string) (* }, nil } -func (c PolicyDBClient) ListKeyAccessServerGrants(ctx context.Context, kasID, kasURI, kasName string) ([]*kasregistry.KeyAccessServerGrants, error) { +func (c PolicyDBClient) ListKeyAccessServerGrants(ctx context.Context, r *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + params := ListKeyAccessServerGrantsParams{ - KasID: kasID, - KasUri: kasURI, - KasName: kasName, + KasID: r.GetKasId(), + KasUri: r.GetKasUri(), + KasName: r.GetKasName(), + Offset: offset, + Limit: limit, } listRows, err := c.Queries.ListKeyAccessServerGrants(ctx, params) if err != nil { @@ -213,6 +244,18 @@ func (c PolicyDBClient) ListKeyAccessServerGrants(ctx context.Context, kasID, ka NamespaceGrants: namespaceGrants, } } - - return grants, nil + var total int32 + var nextOffset int32 + if len(listRows) > 0 { + total = int32(listRows[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + return &kasregistry.ListKeyAccessServerGrantsResponse{ + Grants: grants, + Pagination: &policy.PageResponse{ + CurrentOffset: params.Offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } diff --git a/service/policy/db/namespaces.go b/service/policy/db/namespaces.go index d1fd0a0c2..d4a19467c 100644 --- a/service/policy/db/namespaces.go +++ b/service/policy/db/namespaces.go @@ -44,21 +44,32 @@ func (c PolicyDBClient) GetNamespace(ctx context.Context, id string) (*policy.Na }, nil } -func (c PolicyDBClient) ListNamespaces(ctx context.Context, state string) ([]*policy.Namespace, error) { +func (c PolicyDBClient) ListNamespaces(ctx context.Context, r *namespaces.ListNamespacesRequest) (*namespaces.ListNamespacesResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + active := pgtype.Bool{ Valid: false, } - - if state != "" && state != StateAny { - active = pgtypeBool(state == StateActive) + state := getDBStateTypeTransformedEnum(r.GetState()) + if state != stateAny { + active = pgtypeBool(state == stateActive) } - list, err := c.Queries.ListNamespaces(ctx, active) + list, err := c.Queries.ListNamespaces(ctx, ListNamespacesParams{ + Active: active, + Limit: limit, + Offset: offset, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } - namespaces := make([]*policy.Namespace, len(list)) + nsList := make([]*policy.Namespace, len(list)) for i, ns := range list { metadata := &common.Metadata{} @@ -66,7 +77,7 @@ func (c PolicyDBClient) ListNamespaces(ctx context.Context, state string) ([]*po return nil, err } - namespaces[i] = &policy.Namespace{ + nsList[i] = &policy.Namespace{ Id: ns.ID, Name: ns.Name, Active: &wrapperspb.BoolValue{Value: ns.Active}, @@ -75,7 +86,49 @@ func (c PolicyDBClient) ListNamespaces(ctx context.Context, state string) ([]*po } } - return namespaces, nil + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + return &namespaces.ListNamespacesResponse{ + Namespaces: nsList, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil +} + +// Loads all namespaces into memory by making iterative db roundtrip requests of defaultObjectListAllLimit size +func (c PolicyDBClient) ListAllNamespaces(ctx context.Context) ([]*policy.Namespace, error) { + var nextOffset int32 + nsList := make([]*policy.Namespace, 0) + + for { + listed, err := c.ListNamespaces(ctx, &namespaces.ListNamespacesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + Pagination: &policy.PageRequest{ + Limit: c.listCfg.limitMax, + Offset: nextOffset, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to list all namespaces: %w", err) + } + + nextOffset = listed.GetPagination().GetNextOffset() + nsList = append(nsList, listed.GetNamespaces()...) + + // offset becomes zero when list is exhausted + if nextOffset <= 0 { + break + } + } + return nsList, nil } func (c PolicyDBClient) CreateNamespace(ctx context.Context, r *namespaces.CreateNamespaceRequest) (*policy.Namespace, error) { diff --git a/service/policy/db/policy.go b/service/policy/db/policy.go index cb2304500..390362b3d 100644 --- a/service/policy/db/policy.go +++ b/service/policy/db/policy.go @@ -7,33 +7,41 @@ import ( ) const ( - StateInactive = "INACTIVE" - StateActive = "ACTIVE" - StateAny = "ANY" - StateUnspecified = "UNSPECIFIED" + stateInactive transformedState = "INACTIVE" + stateActive transformedState = "ACTIVE" + stateAny transformedState = "ANY" + stateUnspecified transformedState = "UNSPECIFIED" ) +type transformedState string + +type ListConfig struct { + limitDefault int32 + limitMax int32 +} + type PolicyDBClient struct { *db.Client logger *logger.Logger *Queries + listCfg ListConfig } -func NewClient(c *db.Client, logger *logger.Logger) PolicyDBClient { - return PolicyDBClient{c, logger, New(c.Pgx)} +func NewClient(c *db.Client, logger *logger.Logger, configuredListLimitMax, configuredListLimitDefault int32) PolicyDBClient { + return PolicyDBClient{c, logger, New(c.Pgx), ListConfig{limitDefault: configuredListLimitDefault, limitMax: configuredListLimitMax}} } -func GetDBStateTypeTransformedEnum(state common.ActiveStateEnum) string { +func getDBStateTypeTransformedEnum(state common.ActiveStateEnum) transformedState { switch state.String() { case common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE.String(): - return StateActive + return stateActive case common.ActiveStateEnum_ACTIVE_STATE_ENUM_INACTIVE.String(): - return StateInactive + return stateInactive case common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY.String(): - return StateAny + return stateAny case common.ActiveStateEnum_ACTIVE_STATE_ENUM_UNSPECIFIED.String(): - return StateActive + return stateActive default: - return StateActive + return stateActive } } diff --git a/service/policy/db/query.sql b/service/policy/db/query.sql index fdb7c2956..dc36dd37a 100644 --- a/service/policy/db/query.sql +++ b/service/policy/db/query.sql @@ -3,60 +3,87 @@ ---------------------------------------------------------------- -- name: ListKeyAccessServerGrants :many +WITH listed AS ( + SELECT + COUNT(*) OVER() AS total, + kas.id AS kas_id, + kas.uri AS kas_uri, + kas.name AS kas_name, + kas.public_key AS kas_public_key, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT( + 'labels', kas.metadata -> 'labels', + 'created_at', kas.created_at, + 'updated_at', kas.updated_at + )) AS kas_metadata, + JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'id', attrkag.attribute_definition_id, + 'fqn', fqns_on_attr.fqn + )) FILTER (WHERE attrkag.attribute_definition_id IS NOT NULL) AS attributes_grants, + JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'id', valkag.attribute_value_id, + 'fqn', fqns_on_vals.fqn + )) FILTER (WHERE valkag.attribute_value_id IS NOT NULL) AS values_grants, + JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'id', nskag.namespace_id, + 'fqn', fqns_on_ns.fqn + )) FILTER (WHERE nskag.namespace_id IS NOT NULL) AS namespace_grants + FROM + key_access_servers kas + LEFT JOIN + attribute_definition_key_access_grants attrkag + ON kas.id = attrkag.key_access_server_id + LEFT JOIN + attribute_fqns fqns_on_attr + ON attrkag.attribute_definition_id = fqns_on_attr.attribute_id + AND fqns_on_attr.value_id IS NULL + LEFT JOIN + attribute_value_key_access_grants valkag + ON kas.id = valkag.key_access_server_id + LEFT JOIN + attribute_fqns fqns_on_vals + ON valkag.attribute_value_id = fqns_on_vals.value_id + LEFT JOIN + attribute_namespace_key_access_grants nskag + ON kas.id = nskag.key_access_server_id + LEFT JOIN + attribute_fqns fqns_on_ns + ON nskag.namespace_id = fqns_on_ns.namespace_id + AND fqns_on_ns.attribute_id IS NULL AND fqns_on_ns.value_id IS NULL + WHERE (NULLIF(@kas_id, '') IS NULL OR kas.id = @kas_id::uuid) + AND (NULLIF(@kas_uri, '') IS NULL OR kas.uri = @kas_uri::varchar) + AND (NULLIF(@kas_name, '') IS NULL OR kas.name = @kas_name::varchar) + GROUP BY + kas.id +) SELECT - kas.id AS kas_id, - kas.uri AS kas_uri, - kas.name AS kas_name, - kas.public_key AS kas_public_key, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT( - 'labels', kas.metadata -> 'labels', - 'created_at', kas.created_at, - 'updated_at', kas.updated_at - )) AS kas_metadata, - JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( - 'id', attrkag.attribute_definition_id, - 'fqn', fqns_on_attr.fqn - )) FILTER (WHERE attrkag.attribute_definition_id IS NOT NULL) AS attributes_grants, - JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( - 'id', valkag.attribute_value_id, - 'fqn', fqns_on_vals.fqn - )) FILTER (WHERE valkag.attribute_value_id IS NOT NULL) AS values_grants, - JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( - 'id', nskag.namespace_id, - 'fqn', fqns_on_ns.fqn - )) FILTER (WHERE nskag.namespace_id IS NOT NULL) AS namespace_grants -FROM - key_access_servers kas -LEFT JOIN - attribute_definition_key_access_grants attrkag - ON kas.id = attrkag.key_access_server_id -LEFT JOIN - attribute_fqns fqns_on_attr - ON attrkag.attribute_definition_id = fqns_on_attr.attribute_id - AND fqns_on_attr.value_id IS NULL -LEFT JOIN - attribute_value_key_access_grants valkag - ON kas.id = valkag.key_access_server_id -LEFT JOIN - attribute_fqns fqns_on_vals - ON valkag.attribute_value_id = fqns_on_vals.value_id -LEFT JOIN - attribute_namespace_key_access_grants nskag - ON kas.id = nskag.key_access_server_id -LEFT JOIN - attribute_fqns fqns_on_ns - ON nskag.namespace_id = fqns_on_ns.namespace_id - AND fqns_on_ns.attribute_id IS NULL AND fqns_on_ns.value_id IS NULL -WHERE (NULLIF(@kas_id, '') IS NULL OR kas.id = @kas_id::uuid) - AND (NULLIF(@kas_uri, '') IS NULL OR kas.uri = @kas_uri::varchar) - AND (NULLIF(@kas_name, '') IS NULL OR kas.name = @kas_name::varchar) -GROUP BY - kas.id; + listed.kas_id, + listed.kas_uri, + listed.kas_name, + listed.kas_public_key, + listed.kas_metadata, + listed.attributes_grants, + listed.values_grants, + listed.namespace_grants, + listed.total +FROM listed +LIMIT @limit_ +OFFSET @offset_; -- name: ListKeyAccessServers :many -SELECT id, uri, public_key, name, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -FROM key_access_servers; +WITH counted AS ( + SELECT COUNT(kas.id) AS total + FROM key_access_servers kas +) +SELECT kas.id, + kas.uri, + kas.public_key, + kas.name AS kas_name, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', kas.metadata -> 'labels', 'created_at', kas.created_at, 'updated_at', kas.updated_at)) as metadata, + counted.total +FROM key_access_servers kas +CROSS JOIN counted +LIMIT @limit_ +OFFSET @offset_; -- name: GetKeyAccessServer :one SELECT id, uri, public_key, name, @@ -206,6 +233,10 @@ RETURNING ---------------------------------------------------------------- -- name: ListAttributesDetail :many +WITH counted AS ( + SELECT COUNT(ad.id) AS total + FROM attribute_definitions ad +) SELECT ad.id, ad.name as attribute_name, @@ -222,8 +253,10 @@ SELECT 'fqn', CONCAT(fqns.fqn, '/value/', avt.value) ) ORDER BY ARRAY_POSITION(ad.values_order, avt.id) ) AS values, - fqns.fqn + fqns.fqn, + counted.total FROM attribute_definitions ad +CROSS JOIN counted LEFT JOIN attribute_namespaces n ON n.id = ad.namespace_id LEFT JOIN ( SELECT @@ -249,9 +282,14 @@ WHERE (sqlc.narg('active')::BOOLEAN IS NULL OR ad.active = sqlc.narg('active')) AND (NULLIF(@namespace_id, '') IS NULL OR ad.namespace_id = @namespace_id::uuid) AND (NULLIF(@namespace_name, '') IS NULL OR n.name = @namespace_name) -GROUP BY ad.id, n.name, fqns.fqn; +GROUP BY ad.id, n.name, fqns.fqn, counted.total +LIMIT @limit_ +OFFSET @offset_; -- name: ListAttributesSummary :many +WITH counted AS ( + SELECT COUNT(ad.id) AS total FROM attribute_definitions ad +) SELECT ad.id, ad.name as attribute_name, @@ -259,11 +297,15 @@ SELECT JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', ad.metadata -> 'labels', 'created_at', ad.created_at, 'updated_at', ad.updated_at)) AS metadata, ad.namespace_id, ad.active, - n.name as namespace_name + n.name as namespace_name, + counted.total FROM attribute_definitions ad +CROSS JOIN counted LEFT JOIN attribute_namespaces n ON n.id = ad.namespace_id WHERE ad.namespace_id = $1 -GROUP BY ad.id, n.name; +GROUP BY ad.id, n.name, counted.total +LIMIT @limit_ +OFFSET @offset_; -- name: ListAttributesByDefOrValueFqns :many -- get the attribute definition for the provided value or definition fqn @@ -468,21 +510,27 @@ WHERE attribute_definition_id = $1 AND key_access_server_id = $2; ---------------------------------------------------------------- -- name: ListAttributeValues :many - +WITH counted AS ( + SELECT COUNT(av.id) AS total + FROM attribute_values av +) SELECT av.id, av.value, av.active, JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', av.metadata -> 'labels', 'created_at', av.created_at, 'updated_at', av.updated_at)) as metadata, av.attribute_definition_id, - fqns.fqn + fqns.fqn, + counted.total FROM attribute_values av +CROSS JOIN counted LEFT JOIN attribute_fqns fqns ON av.id = fqns.value_id WHERE ( (sqlc.narg('active')::BOOLEAN IS NULL OR av.active = sqlc.narg('active')) AND (NULLIF(@attribute_definition_id, '') IS NULL OR av.attribute_definition_id = @attribute_definition_id::UUID) ) -GROUP BY av.id, fqns.fqn; +LIMIT @limit_ +OFFSET @offset_; -- name: GetAttributeValue :one SELECT @@ -537,10 +585,20 @@ WHERE attribute_value_id = $1 AND key_access_server_id = $2; ---------------------------------------------------------------- -- name: ListResourceMappingGroups :many -SELECT id, namespace_id, name, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -FROM resource_mapping_groups -WHERE (NULLIF(@namespace_id, '') IS NULL OR namespace_id = @namespace_id::uuid); +WITH counted AS ( + SELECT COUNT(rmg.id) AS total + FROM resource_mapping_groups rmg +) +SELECT rmg.id, + rmg.namespace_id, + rmg.name, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', rmg.metadata -> 'labels', 'created_at', rmg.created_at, 'updated_at', rmg.updated_at)) as metadata, + counted.total +FROM resource_mapping_groups rmg +CROSS JOIN counted +WHERE (NULLIF(@namespace_id, '') IS NULL OR rmg.namespace_id = @namespace_id::uuid) +LIMIT @limit_ +OFFSET @offset_; -- name: GetResourceMappingGroup :one SELECT id, namespace_id, name, @@ -569,17 +627,25 @@ DELETE FROM resource_mapping_groups WHERE id = $1; ---------------------------------------------------------------- -- name: ListResourceMappings :many +WITH counted AS ( + SELECT COUNT(rm.id) AS total + FROM resource_mappings rm +) SELECT m.id, JSON_BUILD_OBJECT('id', av.id, 'value', av.value, 'fqn', fqns.fqn) as attribute_value, m.terms, JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', m.metadata -> 'labels', 'created_at', m.created_at, 'updated_at', m.updated_at)) as metadata, - COALESCE(m.group_id::TEXT, '')::TEXT as group_id + COALESCE(m.group_id::TEXT, '')::TEXT as group_id, + counted.total FROM resource_mappings m +CROSS JOIN counted LEFT JOIN attribute_values av on m.attribute_value_id = av.id LEFT JOIN attribute_fqns fqns on av.id = fqns.value_id WHERE (NULLIF(@group_id, '') IS NULL OR m.group_id = @group_id::UUID) -GROUP BY av.id, m.id, fqns.fqn; +GROUP BY av.id, m.id, fqns.fqn, counted.total +LIMIT @limit_ +OFFSET @offset_; -- name: ListResourceMappingsByFullyQualifiedGroup :many -- CTE to cache the group JSON build since it will be the same for all mappings of the group @@ -646,15 +712,22 @@ DELETE FROM resource_mappings WHERE id = $1; ---------------------------------------------------------------- -- name: ListNamespaces :many +WITH counted AS ( + SELECT COUNT(id) AS total FROM attribute_namespaces +) SELECT ns.id, ns.name, ns.active, JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', ns.metadata -> 'labels', 'created_at', ns.created_at, 'updated_at', ns.updated_at)) as metadata, - fqns.fqn + fqns.fqn, + counted.total FROM attribute_namespaces ns +CROSS JOIN counted LEFT JOIN attribute_fqns fqns ON ns.id = fqns.namespace_id AND fqns.attribute_id IS NULL -WHERE (sqlc.narg('active')::BOOLEAN IS NULL OR ns.active = sqlc.narg('active')::BOOLEAN); +WHERE (sqlc.narg('active')::BOOLEAN IS NULL OR ns.active = sqlc.narg('active')::BOOLEAN) +LIMIT @limit_ +OFFSET @offset_; -- name: GetNamespace :one SELECT @@ -706,11 +779,19 @@ WHERE namespace_id = $1 AND key_access_server_id = $2; ---------------------------------------------------------------- -- name: ListSubjectConditionSets :many +WITH counted AS ( + SELECT COUNT(scs.id) AS total + FROM subject_condition_set scs +) SELECT - id, - condition, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -FROM subject_condition_set; + scs.id, + scs.condition, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', scs.metadata -> 'labels', 'created_at', scs.created_at, 'updated_at', scs.updated_at)) as metadata, + counted.total +FROM subject_condition_set scs +CROSS JOIN counted +LIMIT @limit_ +OFFSET @offset_; -- name: GetSubjectConditionSet :one SELECT @@ -745,6 +826,10 @@ RETURNING id; ---------------------------------------------------------------- -- name: ListSubjectMappings :many +WITH counted AS ( + SELECT COUNT(sm.id) AS total + FROM subject_mappings sm +) SELECT sm.id, sm.actions, @@ -754,11 +839,15 @@ SELECT 'metadata', JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', scs.metadata->'labels', 'created_at', scs.created_at, 'updated_at', scs.updated_at)), 'subject_sets', scs.condition ) AS subject_condition_set, - JSON_BUILD_OBJECT('id', av.id,'value', av.value,'active', av.active) AS attribute_value + JSON_BUILD_OBJECT('id', av.id,'value', av.value,'active', av.active) AS attribute_value, + counted.total FROM subject_mappings sm +CROSS JOIN counted LEFT JOIN attribute_values av ON sm.attribute_value_id = av.id LEFT JOIN subject_condition_set scs ON scs.id = sm.subject_condition_set_id -GROUP BY av.id, sm.id, scs.id; +GROUP BY av.id, sm.id, scs.id, counted.total +LIMIT @limit_ +OFFSET @offset_; -- name: GetSubjectMapping :one SELECT @@ -813,4 +902,4 @@ SET WHERE id = $1; -- name: DeleteSubjectMapping :execrows -DELETE FROM subject_mappings WHERE id = $1; +DELETE FROM subject_mappings WHERE id = $1; \ No newline at end of file diff --git a/service/policy/db/query.sql.go b/service/policy/db/query.sql.go index e79ce3a9a..3e21083c0 100644 --- a/service/policy/db/query.sql.go +++ b/service/policy/db/query.sql.go @@ -904,26 +904,34 @@ func (q *Queries) GetSubjectMapping(ctx context.Context, id string) (GetSubjectM const listAttributeValues = `-- name: ListAttributeValues :many - +WITH counted AS ( + SELECT COUNT(av.id) AS total + FROM attribute_values av +) SELECT av.id, av.value, av.active, JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', av.metadata -> 'labels', 'created_at', av.created_at, 'updated_at', av.updated_at)) as metadata, av.attribute_definition_id, - fqns.fqn + fqns.fqn, + counted.total FROM attribute_values av +CROSS JOIN counted LEFT JOIN attribute_fqns fqns ON av.id = fqns.value_id WHERE ( ($1::BOOLEAN IS NULL OR av.active = $1) AND (NULLIF($2, '') IS NULL OR av.attribute_definition_id = $2::UUID) ) -GROUP BY av.id, fqns.fqn +LIMIT $4 +OFFSET $3 ` type ListAttributeValuesParams struct { Active pgtype.Bool `json:"active"` AttributeDefinitionID interface{} `json:"attribute_definition_id"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` } type ListAttributeValuesRow struct { @@ -933,28 +941,41 @@ type ListAttributeValuesRow struct { Metadata []byte `json:"metadata"` AttributeDefinitionID string `json:"attribute_definition_id"` Fqn pgtype.Text `json:"fqn"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // ATTRIBUTE VALUES // -------------------------------------------------------------- // +// WITH counted AS ( +// SELECT COUNT(av.id) AS total +// FROM attribute_values av +// ) // SELECT // av.id, // av.value, // av.active, // JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', av.metadata -> 'labels', 'created_at', av.created_at, 'updated_at', av.updated_at)) as metadata, // av.attribute_definition_id, -// fqns.fqn +// fqns.fqn, +// counted.total // FROM attribute_values av +// CROSS JOIN counted // LEFT JOIN attribute_fqns fqns ON av.id = fqns.value_id // WHERE ( // ($1::BOOLEAN IS NULL OR av.active = $1) AND // (NULLIF($2, '') IS NULL OR av.attribute_definition_id = $2::UUID) // ) -// GROUP BY av.id, fqns.fqn +// LIMIT $4 +// OFFSET $3 func (q *Queries) ListAttributeValues(ctx context.Context, arg ListAttributeValuesParams) ([]ListAttributeValuesRow, error) { - rows, err := q.db.Query(ctx, listAttributeValues, arg.Active, arg.AttributeDefinitionID) + rows, err := q.db.Query(ctx, listAttributeValues, + arg.Active, + arg.AttributeDefinitionID, + arg.Offset, + arg.Limit, + ) if err != nil { return nil, err } @@ -969,6 +990,7 @@ func (q *Queries) ListAttributeValues(ctx context.Context, arg ListAttributeValu &i.Metadata, &i.AttributeDefinitionID, &i.Fqn, + &i.Total, ); err != nil { return nil, err } @@ -1271,6 +1293,10 @@ func (q *Queries) ListAttributesByDefOrValueFqns(ctx context.Context, fqns []str const listAttributesDetail = `-- name: ListAttributesDetail :many +WITH counted AS ( + SELECT COUNT(ad.id) AS total + FROM attribute_definitions ad +) SELECT ad.id, ad.name as attribute_name, @@ -1287,8 +1313,10 @@ SELECT 'fqn', CONCAT(fqns.fqn, '/value/', avt.value) ) ORDER BY ARRAY_POSITION(ad.values_order, avt.id) ) AS values, - fqns.fqn + fqns.fqn, + counted.total FROM attribute_definitions ad +CROSS JOIN counted LEFT JOIN attribute_namespaces n ON n.id = ad.namespace_id LEFT JOIN ( SELECT @@ -1314,13 +1342,17 @@ WHERE ($1::BOOLEAN IS NULL OR ad.active = $1) AND (NULLIF($2, '') IS NULL OR ad.namespace_id = $2::uuid) AND (NULLIF($3, '') IS NULL OR n.name = $3) -GROUP BY ad.id, n.name, fqns.fqn +GROUP BY ad.id, n.name, fqns.fqn, counted.total +LIMIT $5 +OFFSET $4 ` type ListAttributesDetailParams struct { Active pgtype.Bool `json:"active"` NamespaceID interface{} `json:"namespace_id"` NamespaceName interface{} `json:"namespace_name"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` } type ListAttributesDetailRow struct { @@ -1333,12 +1365,17 @@ type ListAttributesDetailRow struct { NamespaceName pgtype.Text `json:"namespace_name"` Values []byte `json:"values"` Fqn pgtype.Text `json:"fqn"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // ATTRIBUTES // -------------------------------------------------------------- // +// WITH counted AS ( +// SELECT COUNT(ad.id) AS total +// FROM attribute_definitions ad +// ) // SELECT // ad.id, // ad.name as attribute_name, @@ -1355,8 +1392,10 @@ type ListAttributesDetailRow struct { // 'fqn', CONCAT(fqns.fqn, '/value/', avt.value) // ) ORDER BY ARRAY_POSITION(ad.values_order, avt.id) // ) AS values, -// fqns.fqn +// fqns.fqn, +// counted.total // FROM attribute_definitions ad +// CROSS JOIN counted // LEFT JOIN attribute_namespaces n ON n.id = ad.namespace_id // LEFT JOIN ( // SELECT @@ -1382,9 +1421,17 @@ type ListAttributesDetailRow struct { // ($1::BOOLEAN IS NULL OR ad.active = $1) AND // (NULLIF($2, '') IS NULL OR ad.namespace_id = $2::uuid) AND // (NULLIF($3, '') IS NULL OR n.name = $3) -// GROUP BY ad.id, n.name, fqns.fqn +// GROUP BY ad.id, n.name, fqns.fqn, counted.total +// LIMIT $5 +// OFFSET $4 func (q *Queries) ListAttributesDetail(ctx context.Context, arg ListAttributesDetailParams) ([]ListAttributesDetailRow, error) { - rows, err := q.db.Query(ctx, listAttributesDetail, arg.Active, arg.NamespaceID, arg.NamespaceName) + rows, err := q.db.Query(ctx, listAttributesDetail, + arg.Active, + arg.NamespaceID, + arg.NamespaceName, + arg.Offset, + arg.Limit, + ) if err != nil { return nil, err } @@ -1402,6 +1449,7 @@ func (q *Queries) ListAttributesDetail(ctx context.Context, arg ListAttributesDe &i.NamespaceName, &i.Values, &i.Fqn, + &i.Total, ); err != nil { return nil, err } @@ -1414,6 +1462,9 @@ func (q *Queries) ListAttributesDetail(ctx context.Context, arg ListAttributesDe } const listAttributesSummary = `-- name: ListAttributesSummary :many +WITH counted AS ( + SELECT COUNT(ad.id) AS total FROM attribute_definitions ad +) SELECT ad.id, ad.name as attribute_name, @@ -1421,13 +1472,23 @@ SELECT JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', ad.metadata -> 'labels', 'created_at', ad.created_at, 'updated_at', ad.updated_at)) AS metadata, ad.namespace_id, ad.active, - n.name as namespace_name + n.name as namespace_name, + counted.total FROM attribute_definitions ad +CROSS JOIN counted LEFT JOIN attribute_namespaces n ON n.id = ad.namespace_id WHERE ad.namespace_id = $1 -GROUP BY ad.id, n.name +GROUP BY ad.id, n.name, counted.total +LIMIT $3 +OFFSET $2 ` +type ListAttributesSummaryParams struct { + NamespaceID string `json:"namespace_id"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListAttributesSummaryRow struct { ID string `json:"id"` AttributeName string `json:"attribute_name"` @@ -1436,10 +1497,14 @@ type ListAttributesSummaryRow struct { NamespaceID string `json:"namespace_id"` Active bool `json:"active"` NamespaceName pgtype.Text `json:"namespace_name"` + Total int64 `json:"total"` } // ListAttributesSummary // +// WITH counted AS ( +// SELECT COUNT(ad.id) AS total FROM attribute_definitions ad +// ) // SELECT // ad.id, // ad.name as attribute_name, @@ -1447,13 +1512,17 @@ type ListAttributesSummaryRow struct { // JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', ad.metadata -> 'labels', 'created_at', ad.created_at, 'updated_at', ad.updated_at)) AS metadata, // ad.namespace_id, // ad.active, -// n.name as namespace_name +// n.name as namespace_name, +// counted.total // FROM attribute_definitions ad +// CROSS JOIN counted // LEFT JOIN attribute_namespaces n ON n.id = ad.namespace_id // WHERE ad.namespace_id = $1 -// GROUP BY ad.id, n.name -func (q *Queries) ListAttributesSummary(ctx context.Context, namespaceID string) ([]ListAttributesSummaryRow, error) { - rows, err := q.db.Query(ctx, listAttributesSummary, namespaceID) +// GROUP BY ad.id, n.name, counted.total +// LIMIT $3 +// OFFSET $2 +func (q *Queries) ListAttributesSummary(ctx context.Context, arg ListAttributesSummaryParams) ([]ListAttributesSummaryRow, error) { + rows, err := q.db.Query(ctx, listAttributesSummary, arg.NamespaceID, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -1469,6 +1538,7 @@ func (q *Queries) ListAttributesSummary(ctx context.Context, namespaceID string) &i.NamespaceID, &i.Active, &i.NamespaceName, + &i.Total, ); err != nil { return nil, err } @@ -1482,58 +1552,76 @@ func (q *Queries) ListAttributesSummary(ctx context.Context, namespaceID string) const listKeyAccessServerGrants = `-- name: ListKeyAccessServerGrants :many +WITH listed AS ( + SELECT + COUNT(*) OVER() AS total, + kas.id AS kas_id, + kas.uri AS kas_uri, + kas.name AS kas_name, + kas.public_key AS kas_public_key, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT( + 'labels', kas.metadata -> 'labels', + 'created_at', kas.created_at, + 'updated_at', kas.updated_at + )) AS kas_metadata, + JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'id', attrkag.attribute_definition_id, + 'fqn', fqns_on_attr.fqn + )) FILTER (WHERE attrkag.attribute_definition_id IS NOT NULL) AS attributes_grants, + JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'id', valkag.attribute_value_id, + 'fqn', fqns_on_vals.fqn + )) FILTER (WHERE valkag.attribute_value_id IS NOT NULL) AS values_grants, + JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'id', nskag.namespace_id, + 'fqn', fqns_on_ns.fqn + )) FILTER (WHERE nskag.namespace_id IS NOT NULL) AS namespace_grants + FROM + key_access_servers kas + LEFT JOIN + attribute_definition_key_access_grants attrkag + ON kas.id = attrkag.key_access_server_id + LEFT JOIN + attribute_fqns fqns_on_attr + ON attrkag.attribute_definition_id = fqns_on_attr.attribute_id + AND fqns_on_attr.value_id IS NULL + LEFT JOIN + attribute_value_key_access_grants valkag + ON kas.id = valkag.key_access_server_id + LEFT JOIN + attribute_fqns fqns_on_vals + ON valkag.attribute_value_id = fqns_on_vals.value_id + LEFT JOIN + attribute_namespace_key_access_grants nskag + ON kas.id = nskag.key_access_server_id + LEFT JOIN + attribute_fqns fqns_on_ns + ON nskag.namespace_id = fqns_on_ns.namespace_id + AND fqns_on_ns.attribute_id IS NULL AND fqns_on_ns.value_id IS NULL + WHERE (NULLIF($3, '') IS NULL OR kas.id = $3::uuid) + AND (NULLIF($4, '') IS NULL OR kas.uri = $4::varchar) + AND (NULLIF($5, '') IS NULL OR kas.name = $5::varchar) + GROUP BY + kas.id +) SELECT - kas.id AS kas_id, - kas.uri AS kas_uri, - kas.name AS kas_name, - kas.public_key AS kas_public_key, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT( - 'labels', kas.metadata -> 'labels', - 'created_at', kas.created_at, - 'updated_at', kas.updated_at - )) AS kas_metadata, - JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( - 'id', attrkag.attribute_definition_id, - 'fqn', fqns_on_attr.fqn - )) FILTER (WHERE attrkag.attribute_definition_id IS NOT NULL) AS attributes_grants, - JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( - 'id', valkag.attribute_value_id, - 'fqn', fqns_on_vals.fqn - )) FILTER (WHERE valkag.attribute_value_id IS NOT NULL) AS values_grants, - JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( - 'id', nskag.namespace_id, - 'fqn', fqns_on_ns.fqn - )) FILTER (WHERE nskag.namespace_id IS NOT NULL) AS namespace_grants -FROM - key_access_servers kas -LEFT JOIN - attribute_definition_key_access_grants attrkag - ON kas.id = attrkag.key_access_server_id -LEFT JOIN - attribute_fqns fqns_on_attr - ON attrkag.attribute_definition_id = fqns_on_attr.attribute_id - AND fqns_on_attr.value_id IS NULL -LEFT JOIN - attribute_value_key_access_grants valkag - ON kas.id = valkag.key_access_server_id -LEFT JOIN - attribute_fqns fqns_on_vals - ON valkag.attribute_value_id = fqns_on_vals.value_id -LEFT JOIN - attribute_namespace_key_access_grants nskag - ON kas.id = nskag.key_access_server_id -LEFT JOIN - attribute_fqns fqns_on_ns - ON nskag.namespace_id = fqns_on_ns.namespace_id - AND fqns_on_ns.attribute_id IS NULL AND fqns_on_ns.value_id IS NULL -WHERE (NULLIF($1, '') IS NULL OR kas.id = $1::uuid) - AND (NULLIF($2, '') IS NULL OR kas.uri = $2::varchar) - AND (NULLIF($3, '') IS NULL OR kas.name = $3::varchar) -GROUP BY - kas.id + listed.kas_id, + listed.kas_uri, + listed.kas_name, + listed.kas_public_key, + listed.kas_metadata, + listed.attributes_grants, + listed.values_grants, + listed.namespace_grants, + listed.total +FROM listed +LIMIT $2 +OFFSET $1 ` type ListKeyAccessServerGrantsParams struct { + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` KasID interface{} `json:"kas_id"` KasUri interface{} `json:"kas_uri"` KasName interface{} `json:"kas_name"` @@ -1548,63 +1636,86 @@ type ListKeyAccessServerGrantsRow struct { AttributesGrants []byte `json:"attributes_grants"` ValuesGrants []byte `json:"values_grants"` NamespaceGrants []byte `json:"namespace_grants"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // KEY ACCESS SERVERS // -------------------------------------------------------------- // +// WITH listed AS ( +// SELECT +// COUNT(*) OVER() AS total, +// kas.id AS kas_id, +// kas.uri AS kas_uri, +// kas.name AS kas_name, +// kas.public_key AS kas_public_key, +// JSON_STRIP_NULLS(JSON_BUILD_OBJECT( +// 'labels', kas.metadata -> 'labels', +// 'created_at', kas.created_at, +// 'updated_at', kas.updated_at +// )) AS kas_metadata, +// JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( +// 'id', attrkag.attribute_definition_id, +// 'fqn', fqns_on_attr.fqn +// )) FILTER (WHERE attrkag.attribute_definition_id IS NOT NULL) AS attributes_grants, +// JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( +// 'id', valkag.attribute_value_id, +// 'fqn', fqns_on_vals.fqn +// )) FILTER (WHERE valkag.attribute_value_id IS NOT NULL) AS values_grants, +// JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( +// 'id', nskag.namespace_id, +// 'fqn', fqns_on_ns.fqn +// )) FILTER (WHERE nskag.namespace_id IS NOT NULL) AS namespace_grants +// FROM +// key_access_servers kas +// LEFT JOIN +// attribute_definition_key_access_grants attrkag +// ON kas.id = attrkag.key_access_server_id +// LEFT JOIN +// attribute_fqns fqns_on_attr +// ON attrkag.attribute_definition_id = fqns_on_attr.attribute_id +// AND fqns_on_attr.value_id IS NULL +// LEFT JOIN +// attribute_value_key_access_grants valkag +// ON kas.id = valkag.key_access_server_id +// LEFT JOIN +// attribute_fqns fqns_on_vals +// ON valkag.attribute_value_id = fqns_on_vals.value_id +// LEFT JOIN +// attribute_namespace_key_access_grants nskag +// ON kas.id = nskag.key_access_server_id +// LEFT JOIN +// attribute_fqns fqns_on_ns +// ON nskag.namespace_id = fqns_on_ns.namespace_id +// AND fqns_on_ns.attribute_id IS NULL AND fqns_on_ns.value_id IS NULL +// WHERE (NULLIF($3, '') IS NULL OR kas.id = $3::uuid) +// AND (NULLIF($4, '') IS NULL OR kas.uri = $4::varchar) +// AND (NULLIF($5, '') IS NULL OR kas.name = $5::varchar) +// GROUP BY +// kas.id +// ) // SELECT -// kas.id AS kas_id, -// kas.uri AS kas_uri, -// kas.name AS kas_name, -// kas.public_key AS kas_public_key, -// JSON_STRIP_NULLS(JSON_BUILD_OBJECT( -// 'labels', kas.metadata -> 'labels', -// 'created_at', kas.created_at, -// 'updated_at', kas.updated_at -// )) AS kas_metadata, -// JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( -// 'id', attrkag.attribute_definition_id, -// 'fqn', fqns_on_attr.fqn -// )) FILTER (WHERE attrkag.attribute_definition_id IS NOT NULL) AS attributes_grants, -// JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( -// 'id', valkag.attribute_value_id, -// 'fqn', fqns_on_vals.fqn -// )) FILTER (WHERE valkag.attribute_value_id IS NOT NULL) AS values_grants, -// JSON_AGG(DISTINCT JSONB_BUILD_OBJECT( -// 'id', nskag.namespace_id, -// 'fqn', fqns_on_ns.fqn -// )) FILTER (WHERE nskag.namespace_id IS NOT NULL) AS namespace_grants -// FROM -// key_access_servers kas -// LEFT JOIN -// attribute_definition_key_access_grants attrkag -// ON kas.id = attrkag.key_access_server_id -// LEFT JOIN -// attribute_fqns fqns_on_attr -// ON attrkag.attribute_definition_id = fqns_on_attr.attribute_id -// AND fqns_on_attr.value_id IS NULL -// LEFT JOIN -// attribute_value_key_access_grants valkag -// ON kas.id = valkag.key_access_server_id -// LEFT JOIN -// attribute_fqns fqns_on_vals -// ON valkag.attribute_value_id = fqns_on_vals.value_id -// LEFT JOIN -// attribute_namespace_key_access_grants nskag -// ON kas.id = nskag.key_access_server_id -// LEFT JOIN -// attribute_fqns fqns_on_ns -// ON nskag.namespace_id = fqns_on_ns.namespace_id -// AND fqns_on_ns.attribute_id IS NULL AND fqns_on_ns.value_id IS NULL -// WHERE (NULLIF($1, '') IS NULL OR kas.id = $1::uuid) -// AND (NULLIF($2, '') IS NULL OR kas.uri = $2::varchar) -// AND (NULLIF($3, '') IS NULL OR kas.name = $3::varchar) -// GROUP BY -// kas.id +// listed.kas_id, +// listed.kas_uri, +// listed.kas_name, +// listed.kas_public_key, +// listed.kas_metadata, +// listed.attributes_grants, +// listed.values_grants, +// listed.namespace_grants, +// listed.total +// FROM listed +// LIMIT $2 +// OFFSET $1 func (q *Queries) ListKeyAccessServerGrants(ctx context.Context, arg ListKeyAccessServerGrantsParams) ([]ListKeyAccessServerGrantsRow, error) { - rows, err := q.db.Query(ctx, listKeyAccessServerGrants, arg.KasID, arg.KasUri, arg.KasName) + rows, err := q.db.Query(ctx, listKeyAccessServerGrants, + arg.Offset, + arg.Limit, + arg.KasID, + arg.KasUri, + arg.KasName, + ) if err != nil { return nil, err } @@ -1621,6 +1732,7 @@ func (q *Queries) ListKeyAccessServerGrants(ctx context.Context, arg ListKeyAcce &i.AttributesGrants, &i.ValuesGrants, &i.NamespaceGrants, + &i.Total, ); err != nil { return nil, err } @@ -1633,26 +1745,54 @@ func (q *Queries) ListKeyAccessServerGrants(ctx context.Context, arg ListKeyAcce } const listKeyAccessServers = `-- name: ListKeyAccessServers :many -SELECT id, uri, public_key, name, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -FROM key_access_servers +WITH counted AS ( + SELECT COUNT(kas.id) AS total + FROM key_access_servers kas +) +SELECT kas.id, + kas.uri, + kas.public_key, + kas.name AS kas_name, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', kas.metadata -> 'labels', 'created_at', kas.created_at, 'updated_at', kas.updated_at)) as metadata, + counted.total +FROM key_access_servers kas +CROSS JOIN counted +LIMIT $2 +OFFSET $1 ` +type ListKeyAccessServersParams struct { + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListKeyAccessServersRow struct { ID string `json:"id"` Uri string `json:"uri"` PublicKey []byte `json:"public_key"` - Name pgtype.Text `json:"name"` + KasName pgtype.Text `json:"kas_name"` Metadata []byte `json:"metadata"` + Total int64 `json:"total"` } // ListKeyAccessServers // -// SELECT id, uri, public_key, name, -// JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -// FROM key_access_servers -func (q *Queries) ListKeyAccessServers(ctx context.Context) ([]ListKeyAccessServersRow, error) { - rows, err := q.db.Query(ctx, listKeyAccessServers) +// WITH counted AS ( +// SELECT COUNT(kas.id) AS total +// FROM key_access_servers kas +// ) +// SELECT kas.id, +// kas.uri, +// kas.public_key, +// kas.name AS kas_name, +// JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', kas.metadata -> 'labels', 'created_at', kas.created_at, 'updated_at', kas.updated_at)) as metadata, +// counted.total +// FROM key_access_servers kas +// CROSS JOIN counted +// LIMIT $2 +// OFFSET $1 +func (q *Queries) ListKeyAccessServers(ctx context.Context, arg ListKeyAccessServersParams) ([]ListKeyAccessServersRow, error) { + rows, err := q.db.Query(ctx, listKeyAccessServers, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -1664,8 +1804,9 @@ func (q *Queries) ListKeyAccessServers(ctx context.Context) ([]ListKeyAccessServ &i.ID, &i.Uri, &i.PublicKey, - &i.Name, + &i.KasName, &i.Metadata, + &i.Total, ); err != nil { return nil, err } @@ -1679,40 +1820,61 @@ func (q *Queries) ListKeyAccessServers(ctx context.Context) ([]ListKeyAccessServ const listNamespaces = `-- name: ListNamespaces :many +WITH counted AS ( + SELECT COUNT(id) AS total FROM attribute_namespaces +) SELECT ns.id, ns.name, ns.active, JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', ns.metadata -> 'labels', 'created_at', ns.created_at, 'updated_at', ns.updated_at)) as metadata, - fqns.fqn + fqns.fqn, + counted.total FROM attribute_namespaces ns +CROSS JOIN counted LEFT JOIN attribute_fqns fqns ON ns.id = fqns.namespace_id AND fqns.attribute_id IS NULL WHERE ($1::BOOLEAN IS NULL OR ns.active = $1::BOOLEAN) +LIMIT $3 +OFFSET $2 ` +type ListNamespacesParams struct { + Active pgtype.Bool `json:"active"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListNamespacesRow struct { ID string `json:"id"` Name string `json:"name"` Active bool `json:"active"` Metadata []byte `json:"metadata"` Fqn pgtype.Text `json:"fqn"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // NAMESPACES // -------------------------------------------------------------- // +// WITH counted AS ( +// SELECT COUNT(id) AS total FROM attribute_namespaces +// ) // SELECT // ns.id, // ns.name, // ns.active, // JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', ns.metadata -> 'labels', 'created_at', ns.created_at, 'updated_at', ns.updated_at)) as metadata, -// fqns.fqn +// fqns.fqn, +// counted.total // FROM attribute_namespaces ns +// CROSS JOIN counted // LEFT JOIN attribute_fqns fqns ON ns.id = fqns.namespace_id AND fqns.attribute_id IS NULL // WHERE ($1::BOOLEAN IS NULL OR ns.active = $1::BOOLEAN) -func (q *Queries) ListNamespaces(ctx context.Context, active pgtype.Bool) ([]ListNamespacesRow, error) { - rows, err := q.db.Query(ctx, listNamespaces, active) +// LIMIT $3 +// OFFSET $2 +func (q *Queries) ListNamespaces(ctx context.Context, arg ListNamespacesParams) ([]ListNamespacesRow, error) { + rows, err := q.db.Query(ctx, listNamespaces, arg.Active, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -1726,6 +1888,7 @@ func (q *Queries) ListNamespaces(ctx context.Context, active pgtype.Bool) ([]Lis &i.Active, &i.Metadata, &i.Fqn, + &i.Total, ); err != nil { return nil, err } @@ -1739,29 +1902,56 @@ func (q *Queries) ListNamespaces(ctx context.Context, active pgtype.Bool) ([]Lis const listResourceMappingGroups = `-- name: ListResourceMappingGroups :many -SELECT id, namespace_id, name, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -FROM resource_mapping_groups -WHERE (NULLIF($1, '') IS NULL OR namespace_id = $1::uuid) +WITH counted AS ( + SELECT COUNT(rmg.id) AS total + FROM resource_mapping_groups rmg +) +SELECT rmg.id, + rmg.namespace_id, + rmg.name, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', rmg.metadata -> 'labels', 'created_at', rmg.created_at, 'updated_at', rmg.updated_at)) as metadata, + counted.total +FROM resource_mapping_groups rmg +CROSS JOIN counted +WHERE (NULLIF($1, '') IS NULL OR rmg.namespace_id = $1::uuid) +LIMIT $3 +OFFSET $2 ` +type ListResourceMappingGroupsParams struct { + NamespaceID interface{} `json:"namespace_id"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListResourceMappingGroupsRow struct { ID string `json:"id"` NamespaceID string `json:"namespace_id"` Name string `json:"name"` Metadata []byte `json:"metadata"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // RESOURCE MAPPING GROUPS // -------------------------------------------------------------- // -// SELECT id, namespace_id, name, -// JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -// FROM resource_mapping_groups -// WHERE (NULLIF($1, '') IS NULL OR namespace_id = $1::uuid) -func (q *Queries) ListResourceMappingGroups(ctx context.Context, namespaceID interface{}) ([]ListResourceMappingGroupsRow, error) { - rows, err := q.db.Query(ctx, listResourceMappingGroups, namespaceID) +// WITH counted AS ( +// SELECT COUNT(rmg.id) AS total +// FROM resource_mapping_groups rmg +// ) +// SELECT rmg.id, +// rmg.namespace_id, +// rmg.name, +// JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', rmg.metadata -> 'labels', 'created_at', rmg.created_at, 'updated_at', rmg.updated_at)) as metadata, +// counted.total +// FROM resource_mapping_groups rmg +// CROSS JOIN counted +// WHERE (NULLIF($1, '') IS NULL OR rmg.namespace_id = $1::uuid) +// LIMIT $3 +// OFFSET $2 +func (q *Queries) ListResourceMappingGroups(ctx context.Context, arg ListResourceMappingGroupsParams) ([]ListResourceMappingGroupsRow, error) { + rows, err := q.db.Query(ctx, listResourceMappingGroups, arg.NamespaceID, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -1774,6 +1964,7 @@ func (q *Queries) ListResourceMappingGroups(ctx context.Context, namespaceID int &i.NamespaceID, &i.Name, &i.Metadata, + &i.Total, ); err != nil { return nil, err } @@ -1787,44 +1978,67 @@ func (q *Queries) ListResourceMappingGroups(ctx context.Context, namespaceID int const listResourceMappings = `-- name: ListResourceMappings :many +WITH counted AS ( + SELECT COUNT(rm.id) AS total + FROM resource_mappings rm +) SELECT m.id, JSON_BUILD_OBJECT('id', av.id, 'value', av.value, 'fqn', fqns.fqn) as attribute_value, m.terms, JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', m.metadata -> 'labels', 'created_at', m.created_at, 'updated_at', m.updated_at)) as metadata, - COALESCE(m.group_id::TEXT, '')::TEXT as group_id + COALESCE(m.group_id::TEXT, '')::TEXT as group_id, + counted.total FROM resource_mappings m +CROSS JOIN counted LEFT JOIN attribute_values av on m.attribute_value_id = av.id LEFT JOIN attribute_fqns fqns on av.id = fqns.value_id WHERE (NULLIF($1, '') IS NULL OR m.group_id = $1::UUID) -GROUP BY av.id, m.id, fqns.fqn +GROUP BY av.id, m.id, fqns.fqn, counted.total +LIMIT $3 +OFFSET $2 ` +type ListResourceMappingsParams struct { + GroupID interface{} `json:"group_id"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListResourceMappingsRow struct { ID string `json:"id"` AttributeValue []byte `json:"attribute_value"` Terms []string `json:"terms"` Metadata []byte `json:"metadata"` GroupID string `json:"group_id"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // RESOURCE MAPPING // -------------------------------------------------------------- // +// WITH counted AS ( +// SELECT COUNT(rm.id) AS total +// FROM resource_mappings rm +// ) // SELECT // m.id, // JSON_BUILD_OBJECT('id', av.id, 'value', av.value, 'fqn', fqns.fqn) as attribute_value, // m.terms, // JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', m.metadata -> 'labels', 'created_at', m.created_at, 'updated_at', m.updated_at)) as metadata, -// COALESCE(m.group_id::TEXT, '')::TEXT as group_id +// COALESCE(m.group_id::TEXT, '')::TEXT as group_id, +// counted.total // FROM resource_mappings m +// CROSS JOIN counted // LEFT JOIN attribute_values av on m.attribute_value_id = av.id // LEFT JOIN attribute_fqns fqns on av.id = fqns.value_id // WHERE (NULLIF($1, '') IS NULL OR m.group_id = $1::UUID) -// GROUP BY av.id, m.id, fqns.fqn -func (q *Queries) ListResourceMappings(ctx context.Context, groupID interface{}) ([]ListResourceMappingsRow, error) { - rows, err := q.db.Query(ctx, listResourceMappings, groupID) +// GROUP BY av.id, m.id, fqns.fqn, counted.total +// LIMIT $3 +// OFFSET $2 +func (q *Queries) ListResourceMappings(ctx context.Context, arg ListResourceMappingsParams) ([]ListResourceMappingsRow, error) { + rows, err := q.db.Query(ctx, listResourceMappings, arg.GroupID, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -1838,6 +2052,7 @@ func (q *Queries) ListResourceMappings(ctx context.Context, groupID interface{}) &i.Terms, &i.Metadata, &i.GroupID, + &i.Total, ); err != nil { return nil, err } @@ -1949,30 +2164,52 @@ func (q *Queries) ListResourceMappingsByFullyQualifiedGroup(ctx context.Context, const listSubjectConditionSets = `-- name: ListSubjectConditionSets :many +WITH counted AS ( + SELECT COUNT(scs.id) AS total + FROM subject_condition_set scs +) SELECT - id, - condition, - JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -FROM subject_condition_set + scs.id, + scs.condition, + JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', scs.metadata -> 'labels', 'created_at', scs.created_at, 'updated_at', scs.updated_at)) as metadata, + counted.total +FROM subject_condition_set scs +CROSS JOIN counted +LIMIT $2 +OFFSET $1 ` +type ListSubjectConditionSetsParams struct { + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListSubjectConditionSetsRow struct { ID string `json:"id"` Condition []byte `json:"condition"` Metadata []byte `json:"metadata"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // SUBJECT CONDITION SETS // -------------------------------------------------------------- // +// WITH counted AS ( +// SELECT COUNT(scs.id) AS total +// FROM subject_condition_set scs +// ) // SELECT -// id, -// condition, -// JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', metadata -> 'labels', 'created_at', created_at, 'updated_at', updated_at)) as metadata -// FROM subject_condition_set -func (q *Queries) ListSubjectConditionSets(ctx context.Context) ([]ListSubjectConditionSetsRow, error) { - rows, err := q.db.Query(ctx, listSubjectConditionSets) +// scs.id, +// scs.condition, +// JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', scs.metadata -> 'labels', 'created_at', scs.created_at, 'updated_at', scs.updated_at)) as metadata, +// counted.total +// FROM subject_condition_set scs +// CROSS JOIN counted +// LIMIT $2 +// OFFSET $1 +func (q *Queries) ListSubjectConditionSets(ctx context.Context, arg ListSubjectConditionSetsParams) ([]ListSubjectConditionSetsRow, error) { + rows, err := q.db.Query(ctx, listSubjectConditionSets, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -1980,7 +2217,12 @@ func (q *Queries) ListSubjectConditionSets(ctx context.Context) ([]ListSubjectCo var items []ListSubjectConditionSetsRow for rows.Next() { var i ListSubjectConditionSetsRow - if err := rows.Scan(&i.ID, &i.Condition, &i.Metadata); err != nil { + if err := rows.Scan( + &i.ID, + &i.Condition, + &i.Metadata, + &i.Total, + ); err != nil { return nil, err } items = append(items, i) @@ -1993,6 +2235,10 @@ func (q *Queries) ListSubjectConditionSets(ctx context.Context) ([]ListSubjectCo const listSubjectMappings = `-- name: ListSubjectMappings :many +WITH counted AS ( + SELECT COUNT(sm.id) AS total + FROM subject_mappings sm +) SELECT sm.id, sm.actions, @@ -2002,25 +2248,39 @@ SELECT 'metadata', JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', scs.metadata->'labels', 'created_at', scs.created_at, 'updated_at', scs.updated_at)), 'subject_sets', scs.condition ) AS subject_condition_set, - JSON_BUILD_OBJECT('id', av.id,'value', av.value,'active', av.active) AS attribute_value + JSON_BUILD_OBJECT('id', av.id,'value', av.value,'active', av.active) AS attribute_value, + counted.total FROM subject_mappings sm +CROSS JOIN counted LEFT JOIN attribute_values av ON sm.attribute_value_id = av.id LEFT JOIN subject_condition_set scs ON scs.id = sm.subject_condition_set_id -GROUP BY av.id, sm.id, scs.id +GROUP BY av.id, sm.id, scs.id, counted.total +LIMIT $2 +OFFSET $1 ` +type ListSubjectMappingsParams struct { + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + type ListSubjectMappingsRow struct { ID string `json:"id"` Actions []byte `json:"actions"` Metadata []byte `json:"metadata"` SubjectConditionSet []byte `json:"subject_condition_set"` AttributeValue []byte `json:"attribute_value"` + Total int64 `json:"total"` } // -------------------------------------------------------------- // SUBJECT MAPPINGS // -------------------------------------------------------------- // +// WITH counted AS ( +// SELECT COUNT(sm.id) AS total +// FROM subject_mappings sm +// ) // SELECT // sm.id, // sm.actions, @@ -2030,13 +2290,17 @@ type ListSubjectMappingsRow struct { // 'metadata', JSON_STRIP_NULLS(JSON_BUILD_OBJECT('labels', scs.metadata->'labels', 'created_at', scs.created_at, 'updated_at', scs.updated_at)), // 'subject_sets', scs.condition // ) AS subject_condition_set, -// JSON_BUILD_OBJECT('id', av.id,'value', av.value,'active', av.active) AS attribute_value +// JSON_BUILD_OBJECT('id', av.id,'value', av.value,'active', av.active) AS attribute_value, +// counted.total // FROM subject_mappings sm +// CROSS JOIN counted // LEFT JOIN attribute_values av ON sm.attribute_value_id = av.id // LEFT JOIN subject_condition_set scs ON scs.id = sm.subject_condition_set_id -// GROUP BY av.id, sm.id, scs.id -func (q *Queries) ListSubjectMappings(ctx context.Context) ([]ListSubjectMappingsRow, error) { - rows, err := q.db.Query(ctx, listSubjectMappings) +// GROUP BY av.id, sm.id, scs.id, counted.total +// LIMIT $2 +// OFFSET $1 +func (q *Queries) ListSubjectMappings(ctx context.Context, arg ListSubjectMappingsParams) ([]ListSubjectMappingsRow, error) { + rows, err := q.db.Query(ctx, listSubjectMappings, arg.Offset, arg.Limit) if err != nil { return nil, err } @@ -2050,6 +2314,7 @@ func (q *Queries) ListSubjectMappings(ctx context.Context) ([]ListSubjectMapping &i.Metadata, &i.SubjectConditionSet, &i.AttributeValue, + &i.Total, ); err != nil { return nil, err } diff --git a/service/policy/db/resource_mapping.go b/service/policy/db/resource_mapping.go index b8d8c0f7a..210aea57f 100644 --- a/service/policy/db/resource_mapping.go +++ b/service/policy/db/resource_mapping.go @@ -17,8 +17,19 @@ import ( Resource Mapping CRUD */ -func (c PolicyDBClient) ListResourceMappingGroups(ctx context.Context, r *resourcemapping.ListResourceMappingGroupsRequest) ([]*policy.ResourceMappingGroup, error) { - list, err := c.Queries.ListResourceMappingGroups(ctx, r.GetNamespaceId()) +func (c PolicyDBClient) ListResourceMappingGroups(ctx context.Context, r *resourcemapping.ListResourceMappingGroupsRequest) (*resourcemapping.ListResourceMappingGroupsResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + + list, err := c.Queries.ListResourceMappingGroups(ctx, ListResourceMappingGroupsParams{ + NamespaceID: r.GetNamespaceId(), + Limit: limit, + Offset: offset, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } @@ -39,7 +50,21 @@ func (c PolicyDBClient) ListResourceMappingGroups(ctx context.Context, r *resour } } - return rmGroups, nil + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + return &resourcemapping.ListResourceMappingGroupsResponse{ + ResourceMappingGroups: rmGroups, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } func (c PolicyDBClient) GetResourceMappingGroup(ctx context.Context, id string) (*policy.ResourceMappingGroup, error) { @@ -141,8 +166,19 @@ func (c PolicyDBClient) DeleteResourceMappingGroup(ctx context.Context, id strin Resource Mapping CRUD */ -func (c PolicyDBClient) ListResourceMappings(ctx context.Context, r *resourcemapping.ListResourceMappingsRequest) ([]*policy.ResourceMapping, error) { - list, err := c.Queries.ListResourceMappings(ctx, r.GetGroupId()) +func (c PolicyDBClient) ListResourceMappings(ctx context.Context, r *resourcemapping.ListResourceMappingsRequest) (*resourcemapping.ListResourceMappingsResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + + list, err := c.Queries.ListResourceMappings(ctx, ListResourceMappingsParams{ + GroupID: r.GetGroupId(), + Limit: limit, + Offset: offset, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } @@ -177,7 +213,21 @@ func (c PolicyDBClient) ListResourceMappings(ctx context.Context, r *resourcemap mappings[i] = mapping } - return mappings, nil + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + return &resourcemapping.ListResourceMappingsResponse{ + ResourceMappings: mappings, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } func (c PolicyDBClient) ListResourceMappingsByGroupFqns(ctx context.Context, fqns []string) (map[string]*resourcemapping.ResourceMappingsByGroup, error) { diff --git a/service/policy/db/subject_mappings.go b/service/policy/db/subject_mappings.go index 10257e842..15b8f392a 100644 --- a/service/policy/db/subject_mappings.go +++ b/service/policy/db/subject_mappings.go @@ -136,8 +136,18 @@ func (c PolicyDBClient) GetSubjectConditionSet(ctx context.Context, id string) ( }, nil } -func (c PolicyDBClient) ListSubjectConditionSets(ctx context.Context) ([]*policy.SubjectConditionSet, error) { - list, err := c.Queries.ListSubjectConditionSets(ctx) +func (c PolicyDBClient) ListSubjectConditionSets(ctx context.Context, r *subjectmapping.ListSubjectConditionSetsRequest) (*subjectmapping.ListSubjectConditionSetsResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + + list, err := c.Queries.ListSubjectConditionSets(ctx, ListSubjectConditionSetsParams{ + Limit: limit, + Offset: offset, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } @@ -161,7 +171,21 @@ func (c PolicyDBClient) ListSubjectConditionSets(ctx context.Context) ([]*policy } } - return setList, nil + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + return &subjectmapping.ListSubjectConditionSetsResponse{ + SubjectConditionSets: setList, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } // Mutates provided fields and returns the updated subject condition set @@ -339,8 +363,18 @@ func (c PolicyDBClient) GetSubjectMapping(ctx context.Context, id string) (*poli }, nil } -func (c PolicyDBClient) ListSubjectMappings(ctx context.Context) ([]*policy.SubjectMapping, error) { - list, err := c.Queries.ListSubjectMappings(ctx) +func (c PolicyDBClient) ListSubjectMappings(ctx context.Context, r *subjectmapping.ListSubjectMappingsRequest) (*subjectmapping.ListSubjectMappingsResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, db.ErrListLimitTooLarge + } + + list, err := c.Queries.ListSubjectMappings(ctx, ListSubjectMappingsParams{ + Limit: limit, + Offset: offset, + }) if err != nil { return nil, db.WrapIfKnownInvalidQueryErr(err) } @@ -376,7 +410,21 @@ func (c PolicyDBClient) ListSubjectMappings(ctx context.Context) ([]*policy.Subj } } - return mappings, nil + var total int32 + var nextOffset int32 + if len(list) > 0 { + total = int32(list[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + return &subjectmapping.ListSubjectMappingsResponse{ + SubjectMappings: mappings, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + }, nil } // Mutates provided fields and returns the updated subject mapping diff --git a/service/policy/db/utils.go b/service/policy/db/utils.go index 30864bb90..c0f4784c1 100644 --- a/service/policy/db/utils.go +++ b/service/policy/db/utils.go @@ -10,6 +10,27 @@ import ( "google.golang.org/protobuf/encoding/protojson" ) +// Gathers request pagination limit/offset or configured default +func (c PolicyDBClient) getRequestedLimitOffset(page *policy.PageRequest) (int32, int32) { + return getListLimit(page.GetLimit(), c.listCfg.limitDefault), page.GetOffset() +} + +func getListLimit(limit int32, fallback int32) int32 { + if limit > 0 { + return limit + } + return fallback +} + +// Returns next page's offset if has not yet reached total, or else returns 0 +func getNextOffset(currentOffset, limit, total int32) int32 { + next := currentOffset + limit + if next < total { + return next + } + return 0 +} + func unmarshalMetadata(metadataJSON []byte, m *common.Metadata) error { if metadataJSON != nil { if err := protojson.Unmarshal(metadataJSON, m); err != nil { diff --git a/service/policy/db/utils_test.go b/service/policy/db/utils_test.go new file mode 100644 index 000000000..8cc720a42 --- /dev/null +++ b/service/policy/db/utils_test.go @@ -0,0 +1,92 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_GetListLimit(t *testing.T) { + var defaultListLimit int32 = 1000 + cases := []struct { + limit int32 + expected int32 + }{ + { + 0, + 1000, + }, + { + 1, + 1, + }, + { + 10000, + 10000, + }, + } + + for _, test := range cases { + result := getListLimit(test.limit, defaultListLimit) + assert.Equal(t, test.expected, result) + } +} + +func Test_GetNextOffset(t *testing.T) { + var defaultTestListLimit int32 = 250 + cases := []struct { + currOffset int32 + limit int32 + total int32 + expected int32 + scenario string + }{ + { + currOffset: 0, + limit: defaultTestListLimit, + total: 1000, + expected: defaultTestListLimit, + scenario: "defaulted limit with many remaining", + }, + { + currOffset: 100, + limit: 100, + total: 1000, + expected: 200, + scenario: "custom limit with many remaining", + }, + { + currOffset: 100, + limit: 100, + total: 200, + expected: 0, + scenario: "custom limit with none remaining", + }, + { + currOffset: 100, + limit: defaultTestListLimit, + total: 200, + expected: 0, + scenario: "default limit with none remaining", + }, + { + currOffset: 350 - defaultTestListLimit - 1, + limit: defaultTestListLimit, + total: 350, + expected: 349, + scenario: "default limit with exactly one remaining", + }, + { + currOffset: 1000 - 500 - 1, + limit: 500, + total: 1000, + expected: 1000 - 1, + scenario: "custom limit with exactly one remaining", + }, + } + + for _, test := range cases { + result := getNextOffset(test.currOffset, test.limit, test.total) + assert.Equal(t, test.expected, result, test.scenario) + } +} diff --git a/service/policy/kasregistry/key_access_server_registry.go b/service/policy/kasregistry/key_access_server_registry.go index 088996f21..b42922a9f 100644 --- a/service/policy/kasregistry/key_access_server_registry.go +++ b/service/policy/kasregistry/key_access_server_registry.go @@ -12,12 +12,15 @@ import ( "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" + policyconfig "github.com/opentdf/platform/service/policy/config" + policydb "github.com/opentdf/platform/service/policy/db" ) type KeyAccessServerRegistry struct { dbClient policydb.PolicyDBClient logger *logger.Logger + config *policyconfig.Config } func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[kasregistryconnect.KeyAccessServerRegistryServiceHandler] { @@ -29,8 +32,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer ConnectRPCFunc: kasregistryconnect.NewKeyAccessServerRegistryServiceHandler, GRPCGateayFunc: kasr.RegisterKeyAccessServerRegistryServiceHandlerFromEndpoint, RegisterFunc: func(srp serviceregistry.RegistrationParams) (kasregistryconnect.KeyAccessServerRegistryServiceHandler, serviceregistry.HandlerServer) { - ksr := &KeyAccessServerRegistry{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} - return ksr, nil + cfg := policyconfig.GetSharedPolicyConfig(srp) + return &KeyAccessServerRegistry{ + dbClient: policydb.NewClient(srp.DBClient, srp.Logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)), + logger: srp.Logger, + config: cfg, + }, nil }, }, } @@ -64,17 +71,13 @@ func (s KeyAccessServerRegistry) CreateKeyAccessServer(ctx context.Context, } func (s KeyAccessServerRegistry) ListKeyAccessServers(ctx context.Context, - _ *connect.Request[kasr.ListKeyAccessServersRequest], + req *connect.Request[kasr.ListKeyAccessServersRequest], ) (*connect.Response[kasr.ListKeyAccessServersResponse], error) { - rsp := &kasr.ListKeyAccessServersResponse{} - - keyAccessServers, err := s.dbClient.ListKeyAccessServers(ctx) + rsp, err := s.dbClient.ListKeyAccessServers(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.KeyAccessServers = keyAccessServers - return connect.NewResponse(rsp), nil } @@ -158,14 +161,10 @@ func (s KeyAccessServerRegistry) DeleteKeyAccessServer(ctx context.Context, func (s KeyAccessServerRegistry) ListKeyAccessServerGrants(ctx context.Context, req *connect.Request[kasr.ListKeyAccessServerGrantsRequest], ) (*connect.Response[kasr.ListKeyAccessServerGrantsResponse], error) { - rsp := &kasr.ListKeyAccessServerGrantsResponse{} - - keyAccessServerGrants, err := s.dbClient.ListKeyAccessServerGrants(ctx, req.Msg.GetKasId(), req.Msg.GetKasUri(), req.Msg.GetKasName()) + rsp, err := s.dbClient.ListKeyAccessServerGrants(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.Grants = keyAccessServerGrants - return connect.NewResponse(rsp), nil } diff --git a/service/policy/namespaces/namespaces.go b/service/policy/namespaces/namespaces.go index 37cb53970..7ae3d7181 100644 --- a/service/policy/namespaces/namespaces.go +++ b/service/policy/namespaces/namespaces.go @@ -13,12 +13,14 @@ import ( "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" + policyconfig "github.com/opentdf/platform/service/policy/config" policydb "github.com/opentdf/platform/service/policy/db" ) type NamespacesService struct { //nolint:revive // NamespacesService is a valid name dbClient policydb.PolicyDBClient logger *logger.Logger + config *policyconfig.Config } func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[namespacesconnect.NamespaceServiceHandler] { @@ -30,7 +32,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer ConnectRPCFunc: namespacesconnect.NewNamespaceServiceHandler, GRPCGateayFunc: namespaces.RegisterNamespaceServiceHandlerFromEndpoint, RegisterFunc: func(srp serviceregistry.RegistrationParams) (namespacesconnect.NamespaceServiceHandler, serviceregistry.HandlerServer) { - ns := &NamespacesService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + cfg := policyconfig.GetSharedPolicyConfig(srp) + ns := &NamespacesService{ + dbClient: policydb.NewClient(srp.DBClient, srp.Logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)), + logger: srp.Logger, + config: cfg, + } if err := srp.RegisterReadinessCheck("policy", ns.IsReady); err != nil { srp.Logger.Error("failed to register policy readiness check", slog.String("error", err.Error())) @@ -54,17 +61,15 @@ func (ns NamespacesService) IsReady(ctx context.Context) error { } func (ns NamespacesService) ListNamespaces(ctx context.Context, req *connect.Request[namespaces.ListNamespacesRequest]) (*connect.Response[namespaces.ListNamespacesResponse], error) { - state := policydb.GetDBStateTypeTransformedEnum(req.Msg.GetState()) + state := req.Msg.GetState().String() ns.logger.Debug("listing namespaces", slog.String("state", state)) - rsp := &namespaces.ListNamespacesResponse{} - list, err := ns.dbClient.ListNamespaces(ctx, state) + rsp, err := ns.dbClient.ListNamespaces(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } ns.logger.Debug("listed namespaces") - rsp.Namespaces = list return connect.NewResponse(rsp), nil } diff --git a/service/policy/resourcemapping/resource_mapping.go b/service/policy/resourcemapping/resource_mapping.go index ba41489ce..ec385e107 100644 --- a/service/policy/resourcemapping/resource_mapping.go +++ b/service/policy/resourcemapping/resource_mapping.go @@ -12,12 +12,14 @@ import ( "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" + policyconfig "github.com/opentdf/platform/service/policy/config" policydb "github.com/opentdf/platform/service/policy/db" ) type ResourceMappingService struct { //nolint:revive // ResourceMappingService is a valid name for this struct dbClient policydb.PolicyDBClient logger *logger.Logger + config *policyconfig.Config } func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[resourcemappingconnect.ResourceMappingServiceHandler] { @@ -29,8 +31,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer ConnectRPCFunc: resourcemappingconnect.NewResourceMappingServiceHandler, GRPCGateayFunc: resourcemapping.RegisterResourceMappingServiceHandlerFromEndpoint, RegisterFunc: func(srp serviceregistry.RegistrationParams) (resourcemappingconnect.ResourceMappingServiceHandler, serviceregistry.HandlerServer) { - rm := &ResourceMappingService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} - return rm, nil + cfg := policyconfig.GetSharedPolicyConfig(srp) + return &ResourceMappingService{ + dbClient: policydb.NewClient(srp.DBClient, srp.Logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)), + logger: srp.Logger, + config: cfg, + }, nil }, }, } @@ -41,15 +47,11 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer */ func (s ResourceMappingService) ListResourceMappingGroups(ctx context.Context, req *connect.Request[resourcemapping.ListResourceMappingGroupsRequest]) (*connect.Response[resourcemapping.ListResourceMappingGroupsResponse], error) { - rsp := &resourcemapping.ListResourceMappingGroupsResponse{} - - rmGroups, err := s.dbClient.ListResourceMappingGroups(ctx, req.Msg) + rsp, err := s.dbClient.ListResourceMappingGroups(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.ResourceMappingGroups = rmGroups - return connect.NewResponse(rsp), nil } @@ -157,15 +159,11 @@ func (s ResourceMappingService) DeleteResourceMappingGroup(ctx context.Context, func (s ResourceMappingService) ListResourceMappings(ctx context.Context, req *connect.Request[resourcemapping.ListResourceMappingsRequest], ) (*connect.Response[resourcemapping.ListResourceMappingsResponse], error) { - rsp := &resourcemapping.ListResourceMappingsResponse{} - - resourceMappings, err := s.dbClient.ListResourceMappings(ctx, req.Msg) + rsp, err := s.dbClient.ListResourceMappings(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.ResourceMappings = resourceMappings - return connect.NewResponse(rsp), nil } diff --git a/service/policy/subjectmapping/subject_mapping.go b/service/policy/subjectmapping/subject_mapping.go index d089fe290..a2cc54090 100644 --- a/service/policy/subjectmapping/subject_mapping.go +++ b/service/policy/subjectmapping/subject_mapping.go @@ -12,12 +12,14 @@ import ( "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" + policyconfig "github.com/opentdf/platform/service/policy/config" policydb "github.com/opentdf/platform/service/policy/db" ) type SubjectMappingService struct { //nolint:revive // SubjectMappingService is a valid name for this struct dbClient policydb.PolicyDBClient logger *logger.Logger + config *policyconfig.Config } func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[subjectmappingconnect.SubjectMappingServiceHandler] { @@ -29,8 +31,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer ConnectRPCFunc: subjectmappingconnect.NewSubjectMappingServiceHandler, GRPCGateayFunc: sm.RegisterSubjectMappingServiceHandlerFromEndpoint, RegisterFunc: func(srp serviceregistry.RegistrationParams) (subjectmappingconnect.SubjectMappingServiceHandler, serviceregistry.HandlerServer) { - smSvc := &SubjectMappingService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} - return smSvc, nil + cfg := policyconfig.GetSharedPolicyConfig(srp) + return &SubjectMappingService{ + dbClient: policydb.NewClient(srp.DBClient, srp.Logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)), + logger: srp.Logger, + config: cfg, + }, nil }, }, } @@ -66,17 +72,15 @@ func (s SubjectMappingService) CreateSubjectMapping(ctx context.Context, } func (s SubjectMappingService) ListSubjectMappings(ctx context.Context, - _ *connect.Request[sm.ListSubjectMappingsRequest], + req *connect.Request[sm.ListSubjectMappingsRequest], ) (*connect.Response[sm.ListSubjectMappingsResponse], error) { - rsp := &sm.ListSubjectMappingsResponse{} s.logger.Debug("listing subject mappings") - mappings, err := s.dbClient.ListSubjectMappings(ctx) + rsp, err := s.dbClient.ListSubjectMappings(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.SubjectMappings = mappings return connect.NewResponse(rsp), nil } @@ -193,17 +197,15 @@ func (s SubjectMappingService) GetSubjectConditionSet(ctx context.Context, } func (s SubjectMappingService) ListSubjectConditionSets(ctx context.Context, - _ *connect.Request[sm.ListSubjectConditionSetsRequest], + req *connect.Request[sm.ListSubjectConditionSetsRequest], ) (*connect.Response[sm.ListSubjectConditionSetsResponse], error) { - rsp := &sm.ListSubjectConditionSetsResponse{} s.logger.Debug("listing subject condition sets") - conditionSets, err := s.dbClient.ListSubjectConditionSets(ctx) + rsp, err := s.dbClient.ListSubjectConditionSets(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) } - rsp.SubjectConditionSets = conditionSets return connect.NewResponse(rsp), nil } diff --git a/service/policy/unsafe/unsafe.go b/service/policy/unsafe/unsafe.go index 065866b0b..8c78cfa0c 100644 --- a/service/policy/unsafe/unsafe.go +++ b/service/policy/unsafe/unsafe.go @@ -12,12 +12,14 @@ import ( "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" + policyconfig "github.com/opentdf/platform/service/policy/config" policydb "github.com/opentdf/platform/service/policy/db" ) type UnsafeService struct { //nolint:revive // UnsafeService is a valid name for this struct dbClient policydb.PolicyDBClient logger *logger.Logger + config *policyconfig.Config } func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[unsafeconnect.UnsafeServiceHandler] { @@ -28,8 +30,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer ServiceDesc: &unsafe.UnsafeService_ServiceDesc, ConnectRPCFunc: unsafeconnect.NewUnsafeServiceHandler, RegisterFunc: func(srp serviceregistry.RegistrationParams) (unsafeconnect.UnsafeServiceHandler, serviceregistry.HandlerServer) { - unsafeSvc := &UnsafeService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} - return unsafeSvc, nil + cfg := policyconfig.GetSharedPolicyConfig(srp) + return &UnsafeService{ + dbClient: policydb.NewClient(srp.DBClient, srp.Logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)), + logger: srp.Logger, + config: cfg, + }, nil }, }, }