kernel/drivers/
virtio_rng.rs1use alloc::{boxed::Box, collections::VecDeque, vec};
22use spin::{Mutex, RwLock};
23
24use crate::drivers::virtio::{
25 device::VirtioDevice,
26 queue::{DescriptorFlag, VirtQueue},
27};
28use crate::random::EntropySource;
29
30const RNG_BUFFER_SIZE: usize = 256;
32
33pub struct VirtioRngDevice {
38 base_addr: usize,
40 virtqueues: Mutex<[VirtQueue<'static>; 1]>,
42 buffer: Mutex<VecDeque<u8>>,
44 features: RwLock<u32>,
46 initialized: RwLock<bool>,
48}
49
50impl VirtioRngDevice {
51 pub fn new(base_addr: usize) -> Self {
61 let mut device = Self {
62 base_addr,
63 virtqueues: Mutex::new([VirtQueue::new(8)]), buffer: Mutex::new(VecDeque::with_capacity(RNG_BUFFER_SIZE)),
65 features: RwLock::new(0),
66 initialized: RwLock::new(false),
67 };
68
69 let negotiated_features = match device.init() {
71 Ok(features) => {
72 *device.initialized.write() = true;
73 features
74 }
75 Err(e) => {
76 crate::early_println!("[VirtIO RNG] Failed to initialize: {}", e);
77 0
78 }
79 };
80
81 *device.features.write() = negotiated_features;
83
84 crate::early_println!(
85 "[VirtIO RNG] Device initialized with features: 0x{:x}",
86 negotiated_features
87 );
88
89 device
90 }
91
92 fn fill_buffer(&self) -> Result<usize, &'static str> {
101 let mut virtqueues = self.virtqueues.lock();
102 let queue = &mut virtqueues[0];
103
104 let mut data_buffer: Box<[u8]> = vec![0u8; RNG_BUFFER_SIZE].into_boxed_slice();
106 let data_ptr = data_buffer.as_mut_ptr();
107
108 let data_phys = crate::vm::get_kernel_vm_manager()
110 .translate_vaddr(data_ptr as usize)
111 .ok_or("Failed to translate data vaddr")?;
112
113 let desc_idx = queue.alloc_desc().ok_or("No available descriptors")?;
115
116 queue.desc[desc_idx].addr = data_phys as u64;
118 queue.desc[desc_idx].len = RNG_BUFFER_SIZE as u32;
119 queue.desc[desc_idx].flags = DescriptorFlag::Write as u16;
120 queue.desc[desc_idx].next = 0;
121
122 if let Err(e) = queue.push(desc_idx) {
124 queue.free_desc(desc_idx);
125 return Err(e);
126 }
127
128 self.notify(0);
130
131 while queue.is_busy() {
133 core::hint::spin_loop();
134 }
135
136 let completed_desc = match queue.pop() {
138 Some(idx) => idx,
139 None => {
140 queue.free_desc(desc_idx);
141 return Err("No response from device");
142 }
143 };
144
145 if completed_desc != desc_idx {
146 queue.free_desc(desc_idx);
147 return Err("Invalid descriptor index");
148 }
149
150 let bytes_received = queue.desc[desc_idx].len as usize;
155 let mut buffer = self.buffer.lock();
156 for i in 0..bytes_received.min(RNG_BUFFER_SIZE) {
157 buffer.push_back(data_buffer[i]);
158 }
159
160 queue.free_desc(desc_idx);
162
163 Ok(bytes_received)
164 }
165
166 fn read_byte_internal(&self) -> Option<u8> {
172 let mut buffer = self.buffer.lock();
173
174 if buffer.is_empty() {
176 drop(buffer); if let Err(e) = self.fill_buffer() {
178 crate::early_println!("[VirtIO RNG] Failed to fill buffer: {}", e);
179 return None;
180 }
181 buffer = self.buffer.lock();
182 }
183
184 if !buffer.is_empty() {
186 buffer.pop_front()
187 } else {
188 None
189 }
190 }
191}
192
193impl EntropySource for VirtioRngDevice {
194 fn name(&self) -> &'static str {
195 "virtio-rng"
196 }
197
198 fn read_entropy(&self, buffer: &mut [u8]) -> usize {
199 let mut bytes_read = 0;
200
201 for i in 0..buffer.len() {
202 if let Some(byte) = self.read_byte_internal() {
203 buffer[i] = byte;
204 bytes_read += 1;
205 } else {
206 break;
207 }
208 }
209
210 bytes_read
211 }
212
213 fn is_available(&self) -> bool {
214 *self.initialized.read()
216 }
217}
218
219impl VirtioDevice for VirtioRngDevice {
220 fn get_base_addr(&self) -> usize {
221 self.base_addr
222 }
223
224 fn get_virtqueue_count(&self) -> usize {
225 1 }
227
228 fn get_virtqueue_size(&self, queue_idx: usize) -> usize {
229 if queue_idx >= 1 {
230 panic!("Invalid queue index for VirtIO RNG device: {}", queue_idx);
231 }
232 let virtqueues = self.virtqueues.lock();
233 virtqueues[queue_idx].get_queue_size()
234 }
235
236 fn get_queue_desc_addr(&self, queue_idx: usize) -> Option<u64> {
237 if queue_idx >= self.get_virtqueue_count() {
238 return None;
239 }
240 let virtqueues = self.virtqueues.lock();
241 Some(virtqueues[queue_idx].get_raw_ptr() as u64)
242 }
243
244 fn get_queue_driver_addr(&self, queue_idx: usize) -> Option<u64> {
245 if queue_idx >= self.get_virtqueue_count() {
246 return None;
247 }
248 let virtqueues = self.virtqueues.lock();
249 Some(virtqueues[queue_idx].avail.flags as *const _ as u64)
250 }
251
252 fn get_queue_device_addr(&self, queue_idx: usize) -> Option<u64> {
253 if queue_idx >= self.get_virtqueue_count() {
254 return None;
255 }
256 let virtqueues = self.virtqueues.lock();
257 Some(virtqueues[queue_idx].used.flags as *const _ as u64)
258 }
259
260 fn get_supported_features(&self, _device_features: u32) -> u32 {
261 0
264 }
265}