Use CDI for GPU injection for AMD devices for --gpus

Signed-off-by: Shiv Tyagi <Shiv.Tyagi@amd.com>
This commit is contained in:
Shiv Tyagi
2026-03-04 13:07:26 +00:00
parent 6bc6209b88
commit 561a5a9b36
6 changed files with 151 additions and 22 deletions

View File

@@ -268,6 +268,8 @@ func (cli *daemonCLI) start(ctx context.Context) (err error) {
cdiCache = daemon.RegisterCDIDriver(cli.Config.CDISpecDirs...)
}
daemon.RegisterGPUDeviceDrivers(cdiCache)
var apiServer apiserver.Server
cli.authzMiddleware, err = initMiddlewares(ctx, &apiServer, cli.Config, pluginStore)
if err != nil {

View File

@@ -2,6 +2,7 @@ package daemon
import (
"context"
"errors"
"github.com/containerd/log"
"github.com/moby/moby/api/types/container"
@@ -9,10 +10,13 @@ import (
"github.com/moby/moby/v2/daemon/config"
"github.com/moby/moby/v2/daemon/internal/capabilities"
"github.com/opencontainers/runtime-spec/specs-go"
"tags.cncf.io/container-device-interface/pkg/cdi"
)
var deviceDrivers = map[string]*deviceDriver{}
var RegisterGPUDeviceDrivers = func(_ *cdi.Cache) {}
type deviceListing struct {
Devices []system.DeviceInfo
Warnings []string
@@ -38,6 +42,18 @@ func registerDeviceDriver(name string, d *deviceDriver) {
deviceDrivers[name] = d
}
func getFirstAvailableVendor(vendorList []string) (string, error) {
knownVendors := []string{"nvidia.com", "amd.com"}
for _, vendor := range knownVendors {
for _, available := range vendorList {
if vendor == available {
return vendor, nil
}
}
}
return "", errors.New("no known GPU vendor found")
}
func (daemon *Daemon) handleDevice(req container.DeviceRequest, spec *specs.Spec) error {
if req.Driver == "" {
// If no driver is explicitly requested, we iterate over the registered

View File

@@ -1,9 +1,18 @@
package daemon
import (
"errors"
"fmt"
"os/exec"
"strings"
"github.com/moby/moby/v2/daemon/internal/capabilities"
"github.com/opencontainers/runtime-spec/specs-go"
"tags.cncf.io/container-device-interface/pkg/cdi"
)
const (
amdContainerRuntimeExecutableName = "amd-container-runtime"
)
func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
@@ -25,3 +34,46 @@ func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
return nil
}
func createAMDCDIUpdater(cdiCache *cdi.Cache) func(*specs.Spec, *deviceInstance) error {
return func(s *specs.Spec, dev *deviceInstance) error {
vendor, err := getFirstAvailableVendor(cdiCache.ListVendors())
if err != nil {
return fmt.Errorf("failed to discover GPU vendor from CDI: %w", err)
}
if vendor != "amd.com" {
return errors.New("AMD CDI spec not found")
}
injector := &cdiDeviceInjector{
defaultCDIDeviceKind: "amd.com/gpu",
}
return injector.injectDevices(s, dev)
}
}
func getAMDDeviceDrivers(cdiCache *cdi.Cache) *deviceDriver {
var composite firstSuccessfulUpdater
if cdiCache != nil {
composite = append(composite, createAMDCDIUpdater(cdiCache))
}
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
composite = append(composite, setAMDGPUs)
}
if len(composite) == 0 {
return nil
}
// We do not support specifying driver with device requests for AMD GPUs.
// Hence only use the composite updater and try cdi and runtime driver in sequence
// based on availability.
capset := capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}}
return &deviceDriver{
capset: capset,
updateSpec: composite.updateSpec,
}
}

26
daemon/devices_linux.go Normal file
View File

@@ -0,0 +1,26 @@
package daemon
import "tags.cncf.io/container-device-interface/pkg/cdi"
func init() {
RegisterGPUDeviceDrivers = registerGPUDeviceDrivers
}
// registerGPUDeviceDrivers registers GPU device drivers.
// If the cdiCache is provided, it is used to detect presence of CDI specs for AMD GPUs.
// For NVIDIA GPUs, presence of CDI specs is detected by checking for the nvidia-cdi-hook binary.
func registerGPUDeviceDrivers(cdiCache *cdi.Cache) {
// Register NVIDIA device drivers.
if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
for name, driver := range nvidiaDrivers {
registerDeviceDriver(name, driver)
}
return
}
// Register AMD driver if AMD CDI spec or helper binary is present.
if amdDriver := getAMDDeviceDrivers(cdiCache); amdDriver != nil {
registerDeviceDriver("amd", amdDriver)
return
}
}

View File

@@ -23,7 +23,6 @@ var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs
const (
nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook"
nvidiaCDIHookExecutableName = "nvidia-cdi-hook"
amdContainerRuntimeExecutableName = "amd-container-runtime"
)
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
@@ -36,27 +35,6 @@ var allNvidiaCaps = map[string]struct{}{
"display": {},
}
func init() {
// Register NVIDIA device drivers.
if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
for name, driver := range nvidiaDrivers {
registerDeviceDriver(name, driver)
}
return
}
// Register AMD driver if AMD helper binary is present.
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
registerDeviceDriver("amd", &deviceDriver{
capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
updateSpec: setAMDGPUs,
})
return
}
// No "gpu" capability
}
func getNVIDIADeviceDrivers() map[string]*deviceDriver {
var composite firstSuccessfulUpdater
nvidiaDrivers := make(map[string]*deviceDriver)

55
daemon/devices_test.go Normal file
View File

@@ -0,0 +1,55 @@
package daemon
import (
"testing"
"gotest.tools/v3/assert"
)
func TestGetFirstAvailableVendor(t *testing.T) {
tests := []struct {
name string
vendors []string
expectVendor string
expectError string
}{
{
name: "NVIDIA vendor",
vendors: []string{"nvidia.com"},
expectVendor: "nvidia.com",
},
{
name: "AMD vendor",
vendors: []string{"amd.com"},
expectVendor: "amd.com",
},
{
name: "No vendors",
vendors: nil,
expectError: "no known GPU vendor found",
},
{
name: "Unknown vendor",
vendors: []string{"unknown.com"},
expectError: "no known GPU vendor found",
},
{
name: "Mixed vendor",
vendors: []string{"amd.com", "nvidia.com"},
expectVendor: "nvidia.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
vendor, err := getFirstAvailableVendor(tt.vendors)
if tt.expectError != "" {
assert.Error(t, err, tt.expectError)
} else {
assert.NilError(t, err)
assert.Equal(t, tt.expectVendor, vendor)
}
})
}
}