Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tpuproxy: resolve FIXMEs added by cl/723723714 #11453

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pkg/abi/nvgpu/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ type NVOS02_PARAMETERS struct {
Pad1 [4]byte
}

// Bitfields in NVOS02Parameters.Flags:
// Bitfields in NVOS02_PARAMETERS.Flags:
const (
NVOS02_FLAGS_ALLOC_SHIFT = 16
NVOS02_FLAGS_ALLOC_MASK = 0x3
Expand Down Expand Up @@ -470,6 +470,18 @@ type NVOS33_PARAMETERS struct {
Flags uint32
}

// Bitfields in NVOS33_PARAMETERS.Flags:
const (
NVOS33_FLAGS_CACHING_TYPE_SHIFT = 23
NVOS33_FLAGS_CACHING_TYPE_MASK = 0x7
NVOS33_FLAGS_CACHING_TYPE_CACHED = 0
NVOS33_FLAGS_CACHING_TYPE_UNCACHED = 1
NVOS33_FLAGS_CACHING_TYPE_WRITECOMBINED = 2
NVOS33_FLAGS_CACHING_TYPE_WRITEBACK = 5
NVOS33_FLAGS_CACHING_TYPE_DEFAULT = 6
NVOS33_FLAGS_CACHING_TYPE_UNCACHED_WEAK = 7
)

// NVOS34_PARAMETERS is the parameter type for NV_ESC_RM_UNMAP_MEMORY.
//
// +marshal
Expand Down
1 change: 1 addition & 0 deletions pkg/hostarch/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ go_library(
"hostarch.go",
"hostarch_arm64.go",
"hostarch_x86.go",
"memory_type.go",
"sizes_util.go",
],
visibility = ["//:sandbox"],
Expand Down
84 changes: 84 additions & 0 deletions pkg/hostarch/memory_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2025 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package hostarch

import "fmt"

// MemoryType specifies CPU memory access behavior.
type MemoryType uint8

const (
// MemoryTypeWriteBack is equivalent to Linux's default pgprot, or the
// following architectural memory types:
//
// - x86: Write-back (WB)
//
// - ARM64: Normal write-back cacheable
//
// This memory type is appropriate for typical application memory and must
// be the zero value for MemoryType.
MemoryTypeWriteBack MemoryType = iota

// MemoryTypeWriteCombine is equivalent to Linux's pgprot_writecombine(),
// or the following architectural memory types:
//
// - x86: Write-combining (WC)
//
// - ARM64: Normal non-cacheable
MemoryTypeWriteCombine

// MemoryTypeUncached is equivalent to Linux's pgprot_noncached(), or the
// following architectural memory types:
//
// - x86: Strong Uncacheable (UC) or Uncacheable (UC-); these differ in
// that UC- may be "downgraded" to WC by a setting of WC or (Intel only) WP
// in MTRR or EPT/NPT, but gVisor does not use MTRRs and KVM never sets WC
// or WP in EPT/NPT.
//
// - ARM64: Device-nGnRnE
MemoryTypeUncached

// NumMemoryTypes is the number of memory types.
NumMemoryTypes
)

// String implements fmt.Stringer.String.
func (mt MemoryType) String() string {
switch mt {
case MemoryTypeWriteBack:
return "WriteBack"
case MemoryTypeWriteCombine:
return "WriteCombine"
case MemoryTypeUncached:
return "Uncached"
default:
return fmt.Sprintf("%d", mt)
}
}

// ShortString returns a two-character string compactly representing the
// MemoryType.
func (mt MemoryType) ShortString() string {
switch mt {
case MemoryTypeWriteBack:
return "WB"
case MemoryTypeWriteCombine:
return "WC"
case MemoryTypeUncached:
return "UC"
default:
return fmt.Sprintf("%02d", mt)
}
}
59 changes: 30 additions & 29 deletions pkg/ring0/pagetables/pagetables_aarch64.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,26 @@ func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {

// Bits in page table entries.
const (
typeTable = 0x3 << 0
typeSect = 0x1 << 0
typePage = 0x3 << 0
pteValid = 0x1 << 0
pteTableBit = 0x1 << 1
pteTypeMask = 0x3 << 0
present = pteValid | pteTableBit
user = 0x1 << 6 /* AP[1] */
readOnly = 0x1 << 7 /* AP[2] */
accessed = 0x1 << 10
dbm = 0x1 << 51
writable = dbm
cont = 0x1 << 52
pxn = 0x1 << 53
xn = 0x1 << 54
dirty = 0x1 << 55
nG = 0x1 << 11
shared = 0x3 << 8
)

const (
mtDevicenGnRE = 0x1 << 2
mtNormal = 0x4 << 2
typeTable = 0x3 << 0
typeSect = 0x1 << 0
typePage = 0x3 << 0
pteValid = 0x1 << 0
pteTableBit = 0x1 << 1
pteTypeMask = 0x3 << 0
present = pteValid | pteTableBit
attrIndxShift = 2
attrIndxMask = 0x7
user = 0x1 << 6 /* AP[1] */
readOnly = 0x1 << 7 /* AP[2] */
accessed = 0x1 << 10
dbm = 0x1 << 51
writable = dbm
cont = 0x1 << 52
pxn = 0x1 << 53
xn = 0x1 << 54
dirty = 0x1 << 55
nG = 0x1 << 11
shared = 0x3 << 8
)

const (
Expand All @@ -93,6 +90,9 @@ type MapOpts struct {

// User indicates the page is a user page.
User bool

// MemoryType is the memory type.
MemoryType hostarch.MemoryType
}

// PTE is a page table entry.
Expand All @@ -119,15 +119,15 @@ func (p *PTE) Valid() bool {
//go:nosplit
func (p *PTE) Opts() MapOpts {
v := atomic.LoadUintptr((*uintptr)(p))

return MapOpts{
AccessType: hostarch.AccessType{
Read: true,
Write: v&readOnly == 0,
Execute: v&xn == 0,
},
Global: v&nG == 0,
User: v&user != 0,
Global: v&nG == 0,
User: v&user != 0,
MemoryType: hostarch.MemoryType((v >> attrIndxShift) & attrIndxMask),
}
}

Expand Down Expand Up @@ -191,11 +191,12 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) {

if opts.User {
v |= user
v |= mtNormal
} else {
v = v &^ user
v |= mtNormal
}

v |= uintptr(opts.MemoryType&attrIndxMask) << attrIndxShift

atomic.StoreUintptr((*uintptr)(p), v)
}

Expand All @@ -209,7 +210,7 @@ func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
// This should never happen.
panic("unaligned physical address!")
}
v := addr | typeTable | protDefault | mtNormal
v := addr | typeTable | protDefault | (uintptr(hostarch.MemoryTypeWriteBack) << attrIndxShift)
atomic.StoreUintptr((*uintptr)(p), v)
}

Expand Down
10 changes: 10 additions & 0 deletions pkg/ring0/pagetables/pagetables_amd64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,13 @@ func TestSplit2MPage(t *testing.T) {
{0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read}},
})
}

func TestNumMemoryTypes(t *testing.T) {
// The PAT accommodates up to 8 entries. However, PTE.Set() currently
// assumes that NumMemoryTypes <= 4, since the location of the most
// significant bit of the PAT index in page table entries varies depending
// on page size (and is never bit 5 == writeThroughShift + 2).
if hostarch.NumMemoryTypes > 4 {
t.Errorf("PTE.Set() and PTE.Opts() must be altered to handle %d MemoryTypes", hostarch.NumMemoryTypes)
}
}
7 changes: 7 additions & 0 deletions pkg/ring0/pagetables/pagetables_arm64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ func TestSplit2MPage(t *testing.T) {
{0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}},
})
}

func TestNumMemoryTypes(t *testing.T) {
// MAIR accommodates up to 8 entries.
if hostarch.NumMemoryTypes > 8 {
t.Errorf("PTE.Set() and PTE.Opts() must be altered to map %d MemoryTypes to a smaller set of MAIR entries", hostarch.NumMemoryTypes)
}
}
32 changes: 19 additions & 13 deletions pkg/ring0/pagetables/pagetables_x86.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ func (p *PageTables) CR3(noFlush bool, pcid uint16) uint64 {

// Bits in page table entries.
const (
present = 0x001
writable = 0x002
user = 0x004
writeThrough = 0x008
cacheDisable = 0x010
accessed = 0x020
dirty = 0x040
super = 0x080
global = 0x100
optionMask = executeDisable | 0xfff
present = 0x001
writable = 0x002
user = 0x004
accessed = 0x020
dirty = 0x040
super = 0x080
global = 0x100
optionMask = executeDisable | 0xfff

writeThroughShift = 3
patIndexMask = 0x3
)

// MapOpts are x86 options.
Expand All @@ -71,6 +72,9 @@ type MapOpts struct {

// User indicates the page is a user page.
User bool

// MemoryType is the memory type.
MemoryType hostarch.MemoryType
}

// PTE is a page table entry.
Expand Down Expand Up @@ -103,8 +107,9 @@ func (p *PTE) Opts() MapOpts {
Write: v&writable != 0,
Execute: v&executeDisable == 0,
},
Global: v&global != 0,
User: v&user != 0,
Global: v&global != 0,
User: v&user != 0,
MemoryType: hostarch.MemoryType((v >> writeThroughShift) & patIndexMask),
}
}

Expand Down Expand Up @@ -154,6 +159,7 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) {
if opts.AccessType.Write {
v |= writable | dirty
}
v |= uintptr(opts.MemoryType&patIndexMask) << writeThroughShift
if p.IsSuper() {
// Note that this is inherited from the previous instance. Set
// does not change the value of Super. See above.
Expand All @@ -172,7 +178,7 @@ func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
// This should never happen.
panic("unaligned physical address!")
}
v := addr | present | user | writable | accessed | dirty
v := addr | present | user | writable | accessed | dirty | (uintptr(hostarch.MemoryTypeWriteBack) << writeThroughShift)
atomic.StoreUintptr((*uintptr)(p), v)
}

Expand Down
21 changes: 18 additions & 3 deletions pkg/sentry/devices/nvproxy/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ type frontendDevice struct {
minor uint32
}

func (dev *frontendDevice) isCtlDevice() bool {
return dev.minor == nvgpu.NV_CONTROL_DEVICE_MINOR
}

func (dev *frontendDevice) basename() string {
if dev.minor == nvgpu.NV_CONTROL_DEVICE_MINOR {
if dev.isCtlDevice() {
return "nvidiactl"
}
return fmt.Sprintf("nvidia%d", dev.minor)
Expand Down Expand Up @@ -134,8 +138,9 @@ type frontendFD struct {
// These fields are marked nosave since we do not automatically reinvoke
// NV_ESC_RM_MAP_MEMORY after restore, so restored FDs have no
// mmap_context.
mmapLength uint64 `state:"nosave"`
mmapInternal uintptr `state:"nosave"`
mmapLength uint64 `state:"nosave"`
mmapInternal uintptr `state:"nosave"`
mmapMemType hostarch.MemoryType `state:"nosave"`

// clients are handles of clients owned by this frontendFD. clients is
// protected by dev.nvp.objsMu.
Expand Down Expand Up @@ -493,6 +498,7 @@ func rmAllocMemorySystem(fi *frontendIoctlState, ioctlParams *nvgpu.IoctlNVOS02P
fi.fd.dev.nvp.objAdd(fi.ctx, ioctlParams.Params.HRoot, ioctlParams.Params.HObjectNew, ioctlParams.Params.HClass, &miscObject{}, ioctlParams.Params.HObjectParent)
if createMmapCtx {
mapFile.mmapLength = ioctlParams.Params.Limit + 1
mapFile.mmapMemType = getMemoryType(fi.ctx, mapFile.dev, nvgpu.NVOS33_FLAGS_CACHING_TYPE_DEFAULT)
}
}
fi.fd.dev.nvp.objsUnlock()
Expand Down Expand Up @@ -1343,6 +1349,15 @@ func rmMapMemory(fi *frontendIoctlState) (uintptr, error) {
}
if ioctlParams.Params.Status == nvgpu.NV_OK {
mapFile.mmapLength = ioctlParams.Params.Length
// src/nvidia/arch/nvalloc/unix/src/escape.c:RmIoctl() forces
// NVOS33_FLAGS_CACHING_TYPE_DEFAULT, but resMap implementations may
// override the "caching type", so in general the memory type depends
// on the mapped object. Conveniently, when this occurs, the caching
// type in pParms->flags must be updated for the call to
// rm_create_mmap_context(), and pParms is subsequently copied back out
// by kernel-open/nvidia/nv.c:nvidia_ioctl(), so we can get the final
// caching type from the updated ioctl params.
mapFile.mmapMemType = getMemoryType(fi.ctx, mapFile.dev, (ioctlParams.Params.Flags>>nvgpu.NVOS33_FLAGS_CACHING_TYPE_SHIFT)&nvgpu.NVOS33_FLAGS_CACHING_TYPE_MASK)
}

ioctlParams.FD = origFD
Expand Down
Loading