mirror of
https://github.com/moby/moby.git
synced 2026-06-24 08:48:23 +00:00
Use CDI for GPU injection for AMD devices for --gpus
Signed-off-by: Shiv Tyagi <Shiv.Tyagi@amd.com>
This commit is contained in:
@@ -268,6 +268,8 @@ func (cli *daemonCLI) start(ctx context.Context) (err error) {
|
||||
cdiCache = daemon.RegisterCDIDriver(cli.Config.CDISpecDirs...)
|
||||
}
|
||||
|
||||
daemon.RegisterGPUDeviceDrivers(cdiCache)
|
||||
|
||||
var apiServer apiserver.Server
|
||||
cli.authzMiddleware, err = initMiddlewares(ctx, &apiServer, cli.Config, pluginStore)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/containerd/log"
|
||||
"github.com/moby/moby/api/types/container"
|
||||
@@ -9,10 +10,13 @@ import (
|
||||
"github.com/moby/moby/v2/daemon/config"
|
||||
"github.com/moby/moby/v2/daemon/internal/capabilities"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
)
|
||||
|
||||
var deviceDrivers = map[string]*deviceDriver{}
|
||||
|
||||
var RegisterGPUDeviceDrivers = func(_ *cdi.Cache) {}
|
||||
|
||||
type deviceListing struct {
|
||||
Devices []system.DeviceInfo
|
||||
Warnings []string
|
||||
@@ -38,6 +42,18 @@ func registerDeviceDriver(name string, d *deviceDriver) {
|
||||
deviceDrivers[name] = d
|
||||
}
|
||||
|
||||
func getFirstAvailableVendor(vendorList []string) (string, error) {
|
||||
knownVendors := []string{"nvidia.com", "amd.com"}
|
||||
for _, vendor := range knownVendors {
|
||||
for _, available := range vendorList {
|
||||
if vendor == available {
|
||||
return vendor, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", errors.New("no known GPU vendor found")
|
||||
}
|
||||
|
||||
func (daemon *Daemon) handleDevice(req container.DeviceRequest, spec *specs.Spec) error {
|
||||
if req.Driver == "" {
|
||||
// If no driver is explicitly requested, we iterate over the registered
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/moby/moby/v2/daemon/internal/capabilities"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
)
|
||||
|
||||
const (
|
||||
amdContainerRuntimeExecutableName = "amd-container-runtime"
|
||||
)
|
||||
|
||||
func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
|
||||
@@ -25,3 +34,46 @@ func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createAMDCDIUpdater(cdiCache *cdi.Cache) func(*specs.Spec, *deviceInstance) error {
|
||||
return func(s *specs.Spec, dev *deviceInstance) error {
|
||||
vendor, err := getFirstAvailableVendor(cdiCache.ListVendors())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to discover GPU vendor from CDI: %w", err)
|
||||
}
|
||||
|
||||
if vendor != "amd.com" {
|
||||
return errors.New("AMD CDI spec not found")
|
||||
}
|
||||
|
||||
injector := &cdiDeviceInjector{
|
||||
defaultCDIDeviceKind: "amd.com/gpu",
|
||||
}
|
||||
return injector.injectDevices(s, dev)
|
||||
}
|
||||
}
|
||||
|
||||
func getAMDDeviceDrivers(cdiCache *cdi.Cache) *deviceDriver {
|
||||
var composite firstSuccessfulUpdater
|
||||
|
||||
if cdiCache != nil {
|
||||
composite = append(composite, createAMDCDIUpdater(cdiCache))
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
|
||||
composite = append(composite, setAMDGPUs)
|
||||
}
|
||||
|
||||
if len(composite) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We do not support specifying driver with device requests for AMD GPUs.
|
||||
// Hence only use the composite updater and try cdi and runtime driver in sequence
|
||||
// based on availability.
|
||||
capset := capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}}
|
||||
return &deviceDriver{
|
||||
capset: capset,
|
||||
updateSpec: composite.updateSpec,
|
||||
}
|
||||
}
|
||||
|
||||
26
daemon/devices_linux.go
Normal file
26
daemon/devices_linux.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package daemon
|
||||
|
||||
import "tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
|
||||
func init() {
|
||||
RegisterGPUDeviceDrivers = registerGPUDeviceDrivers
|
||||
}
|
||||
|
||||
// registerGPUDeviceDrivers registers GPU device drivers.
|
||||
// If the cdiCache is provided, it is used to detect presence of CDI specs for AMD GPUs.
|
||||
// For NVIDIA GPUs, presence of CDI specs is detected by checking for the nvidia-cdi-hook binary.
|
||||
func registerGPUDeviceDrivers(cdiCache *cdi.Cache) {
|
||||
// Register NVIDIA device drivers.
|
||||
if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
|
||||
for name, driver := range nvidiaDrivers {
|
||||
registerDeviceDriver(name, driver)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Register AMD driver if AMD CDI spec or helper binary is present.
|
||||
if amdDriver := getAMDDeviceDrivers(cdiCache); amdDriver != nil {
|
||||
registerDeviceDriver("amd", amdDriver)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,6 @@ var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs
|
||||
const (
|
||||
nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook"
|
||||
nvidiaCDIHookExecutableName = "nvidia-cdi-hook"
|
||||
amdContainerRuntimeExecutableName = "amd-container-runtime"
|
||||
)
|
||||
|
||||
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
|
||||
@@ -36,27 +35,6 @@ var allNvidiaCaps = map[string]struct{}{
|
||||
"display": {},
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register NVIDIA device drivers.
|
||||
if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
|
||||
for name, driver := range nvidiaDrivers {
|
||||
registerDeviceDriver(name, driver)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Register AMD driver if AMD helper binary is present.
|
||||
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
|
||||
registerDeviceDriver("amd", &deviceDriver{
|
||||
capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
|
||||
updateSpec: setAMDGPUs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// No "gpu" capability
|
||||
}
|
||||
|
||||
func getNVIDIADeviceDrivers() map[string]*deviceDriver {
|
||||
var composite firstSuccessfulUpdater
|
||||
nvidiaDrivers := make(map[string]*deviceDriver)
|
||||
|
||||
55
daemon/devices_test.go
Normal file
55
daemon/devices_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetFirstAvailableVendor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vendors []string
|
||||
expectVendor string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "NVIDIA vendor",
|
||||
vendors: []string{"nvidia.com"},
|
||||
expectVendor: "nvidia.com",
|
||||
},
|
||||
{
|
||||
name: "AMD vendor",
|
||||
vendors: []string{"amd.com"},
|
||||
expectVendor: "amd.com",
|
||||
},
|
||||
{
|
||||
name: "No vendors",
|
||||
vendors: nil,
|
||||
expectError: "no known GPU vendor found",
|
||||
},
|
||||
{
|
||||
name: "Unknown vendor",
|
||||
vendors: []string{"unknown.com"},
|
||||
expectError: "no known GPU vendor found",
|
||||
},
|
||||
{
|
||||
name: "Mixed vendor",
|
||||
vendors: []string{"amd.com", "nvidia.com"},
|
||||
expectVendor: "nvidia.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
vendor, err := getFirstAvailableVendor(tt.vendors)
|
||||
|
||||
if tt.expectError != "" {
|
||||
assert.Error(t, err, tt.expectError)
|
||||
} else {
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, tt.expectVendor, vendor)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user