kernel/device/pci/
driver.rs

1//! PCI device driver support.
2//!
3//! This module defines the PCI driver structure that can match and probe PCI devices.
4
5extern crate alloc;
6
7use super::device::PciDeviceInfo;
8use crate::device::{DeviceDriver, DeviceInfo};
9use alloc::vec::Vec;
10
11/// PCI device ID for driver matching
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct PciDeviceId {
14    /// Vendor ID (or 0xFFFF for any)
15    pub vendor: u16,
16    /// Device ID (or 0xFFFF for any)
17    pub device: u16,
18    /// Subvendor ID (or 0xFFFF for any)
19    pub subvendor: u16,
20    /// Subdevice ID (or 0xFFFF for any)
21    pub subdevice: u16,
22    /// Class code (or 0xFFFFFF for any)
23    pub class: u32,
24    /// Class mask (which bits of class to match)
25    pub class_mask: u32,
26}
27
28impl PciDeviceId {
29    /// Match any vendor and device
30    pub const ANY: u16 = 0xFFFF;
31
32    /// Match any class
33    pub const ANY_CLASS: u32 = 0xFFFFFF;
34
35    /// Create a new PCI device ID matcher
36    pub const fn new(vendor: u16, device: u16) -> Self {
37        Self {
38            vendor,
39            device,
40            subvendor: Self::ANY,
41            subdevice: Self::ANY,
42            class: Self::ANY_CLASS,
43            class_mask: 0,
44        }
45    }
46
47    /// Create a matcher for a specific class
48    pub const fn from_class(class: u32, mask: u32) -> Self {
49        Self {
50            vendor: Self::ANY,
51            device: Self::ANY,
52            subvendor: Self::ANY,
53            subdevice: Self::ANY,
54            class,
55            class_mask: mask,
56        }
57    }
58
59    /// Check if this ID matches a PCI device
60    pub fn matches(&self, device: &PciDeviceInfo) -> bool {
61        // Check vendor/device ID
62        if self.vendor != Self::ANY && self.vendor != device.vendor_id() {
63            return false;
64        }
65        if self.device != Self::ANY && self.device != device.device_id() {
66            return false;
67        }
68
69        // Check subsystem vendor/device ID
70        if self.subvendor != Self::ANY && self.subvendor != device.subsystem_vendor_id() {
71            return false;
72        }
73        if self.subdevice != Self::ANY && self.subdevice != device.subsystem_id() {
74            return false;
75        }
76
77        // Check class code with mask
78        if self.class_mask != 0 {
79            let device_class = device.class_code();
80            if (device_class & self.class_mask) != (self.class & self.class_mask) {
81                return false;
82            }
83        }
84
85        true
86    }
87}
88
89/// PCI device driver
90///
91/// Implements the DeviceDriver trait for PCI devices.
92pub struct PciDeviceDriver {
93    /// Driver name
94    name: &'static str,
95    /// List of PCI device IDs this driver supports
96    id_table: Vec<PciDeviceId>,
97    /// Probe function
98    probe_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
99    /// Remove function
100    remove_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
101}
102
103impl PciDeviceDriver {
104    /// Create a new PCI device driver
105    ///
106    /// # Arguments
107    ///
108    /// * `name` - Driver name
109    /// * `id_table` - List of supported PCI device IDs
110    /// * `probe_fn` - Function to probe devices
111    /// * `remove_fn` - Function to remove devices
112    pub fn new(
113        name: &'static str,
114        id_table: Vec<PciDeviceId>,
115        probe_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
116        remove_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
117    ) -> Self {
118        Self {
119            name,
120            id_table,
121            probe_fn,
122            remove_fn,
123        }
124    }
125
126    /// Get the device ID table
127    pub fn id_table(&self) -> &[PciDeviceId] {
128        &self.id_table
129    }
130
131    /// Check if this driver matches a PCI device
132    pub fn matches_device(&self, device: &PciDeviceInfo) -> bool {
133        self.id_table.iter().any(|id| id.matches(device))
134    }
135}
136
137impl DeviceDriver for PciDeviceDriver {
138    fn name(&self) -> &'static str {
139        self.name
140    }
141
142    fn match_table(&self) -> Vec<&'static str> {
143        // PCI drivers use device IDs (PciDeviceId) rather than string matching.
144        // The actual matching is done by the matches_device() method which checks
145        // vendor/device IDs and class codes. This returns an empty vector as
146        // string-based matching is not used for PCI devices.
147        Vec::new()
148    }
149
150    fn probe(&self, device: &dyn DeviceInfo) -> Result<(), &'static str> {
151        // Downcast to PciDeviceInfo
152        let pci_device = device
153            .as_any()
154            .downcast_ref::<PciDeviceInfo>()
155            .ok_or("Failed to downcast to PciDeviceInfo")?;
156
157        // Check if this driver matches the device
158        if !self.matches_device(pci_device) {
159            return Err("Device does not match driver");
160        }
161
162        // Call the probe function
163        (self.probe_fn)(pci_device)
164    }
165
166    fn remove(&self, device: &dyn DeviceInfo) -> Result<(), &'static str> {
167        // Downcast to PciDeviceInfo
168        let pci_device = device
169            .as_any()
170            .downcast_ref::<PciDeviceInfo>()
171            .ok_or("Failed to downcast to PciDeviceInfo")?;
172
173        // Call the remove function
174        (self.remove_fn)(pci_device)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::device::pci::PciAddress;
182
183    #[test_case]
184    fn test_pci_device_id_matching() {
185        let addr = PciAddress::new(0, 0, 1, 0);
186        let device = PciDeviceInfo::new(
187            addr,
188            0x8086, // Intel
189            0x1234,
190            0x020000, // Network controller
191            0x01,
192            0x0000,
193            0x0000,
194            0x0B,
195            0x01,
196            "pci_device",
197            1,
198        );
199
200        // Test exact match
201        let id = PciDeviceId::new(0x8086, 0x1234);
202        assert!(id.matches(&device));
203
204        // Test vendor mismatch
205        let id = PciDeviceId::new(0x1234, 0x1234);
206        assert!(!id.matches(&device));
207
208        // Test device mismatch
209        let id = PciDeviceId::new(0x8086, 0x5678);
210        assert!(!id.matches(&device));
211
212        // Test ANY vendor
213        let id = PciDeviceId::new(PciDeviceId::ANY, 0x1234);
214        assert!(id.matches(&device));
215
216        // Test class matching
217        let id = PciDeviceId::from_class(0x020000, 0xFF0000); // Match network class
218        assert!(id.matches(&device));
219
220        // Test class mismatch
221        let id = PciDeviceId::from_class(0x030000, 0xFF0000); // Display class
222        assert!(!id.matches(&device));
223    }
224
225    #[test_case]
226    fn test_pci_driver_matching() {
227        let addr = PciAddress::new(0, 0, 1, 0);
228        let device = PciDeviceInfo::new(
229            addr,
230            0x8086,
231            0x1234,
232            0x020000,
233            0x01,
234            0x0000,
235            0x0000,
236            0x0B,
237            0x01,
238            "pci_device",
239            1,
240        );
241
242        let id_table = alloc::vec![
243            PciDeviceId::new(0x8086, 0x1234),
244            PciDeviceId::new(0x8086, 0x5678),
245        ];
246
247        let driver =
248            PciDeviceDriver::new("test_driver", id_table, |_device| Ok(()), |_device| Ok(()));
249
250        assert!(driver.matches_device(&device));
251    }
252
253    #[test_case]
254    fn test_pci_driver_probe() {
255        let addr = PciAddress::new(0, 0, 1, 0);
256        let device = PciDeviceInfo::new(
257            addr,
258            0x8086,
259            0x1234,
260            0x020000,
261            0x01,
262            0x0000,
263            0x0000,
264            0x0B,
265            0x01,
266            "pci_device",
267            1,
268        );
269
270        static mut PROBE_CALLED: bool = false;
271
272        let id_table = alloc::vec![PciDeviceId::new(0x8086, 0x1234)];
273
274        let driver = PciDeviceDriver::new(
275            "test_driver",
276            id_table,
277            |dev| {
278                unsafe {
279                    PROBE_CALLED = true;
280                }
281                assert_eq!(dev.vendor_id(), 0x8086);
282                Ok(())
283            },
284            |_device| Ok(()),
285        );
286
287        let result = driver.probe(&device);
288        assert!(result.is_ok());
289        assert!(unsafe { PROBE_CALLED });
290    }
291
292    #[test_case]
293    fn test_virtio_pci_stub_driver_probe() {
294        // Test simulating virtio-pci devices (Red Hat vendor ID 0x1AF4)
295        // This verifies that PCI probe works with real-world device IDs
296
297        static mut VIRTIO_NET_PROBED: bool = false;
298        static mut VIRTIO_BLK_PROBED: bool = false;
299
300        // Create stub virtio-pci driver that supports common virtio devices
301        let id_table = alloc::vec![
302            PciDeviceId::new(0x1AF4, 0x1000), // VirtIO net (legacy)
303            PciDeviceId::new(0x1AF4, 0x1001), // VirtIO block (legacy)
304            PciDeviceId::new(0x1AF4, 0x1041), // VirtIO net (modern)
305            PciDeviceId::new(0x1AF4, 0x1042), // VirtIO block (modern)
306        ];
307
308        let driver = PciDeviceDriver::new(
309            "virtio-pci-stub",
310            id_table,
311            |device| {
312                // Stub probe function - just verify device info
313                match device.device_id() {
314                    0x1000 | 0x1041 => {
315                        // VirtIO net
316                        unsafe {
317                            VIRTIO_NET_PROBED = true;
318                        }
319                        assert_eq!(device.vendor_id(), 0x1AF4);
320                        assert_eq!(device.base_class(), 0x02); // Network
321                    }
322                    0x1001 | 0x1042 => {
323                        // VirtIO block
324                        unsafe {
325                            VIRTIO_BLK_PROBED = true;
326                        }
327                        assert_eq!(device.vendor_id(), 0x1AF4);
328                        assert_eq!(device.base_class(), 0x01); // Storage
329                    }
330                    _ => return Err("Unknown device"),
331                }
332                Ok(())
333            },
334            |_device| Ok(()),
335        );
336
337        // Test 1: VirtIO net device (legacy)
338        let addr = PciAddress::new(0, 0, 1, 0);
339        let virtio_net = PciDeviceInfo::new(
340            addr,
341            0x1AF4,   // Red Hat vendor
342            0x1000,   // VirtIO net (legacy)
343            0x020000, // Network controller
344            0x00,
345            0x1AF4,
346            0x0001,
347            0x0B,
348            0x01,
349            "virtio_pci_device",
350            1,
351        );
352
353        assert!(driver.matches_device(&virtio_net));
354        let result = driver.probe(&virtio_net);
355        assert!(result.is_ok());
356        assert!(unsafe { VIRTIO_NET_PROBED });
357
358        // Test 2: VirtIO block device (legacy)
359        let addr = PciAddress::new(0, 0, 2, 0);
360        let virtio_blk = PciDeviceInfo::new(
361            addr,
362            0x1AF4,   // Red Hat vendor
363            0x1001,   // VirtIO block (legacy)
364            0x010000, // Storage controller
365            0x00,
366            0x1AF4,
367            0x0002,
368            0x0B,
369            0x01,
370            "virtio_pci_device",
371            2,
372        );
373
374        assert!(driver.matches_device(&virtio_blk));
375        let result = driver.probe(&virtio_blk);
376        assert!(result.is_ok());
377        assert!(unsafe { VIRTIO_BLK_PROBED });
378
379        // Test 3: Non-matching device should not be probed
380        let addr = PciAddress::new(0, 0, 3, 0);
381        let intel_device = PciDeviceInfo::new(
382            addr,
383            0x8086, // Intel vendor
384            0x1234,
385            0x020000,
386            0x00,
387            0x0000,
388            0x0000,
389            0x0B,
390            0x01,
391            "intel_device",
392            3,
393        );
394
395        assert!(!driver.matches_device(&intel_device));
396    }
397
398    #[test_case]
399    fn test_virtio_pci_class_based_matching() {
400        // Test class-based matching for virtio devices
401        // This is useful when you want to match all devices of a certain class
402        // regardless of vendor
403
404        static mut MATCHED: bool = false;
405
406        // Create driver that matches all network controllers
407        let id_table = alloc::vec![
408            PciDeviceId::from_class(0x020000, 0xFF0000), // Network class
409        ];
410
411        let driver = PciDeviceDriver::new(
412            "network-stub",
413            id_table,
414            |device| {
415                unsafe {
416                    MATCHED = true;
417                }
418                assert_eq!(device.base_class(), 0x02);
419                Ok(())
420            },
421            |_device| Ok(()),
422        );
423
424        // Should match virtio-net
425        let addr = PciAddress::new(0, 0, 1, 0);
426        let virtio_net = PciDeviceInfo::new(
427            addr,
428            0x1AF4,
429            0x1000,
430            0x020000, // Network controller
431            0x00,
432            0x1AF4,
433            0x0001,
434            0x0B,
435            0x01,
436            "virtio_net",
437            1,
438        );
439
440        assert!(driver.matches_device(&virtio_net));
441        let result = driver.probe(&virtio_net);
442        assert!(result.is_ok());
443        assert!(unsafe { MATCHED });
444    }
445}