1extern crate alloc;
6
7use super::device::PciDeviceInfo;
8use crate::device::{DeviceDriver, DeviceInfo};
9use alloc::vec::Vec;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct PciDeviceId {
14 pub vendor: u16,
16 pub device: u16,
18 pub subvendor: u16,
20 pub subdevice: u16,
22 pub class: u32,
24 pub class_mask: u32,
26}
27
28impl PciDeviceId {
29 pub const ANY: u16 = 0xFFFF;
31
32 pub const ANY_CLASS: u32 = 0xFFFFFF;
34
35 pub const fn new(vendor: u16, device: u16) -> Self {
37 Self {
38 vendor,
39 device,
40 subvendor: Self::ANY,
41 subdevice: Self::ANY,
42 class: Self::ANY_CLASS,
43 class_mask: 0,
44 }
45 }
46
47 pub const fn from_class(class: u32, mask: u32) -> Self {
49 Self {
50 vendor: Self::ANY,
51 device: Self::ANY,
52 subvendor: Self::ANY,
53 subdevice: Self::ANY,
54 class,
55 class_mask: mask,
56 }
57 }
58
59 pub fn matches(&self, device: &PciDeviceInfo) -> bool {
61 if self.vendor != Self::ANY && self.vendor != device.vendor_id() {
63 return false;
64 }
65 if self.device != Self::ANY && self.device != device.device_id() {
66 return false;
67 }
68
69 if self.subvendor != Self::ANY && self.subvendor != device.subsystem_vendor_id() {
71 return false;
72 }
73 if self.subdevice != Self::ANY && self.subdevice != device.subsystem_id() {
74 return false;
75 }
76
77 if self.class_mask != 0 {
79 let device_class = device.class_code();
80 if (device_class & self.class_mask) != (self.class & self.class_mask) {
81 return false;
82 }
83 }
84
85 true
86 }
87}
88
89pub struct PciDeviceDriver {
93 name: &'static str,
95 id_table: Vec<PciDeviceId>,
97 probe_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
99 remove_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
101}
102
103impl PciDeviceDriver {
104 pub fn new(
113 name: &'static str,
114 id_table: Vec<PciDeviceId>,
115 probe_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
116 remove_fn: fn(&PciDeviceInfo) -> Result<(), &'static str>,
117 ) -> Self {
118 Self {
119 name,
120 id_table,
121 probe_fn,
122 remove_fn,
123 }
124 }
125
126 pub fn id_table(&self) -> &[PciDeviceId] {
128 &self.id_table
129 }
130
131 pub fn matches_device(&self, device: &PciDeviceInfo) -> bool {
133 self.id_table.iter().any(|id| id.matches(device))
134 }
135}
136
137impl DeviceDriver for PciDeviceDriver {
138 fn name(&self) -> &'static str {
139 self.name
140 }
141
142 fn match_table(&self) -> Vec<&'static str> {
143 Vec::new()
148 }
149
150 fn probe(&self, device: &dyn DeviceInfo) -> Result<(), &'static str> {
151 let pci_device = device
153 .as_any()
154 .downcast_ref::<PciDeviceInfo>()
155 .ok_or("Failed to downcast to PciDeviceInfo")?;
156
157 if !self.matches_device(pci_device) {
159 return Err("Device does not match driver");
160 }
161
162 (self.probe_fn)(pci_device)
164 }
165
166 fn remove(&self, device: &dyn DeviceInfo) -> Result<(), &'static str> {
167 let pci_device = device
169 .as_any()
170 .downcast_ref::<PciDeviceInfo>()
171 .ok_or("Failed to downcast to PciDeviceInfo")?;
172
173 (self.remove_fn)(pci_device)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::device::pci::PciAddress;
182
183 #[test_case]
184 fn test_pci_device_id_matching() {
185 let addr = PciAddress::new(0, 0, 1, 0);
186 let device = PciDeviceInfo::new(
187 addr,
188 0x8086, 0x1234,
190 0x020000, 0x01,
192 0x0000,
193 0x0000,
194 0x0B,
195 0x01,
196 "pci_device",
197 1,
198 );
199
200 let id = PciDeviceId::new(0x8086, 0x1234);
202 assert!(id.matches(&device));
203
204 let id = PciDeviceId::new(0x1234, 0x1234);
206 assert!(!id.matches(&device));
207
208 let id = PciDeviceId::new(0x8086, 0x5678);
210 assert!(!id.matches(&device));
211
212 let id = PciDeviceId::new(PciDeviceId::ANY, 0x1234);
214 assert!(id.matches(&device));
215
216 let id = PciDeviceId::from_class(0x020000, 0xFF0000); assert!(id.matches(&device));
219
220 let id = PciDeviceId::from_class(0x030000, 0xFF0000); assert!(!id.matches(&device));
223 }
224
225 #[test_case]
226 fn test_pci_driver_matching() {
227 let addr = PciAddress::new(0, 0, 1, 0);
228 let device = PciDeviceInfo::new(
229 addr,
230 0x8086,
231 0x1234,
232 0x020000,
233 0x01,
234 0x0000,
235 0x0000,
236 0x0B,
237 0x01,
238 "pci_device",
239 1,
240 );
241
242 let id_table = alloc::vec![
243 PciDeviceId::new(0x8086, 0x1234),
244 PciDeviceId::new(0x8086, 0x5678),
245 ];
246
247 let driver =
248 PciDeviceDriver::new("test_driver", id_table, |_device| Ok(()), |_device| Ok(()));
249
250 assert!(driver.matches_device(&device));
251 }
252
253 #[test_case]
254 fn test_pci_driver_probe() {
255 let addr = PciAddress::new(0, 0, 1, 0);
256 let device = PciDeviceInfo::new(
257 addr,
258 0x8086,
259 0x1234,
260 0x020000,
261 0x01,
262 0x0000,
263 0x0000,
264 0x0B,
265 0x01,
266 "pci_device",
267 1,
268 );
269
270 static mut PROBE_CALLED: bool = false;
271
272 let id_table = alloc::vec![PciDeviceId::new(0x8086, 0x1234)];
273
274 let driver = PciDeviceDriver::new(
275 "test_driver",
276 id_table,
277 |dev| {
278 unsafe {
279 PROBE_CALLED = true;
280 }
281 assert_eq!(dev.vendor_id(), 0x8086);
282 Ok(())
283 },
284 |_device| Ok(()),
285 );
286
287 let result = driver.probe(&device);
288 assert!(result.is_ok());
289 assert!(unsafe { PROBE_CALLED });
290 }
291
292 #[test_case]
293 fn test_virtio_pci_stub_driver_probe() {
294 static mut VIRTIO_NET_PROBED: bool = false;
298 static mut VIRTIO_BLK_PROBED: bool = false;
299
300 let id_table = alloc::vec![
302 PciDeviceId::new(0x1AF4, 0x1000), PciDeviceId::new(0x1AF4, 0x1001), PciDeviceId::new(0x1AF4, 0x1041), PciDeviceId::new(0x1AF4, 0x1042), ];
307
308 let driver = PciDeviceDriver::new(
309 "virtio-pci-stub",
310 id_table,
311 |device| {
312 match device.device_id() {
314 0x1000 | 0x1041 => {
315 unsafe {
317 VIRTIO_NET_PROBED = true;
318 }
319 assert_eq!(device.vendor_id(), 0x1AF4);
320 assert_eq!(device.base_class(), 0x02); }
322 0x1001 | 0x1042 => {
323 unsafe {
325 VIRTIO_BLK_PROBED = true;
326 }
327 assert_eq!(device.vendor_id(), 0x1AF4);
328 assert_eq!(device.base_class(), 0x01); }
330 _ => return Err("Unknown device"),
331 }
332 Ok(())
333 },
334 |_device| Ok(()),
335 );
336
337 let addr = PciAddress::new(0, 0, 1, 0);
339 let virtio_net = PciDeviceInfo::new(
340 addr,
341 0x1AF4, 0x1000, 0x020000, 0x00,
345 0x1AF4,
346 0x0001,
347 0x0B,
348 0x01,
349 "virtio_pci_device",
350 1,
351 );
352
353 assert!(driver.matches_device(&virtio_net));
354 let result = driver.probe(&virtio_net);
355 assert!(result.is_ok());
356 assert!(unsafe { VIRTIO_NET_PROBED });
357
358 let addr = PciAddress::new(0, 0, 2, 0);
360 let virtio_blk = PciDeviceInfo::new(
361 addr,
362 0x1AF4, 0x1001, 0x010000, 0x00,
366 0x1AF4,
367 0x0002,
368 0x0B,
369 0x01,
370 "virtio_pci_device",
371 2,
372 );
373
374 assert!(driver.matches_device(&virtio_blk));
375 let result = driver.probe(&virtio_blk);
376 assert!(result.is_ok());
377 assert!(unsafe { VIRTIO_BLK_PROBED });
378
379 let addr = PciAddress::new(0, 0, 3, 0);
381 let intel_device = PciDeviceInfo::new(
382 addr,
383 0x8086, 0x1234,
385 0x020000,
386 0x00,
387 0x0000,
388 0x0000,
389 0x0B,
390 0x01,
391 "intel_device",
392 3,
393 );
394
395 assert!(!driver.matches_device(&intel_device));
396 }
397
398 #[test_case]
399 fn test_virtio_pci_class_based_matching() {
400 static mut MATCHED: bool = false;
405
406 let id_table = alloc::vec![
408 PciDeviceId::from_class(0x020000, 0xFF0000), ];
410
411 let driver = PciDeviceDriver::new(
412 "network-stub",
413 id_table,
414 |device| {
415 unsafe {
416 MATCHED = true;
417 }
418 assert_eq!(device.base_class(), 0x02);
419 Ok(())
420 },
421 |_device| Ok(()),
422 );
423
424 let addr = PciAddress::new(0, 0, 1, 0);
426 let virtio_net = PciDeviceInfo::new(
427 addr,
428 0x1AF4,
429 0x1000,
430 0x020000, 0x00,
432 0x1AF4,
433 0x0001,
434 0x0B,
435 0x01,
436 "virtio_net",
437 1,
438 );
439
440 assert!(driver.matches_device(&virtio_net));
441 let result = driver.probe(&virtio_net);
442 assert!(result.is_ok());
443 assert!(unsafe { MATCHED });
444 }
445}