kernel/arch/riscv64/
fpu.rs

1//! Floating-Point Unit and Vector context for RISC-V 64-bit
2//!
3//! This module provides the FPU and Vector context structures for saving and restoring
4//! floating-point and vector register state during context switches.
5//!
6//! ## FPU (F/D Extensions)
7//! RISC-V uses the F (single-precision) and D (double-precision) extensions with
8//! 32 floating-point registers (f0-f31, each 64-bit for D extension) and fcsr
9//! control/status register.
10//!
11//! ## Vector (V Extension)
12//! RISC-V Vector extension provides 32 vector registers (v0-v31) with configurable
13//! VLEN (vector length). The actual size depends on the implementation. This module
14//! supports VLEN up to 256 bits (32 bytes per register, vlenb=32).
15
16use core::arch::asm;
17
18mod fpu_switch;
19
20pub use fpu_switch::{
21    kernel_switch_in_user_fpu, kernel_switch_out_user_fpu, kernel_switch_out_user_vector,
22};
23
24/// FPU context for RISC-V 64-bit (F/D extensions)
25///
26/// Contains all floating-point registers and the floating-point control/status register.
27/// This is saved/restored during task context switches to preserve FPU state.
28#[repr(C, align(16))]
29#[derive(Debug, Clone)]
30pub struct FpuContext {
31    /// Floating-point registers f0-f31 (64-bit each for D extension)
32    pub f: [u64; 32],
33    /// Floating-point control and status register (fcsr)
34    pub fcsr: u32,
35}
36
37impl FpuContext {
38    /// Create a new zeroed FPU context
39    pub const fn new() -> Self {
40        Self {
41            f: [0; 32],
42            fcsr: 0,
43        }
44    }
45
46    /// Save the current FPU state to this context
47    ///
48    /// # Safety
49    /// This function directly accesses FPU registers. The FPU must be enabled
50    /// (sstatus.FS != Off) before calling this function.
51    #[inline]
52    pub unsafe fn save(&mut self) {
53        let ptr = self.f.as_mut_ptr();
54        unsafe {
55            asm!(
56                ".option push",
57                ".option arch, +f, +d",
58                // Save all 32 floating-point registers
59                "fsd f0, 0*8({0})",
60                "fsd f1, 1*8({0})",
61                "fsd f2, 2*8({0})",
62                "fsd f3, 3*8({0})",
63                "fsd f4, 4*8({0})",
64                "fsd f5, 5*8({0})",
65                "fsd f6, 6*8({0})",
66                "fsd f7, 7*8({0})",
67                "fsd f8, 8*8({0})",
68                "fsd f9, 9*8({0})",
69                "fsd f10, 10*8({0})",
70                "fsd f11, 11*8({0})",
71                "fsd f12, 12*8({0})",
72                "fsd f13, 13*8({0})",
73                "fsd f14, 14*8({0})",
74                "fsd f15, 15*8({0})",
75                "fsd f16, 16*8({0})",
76                "fsd f17, 17*8({0})",
77                "fsd f18, 18*8({0})",
78                "fsd f19, 19*8({0})",
79                "fsd f20, 20*8({0})",
80                "fsd f21, 21*8({0})",
81                "fsd f22, 22*8({0})",
82                "fsd f23, 23*8({0})",
83                "fsd f24, 24*8({0})",
84                "fsd f25, 25*8({0})",
85                "fsd f26, 26*8({0})",
86                "fsd f27, 27*8({0})",
87                "fsd f28, 28*8({0})",
88                "fsd f29, 29*8({0})",
89                "fsd f30, 30*8({0})",
90                "fsd f31, 31*8({0})",
91                ".option pop",
92                in(reg) ptr,
93                options(nostack),
94            );
95        }
96        // Save fcsr
97        let fcsr: u32;
98        unsafe {
99            asm!(
100                ".option push",
101                ".option arch, +f, +d",
102                "frcsr {0}",
103                ".option pop",
104                out(reg) fcsr,
105                options(nomem, nostack),
106            );
107        }
108        self.fcsr = fcsr;
109    }
110
111    /// Restore the FPU state from this context
112    ///
113    /// # Safety
114    /// This function directly accesses FPU registers. The FPU must be enabled
115    /// (sstatus.FS != Off) before calling this function.
116    #[inline]
117    pub unsafe fn restore(&self) {
118        // Restore fcsr first
119        unsafe {
120            asm!(
121                ".option push",
122                ".option arch, +f, +d",
123                "fscsr {0}",
124                ".option pop",
125                in(reg) self.fcsr,
126                options(nomem, nostack),
127            );
128        }
129        let ptr = self.f.as_ptr();
130        unsafe {
131            asm!(
132                ".option push",
133                ".option arch, +f, +d",
134                // Restore all 32 floating-point registers
135                "fld f0, 0*8({0})",
136                "fld f1, 1*8({0})",
137                "fld f2, 2*8({0})",
138                "fld f3, 3*8({0})",
139                "fld f4, 4*8({0})",
140                "fld f5, 5*8({0})",
141                "fld f6, 6*8({0})",
142                "fld f7, 7*8({0})",
143                "fld f8, 8*8({0})",
144                "fld f9, 9*8({0})",
145                "fld f10, 10*8({0})",
146                "fld f11, 11*8({0})",
147                "fld f12, 12*8({0})",
148                "fld f13, 13*8({0})",
149                "fld f14, 14*8({0})",
150                "fld f15, 15*8({0})",
151                "fld f16, 16*8({0})",
152                "fld f17, 17*8({0})",
153                "fld f18, 18*8({0})",
154                "fld f19, 19*8({0})",
155                "fld f20, 20*8({0})",
156                "fld f21, 21*8({0})",
157                "fld f22, 22*8({0})",
158                "fld f23, 23*8({0})",
159                "fld f24, 24*8({0})",
160                "fld f25, 25*8({0})",
161                "fld f26, 26*8({0})",
162                "fld f27, 27*8({0})",
163                "fld f28, 28*8({0})",
164                "fld f29, 29*8({0})",
165                "fld f30, 30*8({0})",
166                "fld f31, 31*8({0})",
167                ".option pop",
168                in(reg) ptr,
169                options(nostack),
170            );
171        }
172    }
173}
174
175impl Default for FpuContext {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Maximum vector length in bytes (VLEN / 8) supported by this implementation.
182/// This supports VLEN up to 256 bits (32 bytes per register).
183/// QEMU virt machine typically uses VLEN=128 (vlenb=16).
184pub const MAX_VLENB: usize = 32;
185
186/// Vector context for RISC-V 64-bit (V extension)
187///
188/// Contains all vector registers and vector CSRs.
189/// The actual number of bytes used per register depends on the implementation's VLEN.
190/// This structure reserves space for VLEN up to 256 bits.
191#[repr(C, align(16))]
192#[derive(Debug, Clone)]
193pub struct VectorContext {
194    /// Vector registers v0-v31 (up to 256 bits = 32 bytes each)
195    /// Stored as arrays of u64 for alignment
196    pub v: [[u64; MAX_VLENB / 8]; 32],
197    /// Vector type register (vtype)
198    pub vtype: u64,
199    /// Vector length register (vl)
200    pub vl: u64,
201    /// Vector start index register (vstart)
202    pub vstart: u64,
203    /// Vector fixed-point rounding mode register (vxrm)
204    pub vxrm: u64,
205    /// Vector fixed-point saturation flag (vxsat)
206    pub vxsat: u64,
207    /// Vector control and status register (vcsr) - combines vxrm and vxsat
208    pub vcsr: u64,
209    /// Cached vlenb value (VLEN/8 in bytes)
210    pub vlenb: u64,
211}
212
213impl VectorContext {
214    /// Create a new zeroed Vector context
215    pub const fn new() -> Self {
216        Self {
217            v: [[0; MAX_VLENB / 8]; 32],
218            vtype: 0,
219            vl: 0,
220            vstart: 0,
221            vxrm: 0,
222            vxsat: 0,
223            vcsr: 0,
224            vlenb: 0,
225        }
226    }
227
228    /// Save the current Vector state to this context
229    ///
230    /// # Safety
231    /// This function directly accesses Vector registers. The Vector extension must
232    /// be enabled (sstatus.VS != Off) before calling this function.
233    #[inline]
234    pub unsafe fn save(&mut self) {
235        // Read vlenb to know the actual vector register size
236        let vlenb: u64;
237        unsafe {
238            asm!(
239                ".option push",
240                ".option arch, +v",
241                "csrr {0}, vlenb",
242                ".option pop",
243                out(reg) vlenb,
244                options(nomem, nostack),
245            );
246        }
247        self.vlenb = vlenb;
248
249        // Save vector CSRs
250        unsafe {
251            asm!(
252                ".option push",
253                ".option arch, +v",
254                "csrr {0}, vtype",
255                "csrr {1}, vl",
256                "csrr {2}, vstart",
257                "csrr {3}, vcsr",
258                ".option pop",
259                out(reg) self.vtype,
260                out(reg) self.vl,
261                out(reg) self.vstart,
262                out(reg) self.vcsr,
263                options(nomem, nostack),
264            );
265        }
266
267        // Extract vxrm and vxsat from vcsr
268        self.vxrm = (self.vcsr >> 1) & 0x3;
269        self.vxsat = self.vcsr & 0x1;
270
271        // Save vector registers using vs1r.v (whole register store).
272        // Use the runtime vlenb as the stride so we only touch the bytes that
273        // the implementation actually uses. This avoids unnecessary memory
274        // traffic (e.g. QEMU virt often uses vlenb=16).
275        let ptr = self.v.as_mut_ptr() as *mut u8;
276        let stride = vlenb as usize;
277
278        // Use inline assembly to save each vector register
279        // vs1r.v stores one vector register (VLEN bits)
280        unsafe {
281            asm!(
282                ".option push",
283                ".option arch, +v",
284                "add t0, {ptr}, {stride}",
285                "vs1r.v v0, ({ptr})",
286                "vs1r.v v1, (t0)",
287                "add {ptr}, t0, {stride}",
288                "add t0, {ptr}, {stride}",
289                "vs1r.v v2, ({ptr})",
290                "vs1r.v v3, (t0)",
291                "add {ptr}, t0, {stride}",
292                "add t0, {ptr}, {stride}",
293                "vs1r.v v4, ({ptr})",
294                "vs1r.v v5, (t0)",
295                "add {ptr}, t0, {stride}",
296                "add t0, {ptr}, {stride}",
297                "vs1r.v v6, ({ptr})",
298                "vs1r.v v7, (t0)",
299                "add {ptr}, t0, {stride}",
300                "add t0, {ptr}, {stride}",
301                "vs1r.v v8, ({ptr})",
302                "vs1r.v v9, (t0)",
303                "add {ptr}, t0, {stride}",
304                "add t0, {ptr}, {stride}",
305                "vs1r.v v10, ({ptr})",
306                "vs1r.v v11, (t0)",
307                "add {ptr}, t0, {stride}",
308                "add t0, {ptr}, {stride}",
309                "vs1r.v v12, ({ptr})",
310                "vs1r.v v13, (t0)",
311                "add {ptr}, t0, {stride}",
312                "add t0, {ptr}, {stride}",
313                "vs1r.v v14, ({ptr})",
314                "vs1r.v v15, (t0)",
315                "add {ptr}, t0, {stride}",
316                "add t0, {ptr}, {stride}",
317                "vs1r.v v16, ({ptr})",
318                "vs1r.v v17, (t0)",
319                "add {ptr}, t0, {stride}",
320                "add t0, {ptr}, {stride}",
321                "vs1r.v v18, ({ptr})",
322                "vs1r.v v19, (t0)",
323                "add {ptr}, t0, {stride}",
324                "add t0, {ptr}, {stride}",
325                "vs1r.v v20, ({ptr})",
326                "vs1r.v v21, (t0)",
327                "add {ptr}, t0, {stride}",
328                "add t0, {ptr}, {stride}",
329                "vs1r.v v22, ({ptr})",
330                "vs1r.v v23, (t0)",
331                "add {ptr}, t0, {stride}",
332                "add t0, {ptr}, {stride}",
333                "vs1r.v v24, ({ptr})",
334                "vs1r.v v25, (t0)",
335                "add {ptr}, t0, {stride}",
336                "add t0, {ptr}, {stride}",
337                "vs1r.v v26, ({ptr})",
338                "vs1r.v v27, (t0)",
339                "add {ptr}, t0, {stride}",
340                "add t0, {ptr}, {stride}",
341                "vs1r.v v28, ({ptr})",
342                "vs1r.v v29, (t0)",
343                "add {ptr}, t0, {stride}",
344                "add t0, {ptr}, {stride}",
345                "vs1r.v v30, ({ptr})",
346                "vs1r.v v31, (t0)",
347                ".option pop",
348                ptr = inout(reg) ptr => _,
349                stride = in(reg) stride,
350                out("t0") _,
351                options(nostack),
352            );
353        }
354    }
355
356    /// Restore the Vector state from this context
357    ///
358    /// # Safety
359    /// This function directly accesses Vector registers. The Vector extension must
360    /// be enabled (sstatus.VS != Off) before calling this function.
361    #[inline]
362    pub unsafe fn restore(&self) {
363        // Restore vector CSRs first
364        unsafe {
365            asm!(
366                ".option push",
367                ".option arch, +v",
368                "csrw vstart, {0}",
369                "csrw vcsr, {1}",
370                ".option pop",
371                in(reg) self.vstart,
372                in(reg) self.vcsr,
373                options(nomem, nostack),
374            );
375        }
376
377        // Restore vtype and vl using vsetvl
378        // This sets both vtype and vl atomically
379        unsafe {
380            asm!(
381                ".option push",
382                ".option arch, +v",
383                "vsetvl x0, {0}, {1}",
384                ".option pop",
385                in(reg) self.vl,
386                in(reg) self.vtype,
387                options(nomem, nostack),
388            );
389        }
390
391        // Restore vector registers using vl1r.v (whole register load)
392        let ptr = self.v.as_ptr() as *const u8;
393        // Use the saved vlenb if available; fall back to MAX_VLENB for
394        // never-saved (zero-initial) contexts.
395        let stride = if self.vlenb == 0 {
396            MAX_VLENB
397        } else {
398            self.vlenb as usize
399        };
400
401        unsafe {
402            asm!(
403                ".option push",
404                ".option arch, +v",
405                "add t0, {ptr}, {stride}",
406                "vl1r.v v0, ({ptr})",
407                "vl1r.v v1, (t0)",
408                "add {ptr}, t0, {stride}",
409                "add t0, {ptr}, {stride}",
410                "vl1r.v v2, ({ptr})",
411                "vl1r.v v3, (t0)",
412                "add {ptr}, t0, {stride}",
413                "add t0, {ptr}, {stride}",
414                "vl1r.v v4, ({ptr})",
415                "vl1r.v v5, (t0)",
416                "add {ptr}, t0, {stride}",
417                "add t0, {ptr}, {stride}",
418                "vl1r.v v6, ({ptr})",
419                "vl1r.v v7, (t0)",
420                "add {ptr}, t0, {stride}",
421                "add t0, {ptr}, {stride}",
422                "vl1r.v v8, ({ptr})",
423                "vl1r.v v9, (t0)",
424                "add {ptr}, t0, {stride}",
425                "add t0, {ptr}, {stride}",
426                "vl1r.v v10, ({ptr})",
427                "vl1r.v v11, (t0)",
428                "add {ptr}, t0, {stride}",
429                "add t0, {ptr}, {stride}",
430                "vl1r.v v12, ({ptr})",
431                "vl1r.v v13, (t0)",
432                "add {ptr}, t0, {stride}",
433                "add t0, {ptr}, {stride}",
434                "vl1r.v v14, ({ptr})",
435                "vl1r.v v15, (t0)",
436                "add {ptr}, t0, {stride}",
437                "add t0, {ptr}, {stride}",
438                "vl1r.v v16, ({ptr})",
439                "vl1r.v v17, (t0)",
440                "add {ptr}, t0, {stride}",
441                "add t0, {ptr}, {stride}",
442                "vl1r.v v18, ({ptr})",
443                "vl1r.v v19, (t0)",
444                "add {ptr}, t0, {stride}",
445                "add t0, {ptr}, {stride}",
446                "vl1r.v v20, ({ptr})",
447                "vl1r.v v21, (t0)",
448                "add {ptr}, t0, {stride}",
449                "add t0, {ptr}, {stride}",
450                "vl1r.v v22, ({ptr})",
451                "vl1r.v v23, (t0)",
452                "add {ptr}, t0, {stride}",
453                "add t0, {ptr}, {stride}",
454                "vl1r.v v24, ({ptr})",
455                "vl1r.v v25, (t0)",
456                "add {ptr}, t0, {stride}",
457                "add t0, {ptr}, {stride}",
458                "vl1r.v v26, ({ptr})",
459                "vl1r.v v27, (t0)",
460                "add {ptr}, t0, {stride}",
461                "add t0, {ptr}, {stride}",
462                "vl1r.v v28, ({ptr})",
463                "vl1r.v v29, (t0)",
464                "add {ptr}, t0, {stride}",
465                "add t0, {ptr}, {stride}",
466                "vl1r.v v30, ({ptr})",
467                "vl1r.v v31, (t0)",
468                ".option pop",
469                ptr = inout(reg) ptr => _,
470                stride = in(reg) stride,
471                out("t0") _,
472                options(nostack),
473            );
474        }
475    }
476}
477
478impl Default for VectorContext {
479    fn default() -> Self {
480        Self::new()
481    }
482}
483
484/// Enable FPU access by setting sstatus.FS to Initial state
485///
486/// This must be called before user space can use floating-point instructions.
487/// The FS field in sstatus controls access to the FPU:
488/// - 00: Off - FPU access causes illegal instruction exception
489/// - 01: Initial - FPU is enabled with initial state
490/// - 10: Clean - FPU is enabled, state has not been modified
491/// - 11: Dirty - FPU is enabled, state has been modified
492#[inline]
493pub fn enable_fpu() {
494    // sstatus.FS bits are at position 13:14
495    // Set to Initial (01) = 0x2000
496    const SSTATUS_FS_INITIAL: usize = 0x2000;
497    const SSTATUS_FS_MASK: usize = 0x6000;
498
499    unsafe {
500        asm!(
501            "csrr {0}, sstatus",
502            "and {0}, {0}, {1}",  // Clear FS bits
503            "or {0}, {0}, {2}",   // Set FS to Initial
504            "csrw sstatus, {0}",
505            out(reg) _,
506            in(reg) !SSTATUS_FS_MASK,
507            in(reg) SSTATUS_FS_INITIAL,
508            options(nomem, nostack),
509        );
510    }
511}
512
513/// Disable FPU access by setting sstatus.FS to Off.
514///
515/// When FS is Off, any FPU instruction executed (in S/U) raises an illegal
516/// instruction exception. The kernel should re-enable FS temporarily when it
517/// needs to save/restore user state.
518#[inline]
519pub fn disable_fpu() {
520    const SSTATUS_FS_MASK: usize = 0x6000;
521
522    unsafe {
523        asm!(
524            "csrr {0}, sstatus",
525            "and {0}, {0}, {1}",
526            "csrw sstatus, {0}",
527            out(reg) _,
528            in(reg) !SSTATUS_FS_MASK,
529            options(nomem, nostack),
530        );
531    }
532}
533
534/// Check if FPU is enabled (sstatus.FS != Off)
535#[inline]
536pub fn is_fpu_enabled() -> bool {
537    let sstatus: usize;
538    unsafe {
539        asm!(
540            "csrr {0}, sstatus",
541            out(reg) sstatus,
542            options(nomem, nostack),
543        );
544    }
545    // FS bits are at position 13:14
546    (sstatus & 0x6000) != 0
547}
548
549/// Check if the FPU state is marked Dirty in sstatus (FS == 0b11).
550#[inline]
551pub fn is_fpu_dirty() -> bool {
552    let sstatus: usize;
553    unsafe {
554        asm!(
555            "csrr {0}, sstatus",
556            out(reg) sstatus,
557            options(nomem, nostack),
558        );
559    }
560    (sstatus & 0x6000) == 0x6000
561}
562
563/// Mark the FPU state as Clean in sstatus (FS = 0b10).
564///
565/// This is useful after saving/restoring FPU state so that a task that doesn't
566/// touch the FPU in its next timeslice won't incur an unnecessary save.
567#[inline]
568pub fn mark_fpu_clean() {
569    const SSTATUS_FS_CLEAN: usize = 0x4000;
570    const SSTATUS_FS_MASK: usize = 0x6000;
571
572    unsafe {
573        asm!(
574            "csrr {0}, sstatus",
575            "and {0}, {0}, {1}",
576            "or {0}, {0}, {2}",
577            "csrw sstatus, {0}",
578            out(reg) _,
579            in(reg) !SSTATUS_FS_MASK,
580            in(reg) SSTATUS_FS_CLEAN,
581            options(nomem, nostack),
582        );
583    }
584}
585
586/// Enable Vector extension access by setting sstatus.VS to Initial state
587///
588/// This must be called before user space can use vector instructions.
589/// The VS field in sstatus controls access to the Vector extension:
590/// - 00: Off - Vector access causes illegal instruction exception
591/// - 01: Initial - Vector is enabled with initial state
592/// - 10: Clean - Vector is enabled, state has not been modified
593/// - 11: Dirty - Vector is enabled, state has been modified
594#[inline]
595pub fn enable_vector() {
596    // sstatus.VS bits are at position 9:10 (bits 9 and 10)
597    // Set to Initial (01) = 0x200
598    const SSTATUS_VS_INITIAL: usize = 0x200;
599    const SSTATUS_VS_MASK: usize = 0x600;
600
601    unsafe {
602        asm!(
603            "csrr {0}, sstatus",
604            "and {0}, {0}, {1}",  // Clear VS bits
605            "or {0}, {0}, {2}",   // Set VS to Initial
606            "csrw sstatus, {0}",
607            out(reg) _,
608            in(reg) !SSTATUS_VS_MASK,
609            in(reg) SSTATUS_VS_INITIAL,
610            options(nomem, nostack),
611        );
612    }
613}
614
615/// Disable Vector extension access by setting sstatus.VS to Off.
616#[inline]
617pub fn disable_vector() {
618    const SSTATUS_VS_MASK: usize = 0x600;
619
620    unsafe {
621        asm!(
622            "csrr {0}, sstatus",
623            "and {0}, {0}, {1}",
624            "csrw sstatus, {0}",
625            out(reg) _,
626            in(reg) !SSTATUS_VS_MASK,
627            options(nomem, nostack),
628        );
629    }
630}
631
632/// Check if Vector extension is enabled (sstatus.VS != Off)
633#[inline]
634pub fn is_vector_enabled() -> bool {
635    let sstatus: usize;
636    unsafe {
637        asm!(
638            "csrr {0}, sstatus",
639            out(reg) sstatus,
640            options(nomem, nostack),
641        );
642    }
643    // VS bits are at position 9:10
644    (sstatus & 0x600) != 0
645}
646
647/// Check if the Vector state is marked Dirty in sstatus (VS == 0b11).
648#[inline]
649pub fn is_vector_dirty() -> bool {
650    let sstatus: usize;
651    unsafe {
652        asm!(
653            "csrr {0}, sstatus",
654            out(reg) sstatus,
655            options(nomem, nostack),
656        );
657    }
658    (sstatus & 0x600) == 0x600
659}
660
661/// Mark the Vector state as Clean in sstatus (VS = 0b10).
662#[inline]
663pub fn mark_vector_clean() {
664    const SSTATUS_VS_CLEAN: usize = 0x400;
665    const SSTATUS_VS_MASK: usize = 0x600;
666
667    unsafe {
668        asm!(
669            "csrr {0}, sstatus",
670            "and {0}, {0}, {1}",
671            "or {0}, {0}, {2}",
672            "csrw sstatus, {0}",
673            out(reg) _,
674            in(reg) !SSTATUS_VS_MASK,
675            in(reg) SSTATUS_VS_CLEAN,
676            options(nomem, nostack),
677        );
678    }
679}
680
681/// Get the vector length in bytes (vlenb = VLEN / 8)
682///
683/// Returns 0 if the Vector extension is not available.
684#[inline]
685pub fn get_vlenb() -> usize {
686    if !is_vector_enabled() {
687        return 0;
688    }
689    let vlenb: usize;
690    unsafe {
691        asm!(
692            "csrr {0}, vlenb",
693            out(reg) vlenb,
694            options(nomem, nostack),
695        );
696    }
697    vlenb
698}