Skip to content

Commit

Permalink
merge branch 'pr-360'
Browse files Browse the repository at this point in the history
Aleksa Sarai (3):
  oci: walk: return error from Close if applicable
  oci: unpack: slurp up raw layer stream before Close()
  pkg: hardening: slurp trailing bytes in VerifiedReadCloser.Close()

LGTMs: @cyphar @tych0
Closes #360
  • Loading branch information
cyphar committed Mar 30, 2021
2 parents 3d09b87 + 0976fbb commit 07fa845
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 6 deletions.
11 changes: 9 additions & 2 deletions oci/casext/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ var ErrSkipDescriptor = errors.New("[internal] do not recurse into descriptor")
// more than once. This is quite important for remote CAS implementations.
type WalkFunc func(descriptorPath DescriptorPath) error

func (ws *walkState) recurse(ctx context.Context, descriptorPath DescriptorPath) error {
func (ws *walkState) recurse(ctx context.Context, descriptorPath DescriptorPath) (Err error) {
log.WithFields(log.Fields{
"digest": descriptorPath.Descriptor().Digest,
}).Debugf("-> ws.recurse")
Expand Down Expand Up @@ -129,7 +129,14 @@ func (ws *walkState) recurse(ctx context.Context, descriptorPath DescriptorPath)
}
return err
}
defer blob.Close()
defer func() {
if err := blob.Close(); err != nil {
log.Warnf("during recursion blob %v had error on Close: %v", descriptor.Digest, err)
if Err == nil {
Err = err
}
}
}()

// Recurse into children.
for _, child := range childDescriptors(blob.Data) {
Expand Down
16 changes: 15 additions & 1 deletion oci/layer/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,22 @@ func UnpackRootfs(ctx context.Context, engine cas.Engine, rootfsPath string, man
// in the later diff_id check failing because the digester didn't get
// the whole uncompressed stream). Just blindly consume anything left
// in the layer.
if _, err = io.Copy(ioutil.Discard, layer); err != nil {
if n, err := io.Copy(ioutil.Discard, layer); err != nil {
return errors.Wrap(err, "discard trailing archive bits")
} else if n != 0 {
log.Debugf("unpack manifest: layer %s: ignoring %d trailing 'junk' bytes in the tar stream -- probably from GNU tar", layerDescriptor.Digest, n)
}
// Same goes for compressed layers -- it seems like some gzip
// implementations add trailing NUL bytes, which Go doesn't slurp up.
// Just eat up the rest of the remaining bytes and discard them.
//
// FIXME: We use layerData here because pgzip returns io.EOF from
// WriteTo, which causes havoc with io.Copy. Ideally we would use
// layerRaw. See <https://github.com/klauspost/pgzip/issues/38>.
if n, err := io.Copy(ioutil.Discard, layerData); err != nil {
return errors.Wrap(err, "discard trailing raw bits")
} else if n != 0 {
log.Warnf("unpack manifest: layer %s: ignoring %d trailing 'junk' bytes in the blob stream -- this may indicate a bug in the tool which built this image", layerDescriptor.Digest, n)
}
if err := layerData.Close(); err != nil {
return errors.Wrap(err, "close layer data")
Expand Down
36 changes: 33 additions & 3 deletions pkg/hardening/verified_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package hardening

import (
"fmt"
"io"
"io/ioutil"
"os"

"github.com/apex/log"
"github.com/opencontainers/go-digest"
Expand Down Expand Up @@ -94,11 +97,9 @@ func (v *VerifiedReadCloser) verify(nilErr error) error {
// Not enough bytes in the stream.
case v.currentSize < v.ExpectedSize:
return errors.Wrapf(ErrSizeMismatch, "expected %d bytes (only %d bytes in stream)", v.ExpectedSize, v.currentSize)

// We don't read the entire blob, so the message needs to be slightly adjusted.
case v.currentSize > v.ExpectedSize:
return errors.Wrapf(ErrSizeMismatch, "expected %d bytes (extra bytes in stream)", v.ExpectedSize)

}
}
// Forward the provided error.
Expand Down Expand Up @@ -131,7 +132,9 @@ func (v *VerifiedReadCloser) Read(p []byte) (n int, err error) {
// anything left by doing a 1-byte read (Go doesn't allow for zero-length
// Read()s to give EOFs).
case left == 0:
// We just want to know whether we read something (n>0). #nosec G104
// We just want to know whether we read something (n>0). Whatever we
// read is irrelevant because if we read something that means the
// reader will fail to verify. #nosec G104
nTmp, _ := v.Reader.Read(make([]byte, 1))
v.currentSize += int64(nTmp)
}
Expand All @@ -157,9 +160,36 @@ func (v *VerifiedReadCloser) Read(p []byte) (n int, err error) {
return n, err
}

// sourceName returns a debugging-friendly string to indicate to the user what
// the source reader is for this verified reader.
func (v *VerifiedReadCloser) sourceName() string {
switch inner := v.Reader.(type) {
case *VerifiedReadCloser:
return fmt.Sprintf("vrdr[%s]", inner.sourceName())
case *os.File:
return inner.Name()
case fmt.Stringer:
return inner.String()
// TODO: Maybe handle things like ioutil.NopCloser by using reflection?
default:
return fmt.Sprintf("%#v", inner)
}
}

// Close is a wrapper around VerifiedReadCloser.Reader, but with a digest check
// which will return an error if the underlying Close() didn't.
func (v *VerifiedReadCloser) Close() error {
// Consume any remaining bytes to make sure that we've actually read to the
// end of the stream. VerifiedReadCloser.Read will not read past
// ExpectedSize+1, so we don't need to add a limit here.
if n, err := io.Copy(ioutil.Discard, v); err != nil {
return errors.Wrap(err, "consume remaining unverified stream")
} else if n != 0 {
// If there's trailing bytes being discarded at this point, that
// indicates whatever you used to generate this blob is adding trailing
// gunk.
log.Infof("verified reader: %d bytes of trailing data discarded from %s", n, v.sourceName())
}
// Piped to underlying close.
err := v.Reader.Close()
if err != nil {
Expand Down
73 changes: 73 additions & 0 deletions pkg/hardening/verified_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,37 @@ func TestValidIgnoreLength(t *testing.T) {
}
}

func TestValidTrailing(t *testing.T) {
for size := 1; size <= 16384; size *= 2 {
t.Run(fmt.Sprintf("size:%d", size), func(t *testing.T) {
// Fill buffer with random data.
buffer := new(bytes.Buffer)
if _, err := io.CopyN(buffer, rand.Reader, int64(size)); err != nil {
t.Fatalf("getting random data for buffer failed: %v", err)
}

// Get expected hash.
expectedDigest := digest.SHA256.FromBytes(buffer.Bytes())
verifiedReader := &VerifiedReadCloser{
Reader: ioutil.NopCloser(buffer),
ExpectedDigest: expectedDigest,
ExpectedSize: int64(-1),
}

// Read *half* of the bytes, leaving some remaining. We should get
// no errors.
if _, err := io.CopyN(ioutil.Discard, verifiedReader, int64(size/2)); err != nil {
t.Errorf("expected no error after reading only %d bytes: got an error: %v", size/2, err)
}

// And on close we shouldn't get an error either.
if err := verifiedReader.Close(); err != nil {
t.Errorf("expected digest+size to be correct on Close: got an error: %v", err)
}
})
}
}

func TestInvalidDigest(t *testing.T) {
for size := 1; size <= 16384; size *= 2 {
t.Run(fmt.Sprintf("size:%d", size), func(t *testing.T) {
Expand Down Expand Up @@ -123,6 +154,48 @@ func TestInvalidDigest(t *testing.T) {
}
}

func TestInvalidDigest_Trailing(t *testing.T) {
for size := 1; size <= 16384; size *= 2 {
for delta := 1; delta-1 <= size/2; delta *= 2 {
t.Run(fmt.Sprintf("size:%d_delta:%d", size, delta), func(t *testing.T) {
// Fill buffer with random data.
buffer := new(bytes.Buffer)
if _, err := io.CopyN(buffer, rand.Reader, int64(size)); err != nil {
t.Fatalf("getting random data for buffer failed: %v", err)
}

// Generate a correct hash (for a shorter buffer), but don't
// verify the size -- this is to make sure that we actually
// read all the bytes.
shortBuffer := buffer.Bytes()[:size-delta]
expectedDigest := digest.SHA256.FromBytes(shortBuffer)
verifiedReader := &VerifiedReadCloser{
Reader: ioutil.NopCloser(buffer),
ExpectedDigest: expectedDigest,
ExpectedSize: -1,
}

// Make sure everything if we copy-to-EOF we get the right error.
if _, err := io.CopyN(ioutil.Discard, verifiedReader, int64(size-delta)); err != nil {
t.Errorf("expected no errors after reading N bytes: got error: %v", err)
}

// Check that the digest does actually match right now.
verifiedReader.init()
if err := verifiedReader.verify(nil); err != nil {
t.Errorf("expected no errors in verify before Close: got error: %v", err)
}

// And on close we should get the error.
if err := verifiedReader.Close(); errors.Cause(err) != ErrDigestMismatch {
t.Errorf("expected digest to be invalid on Close: got wrong error: %v", err)
}
})

}
}
}

func TestInvalidSize_Short(t *testing.T) {
for size := 1; size <= 16384; size *= 2 {
for delta := 1; delta-1 <= size/2; delta *= 2 {
Expand Down

0 comments on commit 07fa845

Please sign in to comment.