From 96cd1f4fb9b8e87f66d8d565e0407b1a4db53f06 Mon Sep 17 00:00:00 2001 From: Maria Kuklina <101095419+kmd-fl@users.noreply.github.com> Date: Tue, 10 Sep 2024 10:23:59 +0200 Subject: [PATCH] fix(vm): allow only nvidia and amd gpus (#2362) --- crates/gpu-utils/src/lib.rs | 14 +++++++++++++- crates/vm-utils/src/vm_utils.rs | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/gpu-utils/src/lib.rs b/crates/gpu-utils/src/lib.rs index b0eea235af..50ae728cd0 100644 --- a/crates/gpu-utils/src/lib.rs +++ b/crates/gpu-utils/src/lib.rs @@ -18,6 +18,9 @@ pub enum PciError { UnsupportedProperty, } +const AMD_VENDOR_ID: u16 = 0x1002; +const NVIDIA_VENDOR_ID: u16 = 0x10de; + pub fn get_gpu_pci() -> Result, PciError> { let info = PciInfo::enumerate_pci()?; // List of GPU devices @@ -29,12 +32,14 @@ pub fn get_gpu_pci() -> Result, PciError> { let device = device?; let device_class = process_property_result(device.device_class())?; let device_location = process_property_result(device.location())?; - if device_class == DisplayController { + if device_class == DisplayController && is_vendor_allowed(device.vendor_id()) { gpu_devices.insert(device_location); } pci_devices.insert(device_location, device_class); } + tracing::info!(target: "gpu-utils", "Found GPU devices: {:?}", gpu_devices); + let result = match get_iommu_groups() { Ok(iommu_groups) => { // Find all devices that are in the same IOMMU group as the GPU devices @@ -54,9 +59,16 @@ pub fn get_gpu_pci() -> Result, PciError> { gpu_devices } }; + + tracing::info!(target: "gpu-utils", "Importing PCI devices: {:?}", result); + Ok(result) } +fn is_vendor_allowed(vendor_id: u16) -> bool { + vendor_id == AMD_VENDOR_ID || vendor_id == NVIDIA_VENDOR_ID +} + // AFAIK the bridge devices are the only non-endpoint devices // May require to update this function if there are other non-endpoint devices fn is_endpoint_device(device_class: &PciDeviceClass) -> bool { diff --git a/crates/vm-utils/src/vm_utils.rs b/crates/vm-utils/src/vm_utils.rs index 6301c5d01e..c67f6af1f7 100644 --- a/crates/vm-utils/src/vm_utils.rs +++ b/crates/vm-utils/src/vm_utils.rs @@ -167,6 +167,7 @@ pub fn create_domain(uri: &str, params: &CreateVMDomainParams) -> Result<(), VmE tracing::info!(target: "vm-utils","Domain with name {} doesn't exists. Creating", params.name); // There's certainly better places to do this, but RN it doesn't really matter let gpu_pci_locations = if params.allow_gpu { + tracing::info!(target: "gpu-utils", "Collecting info about GPU devices..."); gpu_utils::get_gpu_pci()?.into_iter().collect::>() } else { vec![]