Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): remove hex encoding for segment hash #1805

Merged
merged 23 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ae8144a
feat(kas): remove hex encoding for segment hash
sujankota Dec 2, 2024
27daabe
fix the tests
sujankota Dec 2, 2024
81af2b1
Merge branch 'main' into feat/remove-hex-encoding-tdf3
sujankota Dec 2, 2024
8f17eff
fix assertion hash
sujankota Dec 7, 2024
ddda847
handle assertion for legacy tdfs
sujankota Dec 19, 2024
f7171d9
Merge branch 'main' into feat/remove-hex-encoding-tdf3
sujankota Jan 2, 2025
4cf0aba
Update the intial version of the sdk to 4.3.0
sujankota Jan 2, 2025
ae6ce48
Merge branch 'main' into feat/remove-hex-encoding-tdf3
sujankota Jan 7, 2025
ac4073d
Address code review comments
sujankota Jan 7, 2025
e0ba076
add option to add build meta data
sujankota Jan 8, 2025
aec5fb1
Add schema version to key access object
sujankota Jan 12, 2025
a66fc8f
Merge branch 'main' into feat/remove-hex-encoding-tdf3
sujankota Jan 12, 2025
14025e9
Remove the build meta version fromt his PR
sujankota Jan 16, 2025
4a8689e
Merge branch 'main' into feat/remove-hex-encoding-tdf3
sujankota Jan 16, 2025
51591fd
Update version.go
dmihalcik-virtru Jan 17, 2025
16392a2
Merge branch 'main' into feat/remove-hex-encoding-tdf3
dmihalcik-virtru Jan 17, 2025
30446d0
Merge branch 'main' into feat/remove-hex-encoding-tdf3
dmihalcik-virtru Jan 17, 2025
edf2f18
Update tdf_test.go
dmihalcik-virtru Jan 21, 2025
1601efc
Merge branch 'main' into feat/remove-hex-encoding-tdf3
dmihalcik-virtru Jan 21, 2025
327b117
Update version.go
dmihalcik-virtru Jan 21, 2025
feccdaa
Merge branch 'feat/remove-hex-encoding-tdf3' of https://github.com/op…
dmihalcik-virtru Jan 21, 2025
9550b9d
Update tdf.go
dmihalcik-virtru Jan 21, 2025
4fb0fe8
Merge branch 'main' into feat/remove-hex-encoding-tdf3
dmihalcik-virtru Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Manifest struct {
EncryptionInformation `json:"encryptionInformation"`
Payload `json:"payload"`
Assertions []Assertion `json:"assertions,omitempty"`
TDFVersion string `json:"tdf_spec_version,omitempty"`
}

type attributeObject struct {
Expand Down
62 changes: 50 additions & 12 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
)

const (
sdkVersion = "4.3.0"
maxFileSizeSupported = 68719476736 // 64gb
defaultMimeType = "application/octet-stream"
tdfAsZip = "zip"
Expand Down Expand Up @@ -197,7 +198,8 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
return nil, fmt.Errorf("io.writer.Write failed: %w", err)
}

segmentSig, err := calculateSignature(cipherData, tdfObject.payloadKey[:], tdfConfig.segmentIntegrityAlgorithm)
segmentSig, err := calculateSignature(cipherData, tdfObject.payloadKey[:],
tdfConfig.segmentIntegrityAlgorithm, false)
if err != nil {
return nil, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand All @@ -216,7 +218,8 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
readPos += readSize
}

rootSignature, err := calculateSignature([]byte(aggregateHash), tdfObject.payloadKey[:], tdfConfig.integrityAlgorithm)
rootSignature, err := calculateSignature([]byte(aggregateHash), tdfObject.payloadKey[:],
tdfConfig.integrityAlgorithm, false)
if err != nil {
return nil, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand Down Expand Up @@ -263,11 +266,17 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
tmpAssertion.Statement = assertion.Statement
tmpAssertion.AppliesToState = assertion.AppliesToState

hashOfAssertion, err := tmpAssertion.GetHash()
hashOfAssertionAsHex, err := tmpAssertion.GetHash()
if err != nil {
return nil, err
}

hashOfAssertion := make([]byte, hex.DecodedLen(len(hashOfAssertionAsHex)))
_, err = hex.Decode(hashOfAssertion, hashOfAssertionAsHex)
if err != nil {
return nil, fmt.Errorf("error decoding hex string: %w", err)
}

var completeHashBuilder strings.Builder
completeHashBuilder.WriteString(aggregateHash)
completeHashBuilder.Write(hashOfAssertion)
Expand All @@ -284,7 +293,7 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
assertionSigningKey = assertion.SigningKey
}

if err := tmpAssertion.Sign(string(hashOfAssertion), string(encoded), assertionSigningKey); err != nil {
if err := tmpAssertion.Sign(string(hashOfAssertionAsHex), string(encoded), assertionSigningKey); err != nil {
return nil, fmt.Errorf("failed to sign assertion: %w", err)
}

Expand Down Expand Up @@ -322,6 +331,14 @@ func (r *Reader) Manifest() Manifest {
// prepare the manifest for TDF
func (s SDK) prepareManifest(ctx context.Context, t *TDFObject, tdfConfig TDFConfig) error { //nolint:funlen,gocognit // Better readability keeping it as is
manifest := Manifest{}

version, err := ParseVersion(sdkVersion)
if err != nil {
return fmt.Errorf("ReadVersion failed:%w", err)
}

manifest.TDFVersion = version.String()

if len(tdfConfig.splitPlan) == 0 && len(tdfConfig.kasInfoList) == 0 {
return fmt.Errorf("%w: no key access template specified or inferred", errInvalidKasInfo)
}
Expand Down Expand Up @@ -567,6 +584,8 @@ func (r *Reader) WriteTo(writer io.Writer) (int64, error) {
}
}

isLegacyTDF := r.manifest.TDFVersion == ""

var totalBytes int64
var payloadReadOffset int64
for _, seg := range r.manifest.EncryptionInformation.IntegrityInformation.Segments {
Expand All @@ -585,7 +604,7 @@ func (r *Reader) WriteTo(writer io.Writer) (int64, error) {
sigAlg = GMAC
}

payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg)
payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg, isLegacyTDF)
if err != nil {
return totalBytes, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand Down Expand Up @@ -646,6 +665,7 @@ func (r *Reader) ReadAt(buf []byte, offset int64) (int, error) { //nolint:funlen
return 0, ErrTDFPayloadReadFail
}

isLegacyTDF := r.manifest.TDFVersion == ""
var decryptedBuf bytes.Buffer
var payloadReadOffset int64
for index, seg := range r.manifest.EncryptionInformation.IntegrityInformation.Segments {
Expand All @@ -669,7 +689,7 @@ func (r *Reader) ReadAt(buf []byte, offset int64) (int, error) { //nolint:funlen
sigAlg = GMAC
}

payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg)
payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg, isLegacyTDF)
if err != nil {
return 0, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand Down Expand Up @@ -933,18 +953,29 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn
}

// Get the hash of the assertion
hashOfAssertion, err := assertion.GetHash()
hashOfAssertionAsHex, err := assertion.GetHash()
if err != nil {
return fmt.Errorf("%w: failed to get hash of assertion: %w", ErrAssertionFailure{ID: assertion.ID}, err)
}

hashOfAssertion := make([]byte, hex.DecodedLen(len(hashOfAssertionAsHex)))
_, err = hex.Decode(hashOfAssertion, hashOfAssertionAsHex)
if err != nil {
return fmt.Errorf("error decoding hex string: %w", err)
}

isLegacyTDF := r.manifest.TDFVersion == ""
if isLegacyTDF {
hashOfAssertion = hashOfAssertionAsHex
}

var completeHashBuilder bytes.Buffer
completeHashBuilder.Write(aggregateHash.Bytes())
completeHashBuilder.Write(hashOfAssertion)

base64Hash := ocrypto.Base64Encode(completeHashBuilder.Bytes())

if string(hashOfAssertion) != assertionHash {
if string(hashOfAssertionAsHex) != assertionHash {
return fmt.Errorf("%w: assertion hash missmatch", ErrAssertionFailure{ID: assertion.ID})
}

Expand Down Expand Up @@ -972,29 +1003,36 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn
}

// calculateSignature calculate signature of data of the given algorithm.
func calculateSignature(data []byte, secret []byte, alg IntegrityAlgorithm) (string, error) {
func calculateSignature(data []byte, secret []byte, alg IntegrityAlgorithm, isLegacyTDF bool) (string, error) {
if alg == HS256 {
hmac := ocrypto.CalculateSHA256Hmac(secret, data)
return hex.EncodeToString(hmac), nil
if isLegacyTDF {
return hex.EncodeToString(hmac), nil
}
return string(hmac), nil
}
if kGMACPayloadLength > len(data) {
return "", fmt.Errorf("fail to create gmac signature")
}

return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil
if isLegacyTDF {
return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil
}
return string(data[len(data)-kGMACPayloadLength:]), nil
}

// validate the root signature
func validateRootSignature(manifest Manifest, aggregateHash, secret []byte) (bool, error) {
rootSigAlg := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm
rootSigValue := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature
isLegacyTDF := manifest.TDFVersion == ""

sigAlg := HS256
if strings.EqualFold(gmacIntegrityAlgorithm, rootSigAlg) {
sigAlg = GMAC
}

sig, err := calculateSignature(aggregateHash, secret, sigAlg)
sig, err := calculateSignature(aggregateHash, secret, sigAlg, isLegacyTDF)
if err != nil {
return false, fmt.Errorf("splitkey.getSignature failed:%w", err)
}
Expand Down
47 changes: 38 additions & 9 deletions sdk/tdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/lib/ocrypto"
kaspb "github.com/opentdf/platform/protocol/go/kas"
Expand Down Expand Up @@ -264,7 +266,7 @@ func (s *TDFSuite) Test_SimpleTDF() {
"https://example.com/attr/Classification/value/X",
}

expectedTdfSize := int64(2095)
expectedTdfSize := int64(2058)
tdfFilename := "secure-text.tdf"
plainText := "Virtru"
{
Expand Down Expand Up @@ -394,7 +396,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
assertionVerificationKeys: nil,
disableAssertionVerification: false,
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -427,7 +429,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
DefaultKey: defaultKey,
},
disableAssertionVerification: false,
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -476,7 +478,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: false,
expectedSize: 3195,
expectedSize: 2988,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -516,7 +518,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: false,
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand All @@ -533,7 +535,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: true,
expectedSize: 2302,
expectedSize: 2180,
},
} {
expectedTdfSize := test.expectedSize
Expand Down Expand Up @@ -642,7 +644,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
SigningKey: defaultKey,
},
},
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -690,7 +692,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
},
},
},
expectedSize: 3195,
expectedSize: 2988,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -724,7 +726,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
assertionVerificationKeys: &AssertionVerificationKeys{
DefaultKey: defaultKey,
},
expectedSize: 2896,
expectedSize: 2689,
},
} {
expectedTdfSize := test.expectedSize
Expand Down Expand Up @@ -1479,3 +1481,30 @@ func (s *TDFSuite) checkIdentical(file, checksum string) bool {
c := h.Sum(nil)
return checksum == fmt.Sprintf("%x", c)
}

func TestParseVersion(t *testing.T) {
tests := []struct {
input string
expected Version
hasError bool
}{
{"1.2.3", Version{Major: 1, Minor: 2, Patch: 3}, false},
{"1.2.3+p1", Version{Major: 1, Minor: 2, Patch: 3, Preview: 1}, false},
{"1.2.3+p1.2", Version{Major: 1, Minor: 2, Patch: 3, Preview: 1, Revision: 2}, false},
{"1.2", Version{}, true},
{"1.2.3+p", Version{}, true},
{"1.2.3+p1.", Version{}, true},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, err := ParseVersion(tt.input)
if tt.hasError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, *result)
}
})
}
}
76 changes: 76 additions & 0 deletions sdk/version-parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package sdk

import (
"fmt"
"os"
"strings"
)

type Version struct {
Major int
Minor int
Patch int
Preview int
Revision int
}

func ReadVersion() (*Version, error) {
content, err := os.ReadFile("VERSION")
if err != nil {
return nil, fmt.Errorf("reading VERSION file: %w", err)
}

return ParseVersion(strings.TrimSpace(string(content)))
}

func ParseVersion(v string) (*Version, error) {
const maxParts = 2
var ver Version
var preview, revision string

parts := strings.SplitN(v, "+p", maxParts)
mainVersion := parts[0]

if len(parts) > 1 {
if parts[1] == "" {
return nil, fmt.Errorf("invalid preview format")
}
previewParts := strings.SplitN(parts[1], ".", maxParts)
preview = previewParts[0]
if len(previewParts) > 1 {
if previewParts[1] == "" {
return nil, fmt.Errorf("invalid revision format")
}
revision = previewParts[1]
}
}

if _, err := fmt.Sscanf(mainVersion, "%d.%d.%d", &ver.Major, &ver.Minor, &ver.Patch); err != nil {
return nil, fmt.Errorf("parsing version: %w", err)
}

if preview != "" {
if _, err := fmt.Sscanf(preview, "%d", &ver.Preview); err != nil {
return nil, fmt.Errorf("parsing preview version: %w", err)
}
}

if revision != "" {
if _, err := fmt.Sscanf(revision, "%d", &ver.Revision); err != nil {
return nil, fmt.Errorf("parsing revision: %w", err)
}
}

return &ver, nil
}

func (v *Version) String() string {
base := fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch)
if v.Preview > 0 {
if v.Revision > 0 {
return fmt.Sprintf("%s+p%d.%d", base, v.Preview, v.Revision)
}
return fmt.Sprintf("%s+p%d", base, v.Preview)
}
return base
}
Loading