kernel/
timer.rs

1//! Kernel timer module.
2//!
3//! This module provides the kernel timer functionality, which is responsible for
4//! managing the system timer and scheduling tasks based on time intervals.
5//!
6
7use crate::arch::Trapframe;
8use crate::arch::timer::ArchTimer;
9use crate::environment::MAX_NUM_CPUS;
10use crate::sched::scheduler::get_scheduler;
11use core::cell::UnsafeCell;
12use core::sync::atomic::{AtomicU64, Ordering};
13extern crate alloc;
14use alloc::collections::BinaryHeap;
15use alloc::sync::{Arc, Weak};
16use core::cmp::Ordering as CmpOrdering;
17
18pub struct KernelTimer {
19    // SAFETY: Each CPU only accesses its own timer via cpu_id index.
20    // UnsafeCell allows per-CPU mutable access without data races.
21    core_local_timer: [UnsafeCell<ArchTimer>; MAX_NUM_CPUS],
22    pub interval: u64,
23}
24
25// SAFETY: KernelTimer is thread-safe because each CPU only accesses its own timer.
26// The ArchTimer instances are per-CPU, and the hardware registers are CPU-local.
27unsafe impl Sync for KernelTimer {}
28
29static KERNEL_TIMER: spin::Once<KernelTimer> = spin::Once::new();
30
31pub fn get_kernel_timer() -> &'static KernelTimer {
32    KERNEL_TIMER.call_once(|| KernelTimer::new())
33}
34
35impl KernelTimer {
36    fn new() -> Self {
37        KernelTimer {
38            core_local_timer: core::array::from_fn(|_| UnsafeCell::new(ArchTimer::new())),
39            interval: 0xffffffff_ffffffff,
40        }
41    }
42
43    /// Initialize the timer for a specific CPU.
44    /// This must be called by each CPU individually during its initialization.
45    ///
46    /// # Arguments
47    /// * `cpu_id` - The ID of the CPU whose timer should be initialized
48    pub fn init(&self, cpu_id: usize) {
49        // SAFETY: Only the specified CPU's timer is accessed, maintaining
50        // the per-CPU access invariant.
51        unsafe { (*self.core_local_timer[cpu_id].get()).stop() };
52    }
53
54    pub fn start(&self, cpu_id: usize) {
55        // SAFETY: Each CPU only accesses its own timer
56        unsafe { (*self.core_local_timer[cpu_id].get()).start() };
57    }
58
59    pub fn stop(&self, cpu_id: usize) {
60        // SAFETY: Each CPU only accesses its own timer
61        unsafe { (*self.core_local_timer[cpu_id].get()).stop() };
62    }
63
64    pub fn restart(&self, cpu_id: usize) {
65        self.stop(cpu_id);
66        self.start(cpu_id);
67    }
68
69    /* Set the interval in microseconds */
70    pub fn set_interval_us(&self, cpu_id: usize, interval: u64) {
71        // SAFETY: Each CPU only accesses its own timer
72        unsafe { (*self.core_local_timer[cpu_id].get()).set_interval_us(interval) };
73    }
74
75    pub fn get_time_us(&self, cpu_id: usize) -> u64 {
76        // SAFETY: Each CPU only accesses its own timer
77        unsafe { (*self.core_local_timer[cpu_id].get()).get_time_us() }
78    }
79}
80
81// Global tick counter (monotonic, incremented by timer interrupt)
82static TICK_COUNT: AtomicU64 = AtomicU64::new(0);
83
84/// Increment the global tick counter. Call this from the timer interrupt handler.
85pub fn tick(trapframe: &mut Trapframe) {
86    let cpu_id = crate::arch::get_cpu().get_cpuid();
87    let timer = get_kernel_timer();
88    timer.set_interval_us(cpu_id, TICK_INTERVAL_US);
89    timer.start(cpu_id);
90    let now = TICK_COUNT.fetch_add(1, Ordering::Relaxed) + 1;
91    check_software_timers(now);
92    // Call scheduler tick handler to manage time slices
93    let scheduler = get_scheduler();
94    // crate::println!("[timer] Tick: {}, CPU: {}", now, cpu_id);
95    scheduler.on_tick(cpu_id, trapframe);
96}
97
98/// Get the current tick count (monotonic, since boot)
99pub fn get_tick() -> u64 {
100    TICK_COUNT.load(Ordering::Relaxed)
101}
102
103pub fn get_time_ns() -> u64 {
104    let cpu_id = crate::arch::get_cpu().get_cpuid();
105    let timer = get_kernel_timer();
106    timer.get_time_us(cpu_id) * 1_000
107}
108
109pub fn get_time_us() -> u64 {
110    let cpu_id = crate::arch::get_cpu().get_cpuid();
111    let timer = get_kernel_timer();
112    timer.get_time_us(cpu_id)
113}
114
115/// Trait for timer expiration callback
116pub trait TimerHandler: Send + Sync {
117    fn on_timer_expired(self: Arc<Self>, context: usize);
118}
119
120/// Software timer structure
121pub struct SoftwareTimer {
122    pub id: u64,                         // Unique timer ID
123    pub expires: u64,                    // Expiration tick
124    pub handler: Weak<dyn TimerHandler>, // Weak reference to callback handler
125    pub context: usize,                  // User context
126    pub active: bool,                    // Is this timer active?
127}
128
129// Global timer ID counter
130static TIMER_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
131
132impl PartialEq for SoftwareTimer {
133    fn eq(&self, other: &Self) -> bool {
134        self.expires == other.expires
135            && self.context == other.context
136            && self.active == other.active
137    }
138}
139
140impl Eq for SoftwareTimer {}
141
142impl Ord for SoftwareTimer {
143    fn cmp(&self, other: &Self) -> CmpOrdering {
144        // Reverse order for min-heap (BinaryHeap is max-heap by default)
145        other.expires.cmp(&self.expires)
146    }
147}
148
149impl PartialOrd for SoftwareTimer {
150    fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
151        Some(self.cmp(other))
152    }
153}
154
155use alloc::collections::BTreeMap;
156use spin::{Mutex, RwLock};
157
158// Heap-based timer list (protected by spin::Mutex)
159static SOFTWARE_TIMER_HEAP: Mutex<BinaryHeap<SoftwareTimer>> = Mutex::new(BinaryHeap::new());
160
161// Active timer flags (protected by RwLock for efficient concurrent reads)
162// Maps timer ID -> active status
163static TIMER_ACTIVE_FLAGS: RwLock<BTreeMap<u64, bool>> = RwLock::new(BTreeMap::new());
164
165/// Add a new software timer. Returns timer id.
166pub fn add_timer(expires: u64, handler: &Arc<dyn TimerHandler>, context: usize) -> u64 {
167    let id = TIMER_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
168    let timer = SoftwareTimer {
169        id,
170        expires,
171        handler: Arc::downgrade(handler),
172        context,
173        active: true,
174    };
175
176    // Mark as active in the flags map
177    TIMER_ACTIVE_FLAGS.write().insert(id, true);
178
179    SOFTWARE_TIMER_HEAP.lock().push(timer);
180    id
181}
182
183/// Cancel a timer by id (O(1) operation - just marks as inactive)
184pub fn cancel_timer(id: u64) {
185    // Simply mark as inactive - the timer will be skipped in check_software_timers()
186    // and cleaned up when it expires
187    if let Some(active) = TIMER_ACTIVE_FLAGS.write().get_mut(&id) {
188        *active = false;
189    }
190}
191
192/// Check if a timer is active (used by check_software_timers)
193#[inline]
194fn is_timer_active(id: u64) -> bool {
195    TIMER_ACTIVE_FLAGS.read().get(&id).copied().unwrap_or(false)
196}
197
198/// Call this from tick() to check and fire expired timers
199fn check_software_timers(now: u64) {
200    use alloc::vec::Vec;
201    let mut expired = Vec::new();
202    let mut cleanup_ids = Vec::new();
203
204    {
205        let mut heap = SOFTWARE_TIMER_HEAP.lock();
206        let active_flags = TIMER_ACTIVE_FLAGS.read();
207
208        while let Some(timer) = heap.peek() {
209            if timer.expires <= now {
210                let timer = heap.pop().unwrap();
211                // Check if still active
212                if active_flags.get(&timer.id).copied().unwrap_or(false) {
213                    expired.push(timer);
214                } else {
215                    // Mark for cleanup (will be done outside locks)
216                    cleanup_ids.push(timer.id);
217                }
218            } else {
219                break;
220            }
221        }
222    } // Unlock the heap
223
224    // Clean up inactive timers (outside of read lock to avoid deadlock)
225    if !cleanup_ids.is_empty() {
226        let mut active_flags = TIMER_ACTIVE_FLAGS.write();
227        for id in cleanup_ids {
228            active_flags.remove(&id);
229        }
230    }
231
232    // Execute callbacks outside of all locks
233    for timer in expired {
234        // Double-check active status before executing
235        let should_execute = {
236            let active_flags = TIMER_ACTIVE_FLAGS.read();
237            active_flags.get(&timer.id).copied().unwrap_or(false)
238        };
239
240        if should_execute {
241            // Clean up from flags map before executing handler
242            TIMER_ACTIVE_FLAGS.write().remove(&timer.id);
243
244            if let Some(handler) = timer.handler.upgrade() {
245                handler.on_timer_expired(timer.context);
246            }
247        }
248    }
249}
250
251// Tick interval in microseconds (e.g., 10_000 for 10ms tick)
252pub const TICK_INTERVAL_US: u64 = 10_000; // 10ms tick
253
254/// Convert milliseconds to ticks
255#[inline]
256pub fn ms_to_ticks(ms: u64) -> u64 {
257    (ms * 1_000) / TICK_INTERVAL_US
258}
259
260/// Convert microseconds to ticks
261#[inline]
262pub fn us_to_ticks(us: u64) -> u64 {
263    us / TICK_INTERVAL_US
264}
265
266/// Convert nanoseconds to ticks
267#[inline]
268pub fn ns_to_ticks(ns: u64) -> u64 {
269    (ns / 1_000) / TICK_INTERVAL_US
270}
271
272/// Convert ticks to milliseconds
273#[inline]
274pub fn ticks_to_ms(ticks: u64) -> u64 {
275    (ticks * TICK_INTERVAL_US) / 1_000
276}
277
278/// Convert ticks to microseconds
279#[inline]
280pub fn ticks_to_us(ticks: u64) -> u64 {
281    ticks * TICK_INTERVAL_US
282}
283
284/// Convert ticks to nanoseconds
285#[inline]
286pub fn ticks_to_ns(ticks: u64) -> u64 {
287    (ticks * TICK_INTERVAL_US) * 1_000
288}
289
290// static mut TEST_HANDLER: Option<Arc<dyn TimerHandler>> = None;
291
292// // TEST
293// fn register_test_timer() {
294//     use alloc::sync::Arc;
295
296//     struct TestHandler;
297//     impl TimerHandler for TestHandler {
298//         #[allow(static_mut_refs)]
299//         fn on_timer_expired(&self, context: usize) {
300//             crate::early_println!("[Software Timer] Test timer expired with context: {}", context);
301//             if let Some(handler) = unsafe { TEST_HANDLER.clone() } {
302//                 crate::early_println!("[Software Timer] Test handler is still available.");
303//                 let handler = handler.clone();
304//                 add_timer(get_tick() + 100, &handler, context);
305//             } else {
306//                 crate::early_println!("[Software Timer] Test handler is no longer available.");
307//             }
308//         }
309//     }
310
311//     let handler: Arc<dyn TimerHandler>  = Arc::new(TestHandler);
312//     let target_tick = get_tick() + 100; // 100 ticks from now
313//     let id = add_timer(target_tick, &handler, 42);
314//     crate::early_println!("Test timer registered with ID: {}, tick: {}", id, target_tick);
315//     unsafe {
316//         TEST_HANDLER = Some(handler);
317//     }
318// }
319
320// late_initcall!(register_test_timer);