contrib(nvidia): match right apt repo based on os release

Signed-off-by: CrazyMax <1951866+crazy-max@users.noreply.github.com>
Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
This commit is contained in:
CrazyMax
2025-10-23 09:32:33 +02:00
committed by Tonis Tiigi
parent 5ee9c96bad
commit f890b2e451

View File

@@ -90,8 +90,10 @@ func (s *setup) Run(ctx context.Context) (err error) {
closeProgress(err)
}()
isDistro, _ := isDebianOrUbuntu()
if !isDistro {
osr, err := getOSRelease()
if err != nil {
return err
} else if osr.ID != "debian" && osr.ID != "ubuntu" {
return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu")
}
@@ -131,7 +133,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
return err
}
if err := installPackages(ctx, dv, pw, dgst); err != nil {
if err := installPackages(ctx, osr, dv, pw, dgst); err != nil {
return err
}
@@ -167,8 +169,20 @@ func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Dig
return cmd.Run()
}
func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
const aptDistro = "ubuntu2404"
func installPackages(ctx context.Context, osr *osrelease, dv string, pw progress.Writer, dgst digest.Digest) error {
aptDistro := "ubuntu2404"
switch osr.ID {
case "debian":
if osr.VersionID == "" {
aptDistro = "debian12"
} else {
aptDistro = "debian" + osr.VersionID
}
case "ubuntu":
if osr.VersionID != "" {
aptDistro = "ubuntu" + strings.ReplaceAll(osr.VersionID, ".", "")
}
}
var arch string
switch runtime.GOARCH {
@@ -274,35 +288,38 @@ func hasNvidiaDevices() (bool, error) {
return found, nil
}
func getOSID() (string, error) {
type osrelease struct {
ID string
VersionID string
}
func getOSRelease() (*osrelease, error) {
file, err := os.Open("/etc/os-release")
if err != nil {
return "", err
return nil, err
}
defer file.Close()
var id, versionID string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if id, ok := strings.CutPrefix(line, "ID="); ok {
return strings.Trim(id, `"`), nil // Remove potential quotes
if v, ok := strings.CutPrefix(line, "ID="); ok {
id = strings.Trim(v, `"`)
} else if v, ok := strings.CutPrefix(line, "VERSION_ID="); ok {
versionID = strings.Trim(v, `"`)
}
}
if err := scanner.Err(); err != nil {
return "", err
return nil, err
}
return "", errors.Errorf("ID not found in /etc/os-release")
}
func isDebianOrUbuntu() (bool, error) {
id, err := getOSID()
if err != nil {
return false, err
if id == "" {
return nil, errors.Errorf("ID not found in /etc/os-release")
}
return id == "debian" || id == "ubuntu", nil
return &osrelease{ID: id, VersionID: versionID}, nil
}
func hasWSLGPU() bool {