Skip to content

Commit

Permalink
Merge pull request #139 from deeglaze/defaultvmpl
Browse files Browse the repository at this point in the history
Add flag for default VMPL in GetRawQuote.
  • Loading branch information
deeglaze authored Nov 5, 2024
2 parents f94d851 + c72e34c commit 6ad451f
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion client/client_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package client
import (
"flag"
"fmt"
"strconv"
"time"

"github.com/google/go-configfs-tsm/configfs/linuxtsm"
Expand All @@ -41,6 +42,7 @@ const (
var (
throttleDuration = flag.Duration("self_throttle_duration", 2*time.Second, "Rate-limit library-initiated device commands to this duration")
burstMax = flag.Int("self_throttle_burst", 1, "Rate-limit library-initiated device commands to this many commands per duration")
defaultVMPL = flag.String("default_vmpl", "", "Default VMPL to use for attestation (empty for driver default)")
)

// LinuxDevice implements the Device interface with Linux ioctls.
Expand Down Expand Up @@ -167,7 +169,14 @@ func (p *LinuxIoctlQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, level

// GetRawQuote returns byte format attestation plus certificate table via /dev/sev-guest ioctl.
func (p *LinuxIoctlQuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) {
return p.GetRawQuoteAtLevel(reportData, 0)
if *defaultVMPL == "" {
return p.GetRawQuoteAtLevel(reportData, 0)
}
vmpl, err := strconv.ParseUint(*defaultVMPL, 10, 32)
if err != nil {
return nil, fmt.Errorf("bad default_vmpl: %q", *defaultVMPL)
}
return p.GetRawQuoteAtLevel(reportData, uint(vmpl))
}

// Product returns the current CPU's associated AMD SEV product information.
Expand Down Expand Up @@ -222,6 +231,15 @@ func (p *LinuxConfigFsQuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8,
InBlob: reportData[:],
GetAuxBlob: true,
}
if *defaultVMPL != "" {
vmpl, err := strconv.ParseUint(*defaultVMPL, 10, 32)
if err != nil {
return nil, fmt.Errorf("bad default_vmpl: %q", *defaultVMPL)
}
req.Privilege = &report.Privilege{
Level: uint(vmpl),
}
}
resp, err := linuxtsm.GetReport(req)
if err != nil {
return nil, err
Expand Down

0 comments on commit 6ad451f

Please sign in to comment.