Skip to content

Commit 7b5fc31

Browse files
committed
verify: use context for resource fetching
1 parent 43f7581 commit 7b5fc31

File tree

5 files changed

+269
-57
lines changed

5 files changed

+269
-57
lines changed

testing/mocks.go

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package testing
1616

1717
import (
18+
"context"
1819
"encoding/hex"
1920
"errors"
2021
"fmt"
@@ -246,6 +247,17 @@ func (g *Getter) Get(url string) ([]byte, error) {
246247
return body, err
247248
}
248249

250+
// GetContext checks whether the context expired, returns the context error if that's the case and
251+
// calls Get otherwise.
252+
func (g *Getter) GetContext(ctx context.Context, url string) ([]byte, error) {
253+
select {
254+
case <-ctx.Done():
255+
return nil, ctx.Err()
256+
default:
257+
return g.Get(url)
258+
}
259+
}
260+
249261
// Done checks that all configured responses have been consumed, and errors
250262
// otherwise.
251263
func (g *Getter) Done(t testing.TB) {

verify/trust/trust.go

+54-5
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ type HTTPSGetter interface {
130130
Get(url string) ([]byte, error)
131131
}
132132

133+
// ContextHTTPSGetter is an HTTPSGetter that accepts a context.Context.
134+
type ContextHTTPSGetter interface {
135+
GetContext(ctx context.Context, url string) ([]byte, error)
136+
}
137+
138+
// GetWith gets a resource from a URL using an HTTPSGetter.
139+
// If the HTTPSGetter implements ContextHTTPSGetter, the GetContext method will be used.
140+
func GetWith(ctx context.Context, getter HTTPSGetter, url string) ([]byte, error) {
141+
if contextGetter, ok := getter.(ContextHTTPSGetter); ok {
142+
return contextGetter.GetContext(ctx, url)
143+
}
144+
return getter.Get(url)
145+
}
146+
133147
// AttestationRecreationErr represents a problem with fetching or interpreting associated
134148
// certificates for a given attestation report. This is typically due to network unreliability.
135149
type AttestationRecreationErr struct {
@@ -145,7 +159,16 @@ type SimpleHTTPSGetter struct{}
145159

146160
// Get uses http.Get to return the HTTPS response body as a byte array.
147161
func (n *SimpleHTTPSGetter) Get(url string) ([]byte, error) {
148-
resp, err := http.Get(url)
162+
return n.GetContext(context.Background(), url)
163+
}
164+
165+
// GetContext behaves like get, but forwards the context to the http package.
166+
func (n *SimpleHTTPSGetter) GetContext(ctx context.Context, url string) ([]byte, error) {
167+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
168+
if err != nil {
169+
return nil, err
170+
}
171+
resp, err := http.DefaultClient.Do(req)
149172
if err != nil {
150173
return nil, err
151174
} else if resp.StatusCode >= 300 {
@@ -160,9 +183,16 @@ func (n *SimpleHTTPSGetter) Get(url string) ([]byte, error) {
160183
return body, nil
161184
}
162185

186+
var (
187+
_ = HTTPSGetter(&SimpleHTTPSGetter{})
188+
_ = ContextHTTPSGetter(&SimpleHTTPSGetter{})
189+
)
190+
163191
// RetryHTTPSGetter is a meta-HTTPS getter that will retry on failure a given number of times.
164192
type RetryHTTPSGetter struct {
165193
// Timeout is how long to retry before failure.
194+
// If Timeout is zero, the Get method will retry indefinitely and the GetContext method will
195+
// retry until the input context expires.
166196
Timeout time.Duration
167197
// MaxRetryDelay is the maximum amount of time to wait between retries.
168198
MaxRetryDelay time.Duration
@@ -172,11 +202,20 @@ type RetryHTTPSGetter struct {
172202

173203
// Get fetches the body of the URL, retrying a given amount of times on failure.
174204
func (n *RetryHTTPSGetter) Get(url string) ([]byte, error) {
205+
return n.GetContext(context.Background(), url)
206+
}
207+
208+
// GetContext behaves like get, but forwards the context to the Getter and stops retrying when the
209+
// context expired.
210+
func (n *RetryHTTPSGetter) GetContext(ctx context.Context, url string) ([]byte, error) {
175211
delay := initialDelay
176-
ctx, cancel := context.WithTimeout(context.Background(), n.Timeout)
212+
cancel := func() {}
213+
if n.Timeout > 0 {
214+
ctx, cancel = context.WithTimeout(ctx, n.Timeout)
215+
}
177216
var returnedError error
178217
for {
179-
body, err := n.Getter.Get(url)
218+
body, err := GetWith(ctx, n.Getter, url)
180219
if err == nil {
181220
cancel()
182221
return body, nil
@@ -189,12 +228,17 @@ func (n *RetryHTTPSGetter) Get(url string) ([]byte, error) {
189228
select {
190229
case <-ctx.Done():
191230
cancel()
192-
return nil, multierr.Append(returnedError, fmt.Errorf("timeout")) // context cancelled
231+
return nil, multierr.Append(returnedError, ctx.Err())
193232
case <-time.After(delay): // wait to retry
194233
}
195234
}
196235
}
197236

237+
var (
238+
_ = HTTPSGetter(&RetryHTTPSGetter{})
239+
_ = ContextHTTPSGetter(&RetryHTTPSGetter{})
240+
)
241+
198242
// DefaultHTTPSGetter returns the library's default getter implementation. It will
199243
// retry slowly due to the AMD KDS's rate limiting.
200244
func DefaultHTTPSGetter() HTTPSGetter {
@@ -311,14 +355,19 @@ func ClearProductCertCache() {
311355
// GetProductChain returns the ASK and ARK certificates of the given product line, either from getter
312356
// or from a cache of the results from the last successful call.
313357
func GetProductChain(productLine string, s abi.ReportSigner, getter HTTPSGetter) (*ProductCerts, error) {
358+
return GetProductChainContext(context.Background(), productLine, s, getter)
359+
}
360+
361+
// GetProductChainContext behaves like GetProductChain but forwards the context to the HTTPSGetter.
362+
func GetProductChainContext(ctx context.Context, productLine string, s abi.ReportSigner, getter HTTPSGetter) (*ProductCerts, error) {
314363
if productLineCertCache == nil {
315364
prodCacheMu.Lock()
316365
productLineCertCache = make(map[string]*ProductCerts)
317366
prodCacheMu.Unlock()
318367
}
319368
result, ok := productLineCertCache[productLine]
320369
if !ok {
321-
askark, err := getter.Get(kds.ProductCertChainURL(s, productLine))
370+
askark, err := GetWith(ctx, getter, kds.ProductCertChainURL(s, productLine))
322371
if err != nil {
323372
return nil, &AttestationRecreationErr{
324373
Msg: fmt.Sprintf("could not download ASK and ARK certificates: %v", err),

verify/trust/trust_test.go

+78-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package trust_test
1616

1717
import (
1818
"bytes"
19+
"context"
1920
"errors"
2021
"testing"
2122
"time"
@@ -127,10 +128,86 @@ func TestRetryHTTPSGetterAllFail(t *testing.T) {
127128

128129
body, err := r.Get("https://fetch.me")
129130
if !bytes.Equal(body, []byte("")) {
130-
t.Errorf("expected '%s' but got '%s'", "content", body)
131+
t.Errorf("expected empty body but got %q", body)
131132
}
132133
if err == nil {
133134
t.Errorf("expected error, but got none")
134135
}
135136
testGetter.Done(t)
136137
}
138+
139+
func TestRetryHTTPSGetterContext(t *testing.T) {
140+
testGetter := &test.Getter{
141+
Responses: map[string][]test.GetResponse{
142+
"https://fetch.me": {
143+
{
144+
Occurrences: 1,
145+
Body: []byte("content"),
146+
Error: nil,
147+
},
148+
},
149+
},
150+
}
151+
r := &trust.RetryHTTPSGetter{
152+
MaxRetryDelay: 1 * time.Millisecond,
153+
Getter: testGetter,
154+
}
155+
156+
ctx, cancel := context.WithCancel(context.Background())
157+
cancel()
158+
body, err := r.GetContext(ctx, "https://fetch.me")
159+
if !bytes.Equal(body, []byte("")) {
160+
t.Errorf("expected empty body but got %q", body)
161+
}
162+
if !errors.Is(err, context.Canceled) {
163+
t.Errorf("expected error %q, but got %q", context.Canceled, err)
164+
}
165+
}
166+
167+
type recordingGetter struct {
168+
getCalls int
169+
}
170+
171+
func (r *recordingGetter) Get(url string) ([]byte, error) {
172+
r.getCalls++
173+
return []byte{}, nil
174+
}
175+
176+
type recordingContextGetter struct {
177+
recordingGetter
178+
getContextCalls int
179+
}
180+
181+
func (r *recordingContextGetter) GetContext(ctx context.Context, url string) ([]byte, error) {
182+
r.getContextCalls++
183+
return []byte{}, nil
184+
}
185+
186+
func TestGetWith(t *testing.T) {
187+
url := ""
188+
t.Run("HTTPSGetter uses Get", func(t *testing.T) {
189+
contextGetter := recordingContextGetter{}
190+
if _, err := trust.GetWith(context.Background(), &contextGetter.recordingGetter, url); err != nil {
191+
t.Fatalf("trust.GetWith returned an unexpected error: %v", err)
192+
}
193+
if contextGetter.getContextCalls != 0 {
194+
t.Errorf("wrong number of calls to GetContext: got %d, want 0", contextGetter.getContextCalls)
195+
}
196+
if contextGetter.recordingGetter.getCalls != 1 {
197+
t.Errorf("wrong number of calls to Get: got %d, want 1", contextGetter.getCalls)
198+
}
199+
})
200+
t.Run("ContextHTTPSGetter uses GetContext", func(t *testing.T) {
201+
contextGetter := recordingContextGetter{}
202+
if _, err := trust.GetWith(context.Background(), &contextGetter, url); err != nil {
203+
t.Fatalf("trust.GetWith returned an unexpected error: %v", err)
204+
}
205+
if contextGetter.getContextCalls != 1 {
206+
t.Errorf("wrong number of calls to GetContext: got %d, want 1", contextGetter.getContextCalls)
207+
}
208+
if contextGetter.recordingGetter.getCalls != 0 {
209+
t.Errorf("wrong number of calls to Get: got %d, want 0", contextGetter.getCalls)
210+
}
211+
})
212+
213+
}

verify/verify.go

+45-12
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package verify
1717

1818
import (
19+
"context"
1920
"crypto/ecdsa"
2021
"crypto/rsa"
2122
"crypto/x509"
@@ -294,6 +295,12 @@ type CRLUnavailableErr struct {
294295
// GetCrlAndCheckRoot downloads the given cert's CRL from one of the distribution points and
295296
// verifies that the CRL is valid and doesn't revoke an intermediate key.
296297
func GetCrlAndCheckRoot(r *trust.AMDRootCerts, opts *Options) (*x509.RevocationList, error) {
298+
return GetCrlAndCheckRootContext(context.Background(), r, opts)
299+
}
300+
301+
// GetCrlAndCheckRootContext behaves like GetCrlAndCheckRoot but forwards the context to the
302+
// HTTPSGetter.
303+
func GetCrlAndCheckRootContext(ctx context.Context, r *trust.AMDRootCerts, opts *Options) (*x509.RevocationList, error) {
297304
r.Mu.Lock()
298305
defer r.Mu.Unlock()
299306
getter := opts.Getter
@@ -308,7 +315,7 @@ func GetCrlAndCheckRoot(r *trust.AMDRootCerts, opts *Options) (*x509.RevocationL
308315
}
309316
var errs error
310317
for _, url := range r.ProductCerts.Ask.CRLDistributionPoints {
311-
bytes, err := getter.Get(url)
318+
bytes, err := trust.GetWith(ctx, getter, url)
312319
if err != nil {
313320
errs = multierr.Append(errs, err)
314321
continue
@@ -354,8 +361,13 @@ func verifyCRL(r *trust.AMDRootCerts) error {
354361

355362
// VcekNotRevoked will consult the online CRL listed in the VCEK certificate for whether this cert
356363
// has been revoked. Returns nil if not revoked, error on any problem.
357-
func VcekNotRevoked(r *trust.AMDRootCerts, _ *x509.Certificate, options *Options) error {
358-
_, err := GetCrlAndCheckRoot(r, options)
364+
func VcekNotRevoked(r *trust.AMDRootCerts, cert *x509.Certificate, options *Options) error {
365+
return VcekNotRevokedContext(context.Background(), r, cert, options)
366+
}
367+
368+
// VcekNotRevokedContext behaves like VcekNotRevoked but forwards the context to the HTTPSGetter.
369+
func VcekNotRevokedContext(ctx context.Context, r *trust.AMDRootCerts, _ *x509.Certificate, options *Options) error {
370+
_, err := GetCrlAndCheckRootContext(ctx, r, options)
359371
return err
360372
}
361373

@@ -571,7 +583,7 @@ type Options struct {
571583
// any missing certificates in an attestation's certificate chain. Uses Getter if false.
572584
DisableCertFetching bool
573585
// Getter takes a URL and returns the body of its contents. By default uses http.Get and returns
574-
// the body.
586+
// the body. If Getter implements trust.ContextHTTPSGetter, GetContext will be preferred over Get.
575587
Getter trust.HTTPSGetter
576588
// Now is the time at which to verify the validity of certificates. If unset, uses time.Now().
577589
Now time.Time
@@ -662,6 +674,11 @@ func updateProductExpectation(product **spb.SevProduct, reportProduct *spb.SevPr
662674
// SnpAttestation verifies the protobuf representation of an attestation report's signature based
663675
// on the report's SignatureAlgo, provided the certificate chain is valid.
664676
func SnpAttestation(attestation *spb.Attestation, options *Options) error {
677+
return SnpAttestationContext(context.Background(), attestation, options)
678+
}
679+
680+
// SnpAttestationContext behaves like SnpAttestation but forwards the context to the HTTPSGetter.
681+
func SnpAttestationContext(ctx context.Context, attestation *spb.Attestation, options *Options) error {
665682
if options == nil {
666683
return fmt.Errorf("options cannot be nil")
667684
}
@@ -670,7 +687,7 @@ func SnpAttestation(attestation *spb.Attestation, options *Options) error {
670687
}
671688
// Make sure we have the whole certificate chain, or at least the product
672689
// info.
673-
if err := fillInAttestation(attestation, options); err != nil {
690+
if err := fillInAttestation(ctx, attestation, options); err != nil {
674691
return err
675692
}
676693

@@ -786,7 +803,7 @@ func cpuidWorkaround(attestation *spb.Attestation, options *Options) (string, fu
786803

787804
// fillInAttestation uses AMD's KDS to populate any empty certificate field in the attestation's
788805
// certificate chain.
789-
func fillInAttestation(attestation *spb.Attestation, options *Options) error {
806+
func fillInAttestation(ctx context.Context, attestation *spb.Attestation, options *Options) error {
790807
if options.DisableCertFetching {
791808
return nil
792809
}
@@ -810,7 +827,7 @@ func fillInAttestation(attestation *spb.Attestation, options *Options) error {
810827
attestation.CertificateChain = chain
811828
}
812829
if len(chain.GetAskCert()) == 0 || len(chain.GetArkCert()) == 0 {
813-
askark, err := trust.GetProductChain(productLine, info.SigningKey, getter)
830+
askark, err := trust.GetProductChainContext(ctx, productLine, info.SigningKey, getter)
814831
if err != nil {
815832
return err
816833
}
@@ -826,7 +843,7 @@ func fillInAttestation(attestation *spb.Attestation, options *Options) error {
826843
case abi.VcekReportSigner:
827844
if len(chain.GetVcekCert()) == 0 {
828845
vcekURL := kds.VCEKCertURL(productLine, report.GetChipId(), kds.TCBVersion(report.GetReportedTcb()))
829-
vcek, err := getter.Get(vcekURL)
846+
vcek, err := trust.GetWith(ctx, getter, vcekURL)
830847
if err != nil {
831848
return &trust.AttestationRecreationErr{
832849
Msg: fmt.Sprintf("could not download VCEK certificate: %v", err),
@@ -854,11 +871,17 @@ func fillInAttestation(attestation *spb.Attestation, options *Options) error {
854871
// chain for the VCEK that supposedly signed the given report, and returns the Attestation
855872
// representation of their combination. If getter is nil, uses Golang's http.Get.
856873
func GetAttestationFromReport(report *spb.Report, options *Options) (*spb.Attestation, error) {
874+
return GetAttestationFromReportContext(context.Background(), report, options)
875+
}
876+
877+
// GetAttestationFromReportContext behaves like GetAttestationFromReport but forwards the context
878+
// to the HTTPSGetter.
879+
func GetAttestationFromReportContext(ctx context.Context, report *spb.Report, options *Options) (*spb.Attestation, error) {
857880
result := &spb.Attestation{
858881
Report: report,
859882
CertificateChain: &spb.CertificateChain{Extras: map[string][]byte{}},
860883
}
861-
if err := fillInAttestation(result, options); err != nil {
884+
if err := fillInAttestation(ctx, result, options); err != nil {
862885
return nil, err
863886
}
864887
// Attempt to fill in the product field of the attestation. Don't error at this
@@ -886,23 +909,33 @@ func GetAttestationFromReport(report *spb.Report, options *Options) (*spb.Attest
886909
// on the report's SignatureAlgo and uses the AMD Key Distribution Service to download the
887910
// report's corresponding VCEK certificate.
888911
func SnpReport(report *spb.Report, options *Options) error {
912+
return SnpReportContext(context.Background(), report, options)
913+
}
914+
915+
// SnpReportContext behaves like SnpReport but forwards the context to the HTTPSGetter.
916+
func SnpReportContext(ctx context.Context, report *spb.Report, options *Options) error {
889917
if options.DisableCertFetching {
890918
return errors.New("cannot verify attestation report without fetching certificates")
891919
}
892-
attestation, err := GetAttestationFromReport(report, options)
920+
attestation, err := GetAttestationFromReportContext(ctx, report, options)
893921
if err != nil {
894922
return fmt.Errorf("could not recreate attestation from report: %w", err)
895923
}
896-
return SnpAttestation(attestation, options)
924+
return SnpAttestationContext(ctx, attestation, options)
897925
}
898926

899927
// RawSnpReport verifies the raw bytes representation of an attestation report's signature
900928
// based on the report's SignatureAlgo and uses the AMD Key Distribution Service to download
901929
// the report's corresponding VCEK certificate.
902930
func RawSnpReport(rawReport []byte, options *Options) error {
931+
return RawSnpReportContext(context.Background(), rawReport, options)
932+
}
933+
934+
// RawSnpReportContext behaves like RawSnpReport but forwards the context to the HTTPSGetter.
935+
func RawSnpReportContext(ctx context.Context, rawReport []byte, options *Options) error {
903936
report, err := abi.ReportToProto(rawReport)
904937
if err != nil {
905938
return fmt.Errorf("could not interpret report bytes: %v", err)
906939
}
907-
return SnpReport(report, options)
940+
return SnpReportContext(ctx, report, options)
908941
}

0 commit comments

Comments
 (0)