Skip to content

Instantly share code, notes, and snippets.

@ShoGinn
Forked from guruevi/nvidia_allocator.py
Last active August 27, 2024 00:22
Show Gist options
  • Save ShoGinn/5c2834a87760b9a56e9b793b0b9f9fa7 to your computer and use it in GitHub Desktop.
Save ShoGinn/5c2834a87760b9a56e9b793b0b9f9fa7 to your computer and use it in GitHub Desktop.
Proxmox vGPU Hook Script
#!/usr/bin/env python3
import logging
import os
import re
import sys
from pathlib import Path
# Thanks to: https://gist.github.com/guruevi/7d9673c6f44f49b1841eaf49bbd727f9
# This script is used to set the vGPU type for NVIDIA vGPU passthrough in Proxmox VE.
# The script is called from the Proxmox VE hook scripts.
# The script reads the vGPU type ID from the VM config file and sets it to the vGPU device.
# The script can also be used to find the NVIDIA GPU devices that support the specified vGPU type.
PCI_DEVICES_PATH = Path("/sys/bus/pci/devices/")
CURRENT_VGPU_FILE = "current_vgpu_type"
CREATABLE_VGPU_FILE = "creatable_vgpu_types"
NVIDIA_DIR = Path("nvidia")
SCRIPT_NAME = Path(__file__).name
MAX_ARGS = 4
NECESSARY_ARGS = 3
def handle_error(message: str) -> None:
logging.error(message)
sys.exit(1)
def info_message(message: str) -> None:
logging.info(message)
def sort_pci_devices(device_ids: list[str]) -> list[str]:
def pci_sort_key(device_id: str) -> tuple[int, int, int, int]:
parts = device_id.split(":")
domain, bus, device, function = (
parts[0],
parts[1],
parts[2].split(".")[0],
parts[2].split(".")[1],
)
return (int(domain, 16), int(bus, 16), int(device, 16), int(function, 16))
return sorted(device_ids, key=pci_sort_key)
def get_available_gpu(vgpu_type: str) -> tuple[str, str]:
devices = sort_pci_devices(os.listdir(PCI_DEVICES_PATH))
for device in devices:
device_path = PCI_DEVICES_PATH / device
nvidia_path = device_path / NVIDIA_DIR
if not nvidia_path.is_dir():
continue
with (nvidia_path / CURRENT_VGPU_FILE).open() as file:
current_vgpu_type = file.read().strip()
if current_vgpu_type != "0":
continue
with (nvidia_path / CREATABLE_VGPU_FILE).open() as file:
available_vgpu_type = file.read()
for line in available_vgpu_type.splitlines():
if vgpu_type in line:
info_message(f"Found available: {device_path}")
info_message(f"nVIDIA ID, type: {line}")
vgpu_id = line.split(" : ")[0].strip()
return device, vgpu_id
raise ValueError("No available NVIDIA vGPU found, are virtual functions enabled? (systemctl start nvidia-sriov)")
def parse_vgpu_type_id(config: str) -> str:
pattern = r"tags:(.*)nvidia-(\d+)"
match = re.search(pattern, config)
return match[2] if match else ""
def parse_vgpu_bus_id(config: str) -> Path:
pattern = rf"args:(.*)-device vfio-pci,sysfsdev={re.escape(str(PCI_DEVICES_PATH))}/([0-9a-fA-F:.]+)"
if match := re.search(pattern, config):
return PCI_DEVICES_PATH / match[2]
raise FileNotFoundError("vGPU bus ID not found in config file")
def parse_vm_config(vmid: str, from_node: str | None) -> dict:
config_file = f"/etc/pve/qemu-server/{vmid}.conf"
if from_node:
config_file = f"/etc/pve/nodes/{from_node}/qemu-server/{vmid}.conf"
config = Path(config_file).read_text()
config_dict = {}
for line in config.splitlines():
stripped_line = line.strip() # remove leading and trailing whitespace
if not stripped_line: # skip empty lines
continue
if ":" not in stripped_line: # skip lines without a colon
continue
key, value = stripped_line.split(":", 1) # split on the first colon
config_dict[key.strip()] = value.strip()
return config_dict
def parse_line_config(config_line: str, item: str) -> str | None:
line_dict = {}
for line in config_line.split(","):
key, value = line.split("=")
line_dict[key] = value
return line_dict.get(item)
def write_current_vgpu_type(vgpu_path: Path, vgpu_type_id: str) -> None:
file_path = vgpu_path / NVIDIA_DIR / CURRENT_VGPU_FILE
info_message(f"Writing vGPU type to {file_path}")
try:
with file_path.open() as file:
current_value = file.read().strip()
if current_value == vgpu_type_id:
info_message(f"vGPU type already set to {vgpu_type_id}")
return
with file_path.open("w") as file:
file.write(vgpu_type_id)
info_message(f"vGPU type set to {vgpu_type_id}")
except PermissionError:
handle_error(f"Permission denied: Unable to write to {file_path}")
except Exception as e:
handle_error(f"An error occurred: {e!s}")
def generate_command(vmid: str, vgpu_name: str) -> None:
from_node = os.environ.get("PVE_MIGRATED_FROM", None)
available_vgpu, gpu_id = get_available_gpu(vgpu_name)
config_dict = parse_vm_config(vmid, from_node)
uuid = parse_line_config(config_dict["smbios1"], "uuid")
info_message(f"qm set {vmid} --hookscript local:snippets/nvidia_allocator.py")
info_message(f'qm set {vmid} --args "-device vfio-pci,sysfsdev=/sys/bus/pci/devices/{available_vgpu} -uuid {uuid}"')
tags = set(filter(None, config_dict.get("tags", "").strip().split(";")))
tags.add(f"nvidia-{gpu_id}")
info_message(f"qm set {vmid} --tags \"{';'.join(tags)}\"")
def print_usage_and_exit() -> None:
handle_error(f"""\nUsage:\t{SCRIPT_NAME} <vmid> <phase>\n\t{SCRIPT_NAME} <vmid> get_command <vgpu_name>\n""")
def get_config(vmid: str) -> str:
config_path = Path("/etc/pve/qemu-server/") / f"{vmid}.conf"
if not config_path.exists():
raise ValueError(f"Config file not found: {config_path}")
with config_path.open() as file:
return file.read()
def get_command(vmid: str) -> None:
if len(sys.argv) < MAX_ARGS:
raise ValueError(f"Usage: {SCRIPT_NAME} <vmid> get_command <vgpu_name>")
vgpu_name = sys.argv[3]
generate_command(vmid, vgpu_name)
def get_vgpu_path(config: str) -> Path:
try:
vgpu_path = parse_vgpu_bus_id(config)
except FileNotFoundError as e:
handle_error(str(e))
return vgpu_path
def pre_start(vmid: str) -> None:
config = get_config(vmid)
vgpu_path = get_vgpu_path(config)
vgpu_type_id = parse_vgpu_type_id(config)
if not vgpu_type_id:
raise ValueError("vGPU type ID not found in config")
if not vgpu_path.exists():
raise ValueError(
f"Specified vGPU not found: rerun the nvidia_allocator get_command or check the drivers: {vgpu_path}"
)
write_current_vgpu_type(vgpu_path, vgpu_type_id)
def post_stop(vmid: str) -> None:
config = get_config(vmid)
vgpu_path = get_vgpu_path(config)
write_current_vgpu_type(vgpu_path, "0")
def post_start() -> None:
return
def pre_stop() -> None:
return
def main() -> None:
if len(sys.argv) < NECESSARY_ARGS:
print_usage_and_exit()
vmid, phase = sys.argv[1], sys.argv[2]
phases = {
"get_command": get_command,
"pre-start": pre_start,
"post-start": post_start,
"pre-stop": pre_stop,
"post-stop": post_stop,
}
if phase not in phases:
info_message(f"Invalid phase: {phase}")
print_usage_and_exit()
try:
phases[phase](vmid) # type: ignore [operator]
except Exception as e:
logging.error(str(e))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(message)s")
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment