kernel/arch/riscv64/
context.rs

1//! Kernel context switching for RISC-V 64-bit
2//!
3//! This module implements kernel-level context switching between tasks.
4//! It handles saving and restoring callee-saved registers when switching
5//! between kernel threads.
6
7use alloc::boxed::Box;
8use core::arch::naked_asm;
9
10use crate::arch::Trapframe;
11use crate::mem::page::{Page, allocate_boxed_pages};
12use crate::vm::vmem::MemoryArea;
13
14/// Kernel context for RISC-V 64-bit
15///
16/// Contains callee-saved registers that need to be preserved across
17/// function calls and context switches in kernel mode, as well as
18/// the kernel stack information.
19#[repr(C, align(16))]
20#[derive(Debug, Clone)]
21pub struct KernelContext {
22    /// Stack pointer
23    pub sp: u64,
24    /// Return address
25    pub ra: u64,
26    /// Saved registers s0-s11 (callee-saved)
27    pub s: [u64; 12],
28    /// Kernel stack pages for this context (page-aligned, contiguous)
29    pub kernel_stack: Box<[Page]>,
30}
31
32impl KernelContext {
33    /// Create a new kernel context with kernel stack
34    ///
35    /// # Returns
36    /// A new KernelContext with allocated kernel stack ready for scheduling
37    pub fn new() -> Self {
38        // Allocate page-aligned contiguous pages for the kernel stack
39        let num_pages = crate::environment::TASK_KERNEL_STACK_SIZE / crate::environment::PAGE_SIZE;
40        let kernel_stack = allocate_boxed_pages(num_pages);
41        let stack_top = kernel_stack.as_ptr() as u64
42            + (kernel_stack.len() * crate::environment::PAGE_SIZE) as u64;
43
44        let trapframe_size = core::mem::size_of::<Trapframe>() as u64;
45        let trapframe_align = core::mem::align_of::<Trapframe>() as u64;
46        debug_assert!(trapframe_align.is_power_of_two());
47        let trapframe_addr = (stack_top - trapframe_size) & !(trapframe_align - 1);
48
49        Self {
50            sp: trapframe_addr, // Reserve aligned space for trapframe
51            ra: crate::task::task_initial_kernel_entrypoint as u64,
52            s: [0; 12],
53            kernel_stack,
54        }
55    }
56
57    /// Get the bottom of the kernel stack
58    pub fn get_kernel_stack_bottom_paddr(&self) -> u64 {
59        (self.kernel_stack.as_ptr() as u64)
60            + (self.kernel_stack.len() as u64 * crate::environment::PAGE_SIZE as u64)
61    }
62
63    pub fn get_kernel_stack_memory_area_paddr(&self) -> MemoryArea {
64        MemoryArea::new(
65            self.kernel_stack.as_ptr() as usize,
66            (self.get_kernel_stack_bottom_paddr() as usize) - 1,
67        )
68    }
69
70    pub fn get_kernel_stack_paddr(&self) -> *const u8 {
71        self.kernel_stack.as_ptr() as *const u8
72    }
73
74    /// Set entry point for this context
75    ///
76    /// # Arguments
77    /// * `entry_point` - Function address to set as entry point
78    ///
79    pub fn set_entry_point(&mut self, entry_point: u64) {
80        self.ra = entry_point;
81    }
82
83    /// Get entry point of this context
84    ///
85    /// # Returns
86    ///
87    /// Function address of the entry point
88    pub fn get_entry_point(&self) -> u64 {
89        self.ra
90    }
91
92    // Set stack pointer for this context (VA)
93    pub fn set_sp(&mut self, sp_vaddr: u64) {
94        self.sp = sp_vaddr;
95    }
96
97    /// Get a mutable reference to the trapframe
98    ///
99    /// The trapframe is located at the top of the kernel stack, reserved during
100    /// context creation. This provides access to the user-space register state.
101    ///
102    /// # Returns
103    /// A mutable reference to the Trapframe, or None if no kernel stack is allocated
104    pub fn get_trapframe(&mut self) -> &mut Trapframe {
105        let stack_top = self.kernel_stack.as_ptr() as usize
106            + (self.kernel_stack.len() * crate::environment::PAGE_SIZE);
107        let trapframe_size = core::mem::size_of::<Trapframe>();
108        let trapframe_align = core::mem::align_of::<Trapframe>();
109        debug_assert!(trapframe_align.is_power_of_two());
110
111        let trapframe_addr = (stack_top - trapframe_size) & !(trapframe_align - 1);
112        debug_assert_eq!(trapframe_addr % trapframe_align, 0);
113        unsafe { &mut *(trapframe_addr as *mut Trapframe) }
114    }
115}
116
117/// Switch from current context to target context
118///
119/// This function saves the current kernel context and loads the target context.
120/// When the target task is later switched away from, it will resume execution
121/// right after this function call.
122///
123/// # Arguments
124/// * `current` - Pointer to current task's kernel context (will be saved)
125/// * `target` - Pointer to target task's kernel context (will be loaded)
126///
127/// # Safety
128/// This function manipulates CPU registers directly and must only be called
129/// with valid context pointers. The caller must ensure proper stack alignment
130/// and that both contexts point to valid memory.
131#[unsafe(naked)]
132pub unsafe extern "C" fn switch_to(current: *mut KernelContext, target: *const KernelContext) {
133    naked_asm!(
134        // Save current context
135        "sd sp, 0(a0)",    // Save stack pointer
136        "sd ra, 8(a0)",    // Save return address
137        "sd s0, 16(a0)",   // Save s0
138        "sd s1, 24(a0)",   // Save s1
139        "sd s2, 32(a0)",   // Save s2
140        "sd s3, 40(a0)",   // Save s3
141        "sd s4, 48(a0)",   // Save s4
142        "sd s5, 56(a0)",   // Save s5
143        "sd s6, 64(a0)",   // Save s6
144        "sd s7, 72(a0)",   // Save s7
145        "sd s8, 80(a0)",   // Save s8
146        "sd s9, 88(a0)",   // Save s9
147        "sd s10, 96(a0)",  // Save s10
148        "sd s11, 104(a0)", // Save s11
149        // Load target context
150        "ld sp, 0(a1)",    // Load stack pointer
151        "ld ra, 8(a1)",    // Load return address
152        "ld s0, 16(a1)",   // Load s0
153        "ld s1, 24(a1)",   // Load s1
154        "ld s2, 32(a1)",   // Load s2
155        "ld s3, 40(a1)",   // Load s3
156        "ld s4, 48(a1)",   // Load s4
157        "ld s5, 56(a1)",   // Load s5
158        "ld s6, 64(a1)",   // Load s6
159        "ld s7, 72(a1)",   // Load s7
160        "ld s8, 80(a1)",   // Load s8
161        "ld s9, 88(a1)",   // Load s9
162        "ld s10, 96(a1)",  // Load s10
163        "ld s11, 104(a1)", // Load s11
164        // Return to target context
165        "ret",
166    );
167}