kernel/arch/riscv64/fpu.rs
1//! Floating-Point Unit and Vector context for RISC-V 64-bit
2//!
3//! This module provides the FPU and Vector context structures for saving and restoring
4//! floating-point and vector register state during context switches.
5//!
6//! ## FPU (F/D Extensions)
7//! RISC-V uses the F (single-precision) and D (double-precision) extensions with
8//! 32 floating-point registers (f0-f31, each 64-bit for D extension) and fcsr
9//! control/status register.
10//!
11//! ## Vector (V Extension)
12//! RISC-V Vector extension provides 32 vector registers (v0-v31) with configurable
13//! VLEN (vector length). The actual size depends on the implementation. This module
14//! supports VLEN up to 256 bits (32 bytes per register, vlenb=32).
15
16use core::arch::asm;
17
18mod fpu_switch;
19
20pub use fpu_switch::{
21 kernel_switch_in_user_fpu, kernel_switch_out_user_fpu, kernel_switch_out_user_vector,
22};
23
24/// FPU context for RISC-V 64-bit (F/D extensions)
25///
26/// Contains all floating-point registers and the floating-point control/status register.
27/// This is saved/restored during task context switches to preserve FPU state.
28#[repr(C, align(16))]
29#[derive(Debug, Clone)]
30pub struct FpuContext {
31 /// Floating-point registers f0-f31 (64-bit each for D extension)
32 pub f: [u64; 32],
33 /// Floating-point control and status register (fcsr)
34 pub fcsr: u32,
35}
36
37impl FpuContext {
38 /// Create a new zeroed FPU context
39 pub const fn new() -> Self {
40 Self {
41 f: [0; 32],
42 fcsr: 0,
43 }
44 }
45
46 /// Save the current FPU state to this context
47 ///
48 /// # Safety
49 /// This function directly accesses FPU registers. The FPU must be enabled
50 /// (sstatus.FS != Off) before calling this function.
51 #[inline]
52 pub unsafe fn save(&mut self) {
53 let ptr = self.f.as_mut_ptr();
54 unsafe {
55 asm!(
56 ".option push",
57 ".option arch, +f, +d",
58 // Save all 32 floating-point registers
59 "fsd f0, 0*8({0})",
60 "fsd f1, 1*8({0})",
61 "fsd f2, 2*8({0})",
62 "fsd f3, 3*8({0})",
63 "fsd f4, 4*8({0})",
64 "fsd f5, 5*8({0})",
65 "fsd f6, 6*8({0})",
66 "fsd f7, 7*8({0})",
67 "fsd f8, 8*8({0})",
68 "fsd f9, 9*8({0})",
69 "fsd f10, 10*8({0})",
70 "fsd f11, 11*8({0})",
71 "fsd f12, 12*8({0})",
72 "fsd f13, 13*8({0})",
73 "fsd f14, 14*8({0})",
74 "fsd f15, 15*8({0})",
75 "fsd f16, 16*8({0})",
76 "fsd f17, 17*8({0})",
77 "fsd f18, 18*8({0})",
78 "fsd f19, 19*8({0})",
79 "fsd f20, 20*8({0})",
80 "fsd f21, 21*8({0})",
81 "fsd f22, 22*8({0})",
82 "fsd f23, 23*8({0})",
83 "fsd f24, 24*8({0})",
84 "fsd f25, 25*8({0})",
85 "fsd f26, 26*8({0})",
86 "fsd f27, 27*8({0})",
87 "fsd f28, 28*8({0})",
88 "fsd f29, 29*8({0})",
89 "fsd f30, 30*8({0})",
90 "fsd f31, 31*8({0})",
91 ".option pop",
92 in(reg) ptr,
93 options(nostack),
94 );
95 }
96 // Save fcsr
97 let fcsr: u32;
98 unsafe {
99 asm!(
100 ".option push",
101 ".option arch, +f, +d",
102 "frcsr {0}",
103 ".option pop",
104 out(reg) fcsr,
105 options(nomem, nostack),
106 );
107 }
108 self.fcsr = fcsr;
109 }
110
111 /// Restore the FPU state from this context
112 ///
113 /// # Safety
114 /// This function directly accesses FPU registers. The FPU must be enabled
115 /// (sstatus.FS != Off) before calling this function.
116 #[inline]
117 pub unsafe fn restore(&self) {
118 // Restore fcsr first
119 unsafe {
120 asm!(
121 ".option push",
122 ".option arch, +f, +d",
123 "fscsr {0}",
124 ".option pop",
125 in(reg) self.fcsr,
126 options(nomem, nostack),
127 );
128 }
129 let ptr = self.f.as_ptr();
130 unsafe {
131 asm!(
132 ".option push",
133 ".option arch, +f, +d",
134 // Restore all 32 floating-point registers
135 "fld f0, 0*8({0})",
136 "fld f1, 1*8({0})",
137 "fld f2, 2*8({0})",
138 "fld f3, 3*8({0})",
139 "fld f4, 4*8({0})",
140 "fld f5, 5*8({0})",
141 "fld f6, 6*8({0})",
142 "fld f7, 7*8({0})",
143 "fld f8, 8*8({0})",
144 "fld f9, 9*8({0})",
145 "fld f10, 10*8({0})",
146 "fld f11, 11*8({0})",
147 "fld f12, 12*8({0})",
148 "fld f13, 13*8({0})",
149 "fld f14, 14*8({0})",
150 "fld f15, 15*8({0})",
151 "fld f16, 16*8({0})",
152 "fld f17, 17*8({0})",
153 "fld f18, 18*8({0})",
154 "fld f19, 19*8({0})",
155 "fld f20, 20*8({0})",
156 "fld f21, 21*8({0})",
157 "fld f22, 22*8({0})",
158 "fld f23, 23*8({0})",
159 "fld f24, 24*8({0})",
160 "fld f25, 25*8({0})",
161 "fld f26, 26*8({0})",
162 "fld f27, 27*8({0})",
163 "fld f28, 28*8({0})",
164 "fld f29, 29*8({0})",
165 "fld f30, 30*8({0})",
166 "fld f31, 31*8({0})",
167 ".option pop",
168 in(reg) ptr,
169 options(nostack),
170 );
171 }
172 }
173}
174
175impl Default for FpuContext {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181/// Maximum vector length in bytes (VLEN / 8) supported by this implementation.
182/// This supports VLEN up to 256 bits (32 bytes per register).
183/// QEMU virt machine typically uses VLEN=128 (vlenb=16).
184pub const MAX_VLENB: usize = 32;
185
186/// Vector context for RISC-V 64-bit (V extension)
187///
188/// Contains all vector registers and vector CSRs.
189/// The actual number of bytes used per register depends on the implementation's VLEN.
190/// This structure reserves space for VLEN up to 256 bits.
191#[repr(C, align(16))]
192#[derive(Debug, Clone)]
193pub struct VectorContext {
194 /// Vector registers v0-v31 (up to 256 bits = 32 bytes each)
195 /// Stored as arrays of u64 for alignment
196 pub v: [[u64; MAX_VLENB / 8]; 32],
197 /// Vector type register (vtype)
198 pub vtype: u64,
199 /// Vector length register (vl)
200 pub vl: u64,
201 /// Vector start index register (vstart)
202 pub vstart: u64,
203 /// Vector fixed-point rounding mode register (vxrm)
204 pub vxrm: u64,
205 /// Vector fixed-point saturation flag (vxsat)
206 pub vxsat: u64,
207 /// Vector control and status register (vcsr) - combines vxrm and vxsat
208 pub vcsr: u64,
209 /// Cached vlenb value (VLEN/8 in bytes)
210 pub vlenb: u64,
211}
212
213impl VectorContext {
214 /// Create a new zeroed Vector context
215 pub const fn new() -> Self {
216 Self {
217 v: [[0; MAX_VLENB / 8]; 32],
218 vtype: 0,
219 vl: 0,
220 vstart: 0,
221 vxrm: 0,
222 vxsat: 0,
223 vcsr: 0,
224 vlenb: 0,
225 }
226 }
227
228 /// Save the current Vector state to this context
229 ///
230 /// # Safety
231 /// This function directly accesses Vector registers. The Vector extension must
232 /// be enabled (sstatus.VS != Off) before calling this function.
233 #[inline]
234 pub unsafe fn save(&mut self) {
235 // Read vlenb to know the actual vector register size
236 let vlenb: u64;
237 unsafe {
238 asm!(
239 ".option push",
240 ".option arch, +v",
241 "csrr {0}, vlenb",
242 ".option pop",
243 out(reg) vlenb,
244 options(nomem, nostack),
245 );
246 }
247 self.vlenb = vlenb;
248
249 // Save vector CSRs
250 unsafe {
251 asm!(
252 ".option push",
253 ".option arch, +v",
254 "csrr {0}, vtype",
255 "csrr {1}, vl",
256 "csrr {2}, vstart",
257 "csrr {3}, vcsr",
258 ".option pop",
259 out(reg) self.vtype,
260 out(reg) self.vl,
261 out(reg) self.vstart,
262 out(reg) self.vcsr,
263 options(nomem, nostack),
264 );
265 }
266
267 // Extract vxrm and vxsat from vcsr
268 self.vxrm = (self.vcsr >> 1) & 0x3;
269 self.vxsat = self.vcsr & 0x1;
270
271 // Save vector registers using vs1r.v (whole register store).
272 // Use the runtime vlenb as the stride so we only touch the bytes that
273 // the implementation actually uses. This avoids unnecessary memory
274 // traffic (e.g. QEMU virt often uses vlenb=16).
275 let ptr = self.v.as_mut_ptr() as *mut u8;
276 let stride = vlenb as usize;
277
278 // Use inline assembly to save each vector register
279 // vs1r.v stores one vector register (VLEN bits)
280 unsafe {
281 asm!(
282 ".option push",
283 ".option arch, +v",
284 "add t0, {ptr}, {stride}",
285 "vs1r.v v0, ({ptr})",
286 "vs1r.v v1, (t0)",
287 "add {ptr}, t0, {stride}",
288 "add t0, {ptr}, {stride}",
289 "vs1r.v v2, ({ptr})",
290 "vs1r.v v3, (t0)",
291 "add {ptr}, t0, {stride}",
292 "add t0, {ptr}, {stride}",
293 "vs1r.v v4, ({ptr})",
294 "vs1r.v v5, (t0)",
295 "add {ptr}, t0, {stride}",
296 "add t0, {ptr}, {stride}",
297 "vs1r.v v6, ({ptr})",
298 "vs1r.v v7, (t0)",
299 "add {ptr}, t0, {stride}",
300 "add t0, {ptr}, {stride}",
301 "vs1r.v v8, ({ptr})",
302 "vs1r.v v9, (t0)",
303 "add {ptr}, t0, {stride}",
304 "add t0, {ptr}, {stride}",
305 "vs1r.v v10, ({ptr})",
306 "vs1r.v v11, (t0)",
307 "add {ptr}, t0, {stride}",
308 "add t0, {ptr}, {stride}",
309 "vs1r.v v12, ({ptr})",
310 "vs1r.v v13, (t0)",
311 "add {ptr}, t0, {stride}",
312 "add t0, {ptr}, {stride}",
313 "vs1r.v v14, ({ptr})",
314 "vs1r.v v15, (t0)",
315 "add {ptr}, t0, {stride}",
316 "add t0, {ptr}, {stride}",
317 "vs1r.v v16, ({ptr})",
318 "vs1r.v v17, (t0)",
319 "add {ptr}, t0, {stride}",
320 "add t0, {ptr}, {stride}",
321 "vs1r.v v18, ({ptr})",
322 "vs1r.v v19, (t0)",
323 "add {ptr}, t0, {stride}",
324 "add t0, {ptr}, {stride}",
325 "vs1r.v v20, ({ptr})",
326 "vs1r.v v21, (t0)",
327 "add {ptr}, t0, {stride}",
328 "add t0, {ptr}, {stride}",
329 "vs1r.v v22, ({ptr})",
330 "vs1r.v v23, (t0)",
331 "add {ptr}, t0, {stride}",
332 "add t0, {ptr}, {stride}",
333 "vs1r.v v24, ({ptr})",
334 "vs1r.v v25, (t0)",
335 "add {ptr}, t0, {stride}",
336 "add t0, {ptr}, {stride}",
337 "vs1r.v v26, ({ptr})",
338 "vs1r.v v27, (t0)",
339 "add {ptr}, t0, {stride}",
340 "add t0, {ptr}, {stride}",
341 "vs1r.v v28, ({ptr})",
342 "vs1r.v v29, (t0)",
343 "add {ptr}, t0, {stride}",
344 "add t0, {ptr}, {stride}",
345 "vs1r.v v30, ({ptr})",
346 "vs1r.v v31, (t0)",
347 ".option pop",
348 ptr = inout(reg) ptr => _,
349 stride = in(reg) stride,
350 out("t0") _,
351 options(nostack),
352 );
353 }
354 }
355
356 /// Restore the Vector state from this context
357 ///
358 /// # Safety
359 /// This function directly accesses Vector registers. The Vector extension must
360 /// be enabled (sstatus.VS != Off) before calling this function.
361 #[inline]
362 pub unsafe fn restore(&self) {
363 // Restore vector CSRs first
364 unsafe {
365 asm!(
366 ".option push",
367 ".option arch, +v",
368 "csrw vstart, {0}",
369 "csrw vcsr, {1}",
370 ".option pop",
371 in(reg) self.vstart,
372 in(reg) self.vcsr,
373 options(nomem, nostack),
374 );
375 }
376
377 // Restore vtype and vl using vsetvl
378 // This sets both vtype and vl atomically
379 unsafe {
380 asm!(
381 ".option push",
382 ".option arch, +v",
383 "vsetvl x0, {0}, {1}",
384 ".option pop",
385 in(reg) self.vl,
386 in(reg) self.vtype,
387 options(nomem, nostack),
388 );
389 }
390
391 // Restore vector registers using vl1r.v (whole register load)
392 let ptr = self.v.as_ptr() as *const u8;
393 // Use the saved vlenb if available; fall back to MAX_VLENB for
394 // never-saved (zero-initial) contexts.
395 let stride = if self.vlenb == 0 {
396 MAX_VLENB
397 } else {
398 self.vlenb as usize
399 };
400
401 unsafe {
402 asm!(
403 ".option push",
404 ".option arch, +v",
405 "add t0, {ptr}, {stride}",
406 "vl1r.v v0, ({ptr})",
407 "vl1r.v v1, (t0)",
408 "add {ptr}, t0, {stride}",
409 "add t0, {ptr}, {stride}",
410 "vl1r.v v2, ({ptr})",
411 "vl1r.v v3, (t0)",
412 "add {ptr}, t0, {stride}",
413 "add t0, {ptr}, {stride}",
414 "vl1r.v v4, ({ptr})",
415 "vl1r.v v5, (t0)",
416 "add {ptr}, t0, {stride}",
417 "add t0, {ptr}, {stride}",
418 "vl1r.v v6, ({ptr})",
419 "vl1r.v v7, (t0)",
420 "add {ptr}, t0, {stride}",
421 "add t0, {ptr}, {stride}",
422 "vl1r.v v8, ({ptr})",
423 "vl1r.v v9, (t0)",
424 "add {ptr}, t0, {stride}",
425 "add t0, {ptr}, {stride}",
426 "vl1r.v v10, ({ptr})",
427 "vl1r.v v11, (t0)",
428 "add {ptr}, t0, {stride}",
429 "add t0, {ptr}, {stride}",
430 "vl1r.v v12, ({ptr})",
431 "vl1r.v v13, (t0)",
432 "add {ptr}, t0, {stride}",
433 "add t0, {ptr}, {stride}",
434 "vl1r.v v14, ({ptr})",
435 "vl1r.v v15, (t0)",
436 "add {ptr}, t0, {stride}",
437 "add t0, {ptr}, {stride}",
438 "vl1r.v v16, ({ptr})",
439 "vl1r.v v17, (t0)",
440 "add {ptr}, t0, {stride}",
441 "add t0, {ptr}, {stride}",
442 "vl1r.v v18, ({ptr})",
443 "vl1r.v v19, (t0)",
444 "add {ptr}, t0, {stride}",
445 "add t0, {ptr}, {stride}",
446 "vl1r.v v20, ({ptr})",
447 "vl1r.v v21, (t0)",
448 "add {ptr}, t0, {stride}",
449 "add t0, {ptr}, {stride}",
450 "vl1r.v v22, ({ptr})",
451 "vl1r.v v23, (t0)",
452 "add {ptr}, t0, {stride}",
453 "add t0, {ptr}, {stride}",
454 "vl1r.v v24, ({ptr})",
455 "vl1r.v v25, (t0)",
456 "add {ptr}, t0, {stride}",
457 "add t0, {ptr}, {stride}",
458 "vl1r.v v26, ({ptr})",
459 "vl1r.v v27, (t0)",
460 "add {ptr}, t0, {stride}",
461 "add t0, {ptr}, {stride}",
462 "vl1r.v v28, ({ptr})",
463 "vl1r.v v29, (t0)",
464 "add {ptr}, t0, {stride}",
465 "add t0, {ptr}, {stride}",
466 "vl1r.v v30, ({ptr})",
467 "vl1r.v v31, (t0)",
468 ".option pop",
469 ptr = inout(reg) ptr => _,
470 stride = in(reg) stride,
471 out("t0") _,
472 options(nostack),
473 );
474 }
475 }
476}
477
478impl Default for VectorContext {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484/// Enable FPU access by setting sstatus.FS to Initial state
485///
486/// This must be called before user space can use floating-point instructions.
487/// The FS field in sstatus controls access to the FPU:
488/// - 00: Off - FPU access causes illegal instruction exception
489/// - 01: Initial - FPU is enabled with initial state
490/// - 10: Clean - FPU is enabled, state has not been modified
491/// - 11: Dirty - FPU is enabled, state has been modified
492#[inline]
493pub fn enable_fpu() {
494 // sstatus.FS bits are at position 13:14
495 // Set to Initial (01) = 0x2000
496 const SSTATUS_FS_INITIAL: usize = 0x2000;
497 const SSTATUS_FS_MASK: usize = 0x6000;
498
499 unsafe {
500 asm!(
501 "csrr {0}, sstatus",
502 "and {0}, {0}, {1}", // Clear FS bits
503 "or {0}, {0}, {2}", // Set FS to Initial
504 "csrw sstatus, {0}",
505 out(reg) _,
506 in(reg) !SSTATUS_FS_MASK,
507 in(reg) SSTATUS_FS_INITIAL,
508 options(nomem, nostack),
509 );
510 }
511}
512
513/// Disable FPU access by setting sstatus.FS to Off.
514///
515/// When FS is Off, any FPU instruction executed (in S/U) raises an illegal
516/// instruction exception. The kernel should re-enable FS temporarily when it
517/// needs to save/restore user state.
518#[inline]
519pub fn disable_fpu() {
520 const SSTATUS_FS_MASK: usize = 0x6000;
521
522 unsafe {
523 asm!(
524 "csrr {0}, sstatus",
525 "and {0}, {0}, {1}",
526 "csrw sstatus, {0}",
527 out(reg) _,
528 in(reg) !SSTATUS_FS_MASK,
529 options(nomem, nostack),
530 );
531 }
532}
533
534/// Check if FPU is enabled (sstatus.FS != Off)
535#[inline]
536pub fn is_fpu_enabled() -> bool {
537 let sstatus: usize;
538 unsafe {
539 asm!(
540 "csrr {0}, sstatus",
541 out(reg) sstatus,
542 options(nomem, nostack),
543 );
544 }
545 // FS bits are at position 13:14
546 (sstatus & 0x6000) != 0
547}
548
549/// Check if the FPU state is marked Dirty in sstatus (FS == 0b11).
550#[inline]
551pub fn is_fpu_dirty() -> bool {
552 let sstatus: usize;
553 unsafe {
554 asm!(
555 "csrr {0}, sstatus",
556 out(reg) sstatus,
557 options(nomem, nostack),
558 );
559 }
560 (sstatus & 0x6000) == 0x6000
561}
562
563/// Mark the FPU state as Clean in sstatus (FS = 0b10).
564///
565/// This is useful after saving/restoring FPU state so that a task that doesn't
566/// touch the FPU in its next timeslice won't incur an unnecessary save.
567#[inline]
568pub fn mark_fpu_clean() {
569 const SSTATUS_FS_CLEAN: usize = 0x4000;
570 const SSTATUS_FS_MASK: usize = 0x6000;
571
572 unsafe {
573 asm!(
574 "csrr {0}, sstatus",
575 "and {0}, {0}, {1}",
576 "or {0}, {0}, {2}",
577 "csrw sstatus, {0}",
578 out(reg) _,
579 in(reg) !SSTATUS_FS_MASK,
580 in(reg) SSTATUS_FS_CLEAN,
581 options(nomem, nostack),
582 );
583 }
584}
585
586/// Enable Vector extension access by setting sstatus.VS to Initial state
587///
588/// This must be called before user space can use vector instructions.
589/// The VS field in sstatus controls access to the Vector extension:
590/// - 00: Off - Vector access causes illegal instruction exception
591/// - 01: Initial - Vector is enabled with initial state
592/// - 10: Clean - Vector is enabled, state has not been modified
593/// - 11: Dirty - Vector is enabled, state has been modified
594#[inline]
595pub fn enable_vector() {
596 // sstatus.VS bits are at position 9:10 (bits 9 and 10)
597 // Set to Initial (01) = 0x200
598 const SSTATUS_VS_INITIAL: usize = 0x200;
599 const SSTATUS_VS_MASK: usize = 0x600;
600
601 unsafe {
602 asm!(
603 "csrr {0}, sstatus",
604 "and {0}, {0}, {1}", // Clear VS bits
605 "or {0}, {0}, {2}", // Set VS to Initial
606 "csrw sstatus, {0}",
607 out(reg) _,
608 in(reg) !SSTATUS_VS_MASK,
609 in(reg) SSTATUS_VS_INITIAL,
610 options(nomem, nostack),
611 );
612 }
613}
614
615/// Disable Vector extension access by setting sstatus.VS to Off.
616#[inline]
617pub fn disable_vector() {
618 const SSTATUS_VS_MASK: usize = 0x600;
619
620 unsafe {
621 asm!(
622 "csrr {0}, sstatus",
623 "and {0}, {0}, {1}",
624 "csrw sstatus, {0}",
625 out(reg) _,
626 in(reg) !SSTATUS_VS_MASK,
627 options(nomem, nostack),
628 );
629 }
630}
631
632/// Check if Vector extension is enabled (sstatus.VS != Off)
633#[inline]
634pub fn is_vector_enabled() -> bool {
635 let sstatus: usize;
636 unsafe {
637 asm!(
638 "csrr {0}, sstatus",
639 out(reg) sstatus,
640 options(nomem, nostack),
641 );
642 }
643 // VS bits are at position 9:10
644 (sstatus & 0x600) != 0
645}
646
647/// Check if the Vector state is marked Dirty in sstatus (VS == 0b11).
648#[inline]
649pub fn is_vector_dirty() -> bool {
650 let sstatus: usize;
651 unsafe {
652 asm!(
653 "csrr {0}, sstatus",
654 out(reg) sstatus,
655 options(nomem, nostack),
656 );
657 }
658 (sstatus & 0x600) == 0x600
659}
660
661/// Mark the Vector state as Clean in sstatus (VS = 0b10).
662#[inline]
663pub fn mark_vector_clean() {
664 const SSTATUS_VS_CLEAN: usize = 0x400;
665 const SSTATUS_VS_MASK: usize = 0x600;
666
667 unsafe {
668 asm!(
669 "csrr {0}, sstatus",
670 "and {0}, {0}, {1}",
671 "or {0}, {0}, {2}",
672 "csrw sstatus, {0}",
673 out(reg) _,
674 in(reg) !SSTATUS_VS_MASK,
675 in(reg) SSTATUS_VS_CLEAN,
676 options(nomem, nostack),
677 );
678 }
679}
680
681/// Get the vector length in bytes (vlenb = VLEN / 8)
682///
683/// Returns 0 if the Vector extension is not available.
684#[inline]
685pub fn get_vlenb() -> usize {
686 if !is_vector_enabled() {
687 return 0;
688 }
689 let vlenb: usize;
690 unsafe {
691 asm!(
692 "csrr {0}, vlenb",
693 out(reg) vlenb,
694 options(nomem, nostack),
695 );
696 }
697 vlenb
698}