kernel/sync/
waker.rs

1//! Waker - Synchronization primitive for task waiting and waking
2//!
3//! This module provides the `Waker` struct, which manages asynchronous task waiting
4//! and waking mechanisms. It allows tasks to block on specific events and be woken
5//! up when those events occur, such as I/O completion or interrupt handling.
6
7extern crate alloc;
8
9use crate::arch::Trapframe;
10use crate::sched::scheduler::get_scheduler;
11use crate::task::{BlockedType, TaskState};
12use alloc::collections::VecDeque;
13use core::fmt;
14use spin::Mutex;
15
16/// A synchronization primitive that manages waiting and waking of tasks
17///
18/// The `Waker` struct provides a mechanism for tasks to wait for specific events
19/// and be woken up when those events occur. It maintains a queue of waiting task IDs
20/// and provides methods to block the current task or wake up waiting tasks.
21///
22/// # Examples
23///
24/// ```
25/// // Create a new interruptible waker for UART receive events
26/// static UART_RX_WAKER: Waker = Waker::new_interruptible("uart_rx");
27///
28/// // In a blocking read function
29/// UART_RX_WAKER.wait();
30///
31/// // In an interrupt handler
32/// UART_RX_WAKER.wake_one();
33/// ```
34pub struct Waker {
35    /// Queue of waiting task IDs
36    wait_queue: Mutex<VecDeque<usize>>,
37    /// The type of blocking this waker uses (interruptible or uninterruptible)
38    block_type: BlockedType,
39    /// Human-readable name for debugging purposes
40    name: &'static str,
41}
42
43impl Waker {
44    /// Create a new interruptible waker
45    ///
46    /// Interruptible wakers allow waiting tasks to be interrupted by signals
47    /// or other asynchronous events. This is suitable for user I/O operations
48    /// where cancellation might be needed.
49    ///
50    /// # Arguments
51    ///
52    /// * `name` - A human-readable name for debugging purposes
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// static KEYBOARD_WAKER: Waker = Waker::new_interruptible("keyboard");
58    /// ```
59    pub const fn new_interruptible(name: &'static str) -> Self {
60        Self {
61            wait_queue: Mutex::new(VecDeque::new()),
62            block_type: BlockedType::Interruptible,
63            name,
64        }
65    }
66
67    /// Create a new uninterruptible waker
68    ///
69    /// Uninterruptible wakers ensure that waiting tasks cannot be interrupted
70    /// and will wait until the event occurs. This is suitable for critical
71    /// operations like disk I/O where data integrity is important.
72    ///
73    /// # Arguments
74    ///
75    /// * `name` - A human-readable name for debugging purposes
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// static DISK_IO_WAKER: Waker = Waker::new_uninterruptible("disk_io");
81    /// ```
82    pub const fn new_uninterruptible(name: &'static str) -> Self {
83        Self {
84            wait_queue: Mutex::new(VecDeque::new()),
85            block_type: BlockedType::Uninterruptible,
86            name,
87        }
88    }
89
90    /// Block the current task and add it to the wait queue
91    ///
92    /// This method puts the current task into a blocked state and adds its ID
93    /// to the wait queue. The task will remain blocked until another part of
94    /// the system calls `wake_one()` or `wake_all()` on this waker.
95    ///
96    /// # Behavior
97    ///
98    /// 1. Gets the current task ID
99    /// 2. Sets the task state to `Blocked(self.block_type)` FIRST
100    /// 3. Adds the task ID to the wait queue
101    /// 4. Calls the scheduler to yield CPU to other tasks
102    /// 5. Returns when the task is woken up and rescheduled
103    ///
104    /// # Note
105    ///
106    /// This function returns when the task is woken up by another part of the system.
107    /// The calling code can then continue execution, typically to re-check the
108    /// condition that caused the wait.
109    ///
110    /// # Critical Section
111    ///
112    /// To prevent race conditions between wait() and wake_one()/wake_all():
113    /// 1. Set task state to Blocked BEFORE adding to queue
114    /// 2. This ensures wake_task() can safely operate even if called immediately
115    pub fn wait(&self, task_id: usize, trapframe: &mut Trapframe) {
116        // CRITICAL: Set task state to Blocked FIRST, before adding to queue
117        // This prevents race condition where wake_one() is called after queue.push_back()
118        // but before set_state(), which would leave the task in Running state but not in queue
119        if let Some(task) = get_scheduler().get_task_by_id(task_id) {
120            task.set_state(TaskState::Blocked(self.block_type));
121        } else {
122            panic!("[WAKER] Task ID {} not found in scheduler", task_id);
123        }
124
125        // Memory barrier to ensure state change is visible before queue operation
126        core::sync::atomic::fence(core::sync::atomic::Ordering::SeqCst);
127
128        // Now add task to wait queue - at this point task is already Blocked
129        // Even if wake_one() is called immediately, wake_task() will work correctly
130        {
131            let mut queue = self.wait_queue.lock();
132            queue.push_back(task_id);
133        }
134
135        // Memory barrier to ensure queue addition is visible before yielding
136        core::sync::atomic::fence(core::sync::atomic::Ordering::SeqCst);
137
138        // Yield CPU to scheduler - returns when woken
139        get_scheduler().schedule(trapframe);
140    }
141
142    /// Block the task until woken or the timeout elapses.
143    ///
144    /// Returns true if woken by event, false if timeout elapsed.
145    pub fn wait_with_timeout(
146        &self,
147        task_id: usize,
148        trapframe: &mut Trapframe,
149        timeout_ticks: Option<u64>,
150    ) -> bool {
151        if matches!(timeout_ticks, Some(0)) {
152            return false;
153        }
154
155        if let Some(ticks) = timeout_ticks {
156            use crate::timer::{TimerHandler, add_timer, cancel_timer, get_tick};
157            use alloc::sync::Arc;
158            use core::sync::atomic::{AtomicBool, Ordering};
159
160            struct TimeoutWake {
161                task_id: usize,
162                timed_out: AtomicBool,
163            }
164
165            impl TimerHandler for TimeoutWake {
166                fn on_timer_expired(self: Arc<Self>, _context: usize) {
167                    self.timed_out.store(true, Ordering::SeqCst);
168                    let scheduler = get_scheduler();
169                    let _ = scheduler.wake_task(self.task_id);
170                }
171            }
172
173            let handler: Arc<TimeoutWake> = Arc::new(TimeoutWake {
174                task_id,
175                timed_out: AtomicBool::new(false),
176            });
177            let handler_ref: Arc<dyn TimerHandler> = handler.clone();
178            let id = add_timer(get_tick().saturating_add(ticks), &handler_ref, 0);
179
180            self.wait(task_id, trapframe);
181
182            cancel_timer(id);
183
184            !handler.timed_out.load(Ordering::SeqCst)
185        } else {
186            self.wait(task_id, trapframe);
187            true
188        }
189    }
190
191    // /// Block any task (not limited to the current task) and add it to the wait queue
192    // ///
193    // /// This method is intended for blocking tasks other than the current one.
194    // /// It sets the specified task's state to Blocked and adds it to the wait queue.
195    // /// No scheduler switch or CPU state saving is performed.
196    // ///
197    // /// # Arguments
198    // /// * `task_id` - The ID of the task to be blocked
199    // pub fn block(&self, task_id: usize) {
200    //     {
201    //         let mut queue = self.wait_queue.lock();
202    //         queue.push_back(task_id);
203    //     }
204
205    //     if let Some(task) = get_scheduler().get_task_by_id(task_id) {
206    //         // Set task state to blocked
207    //         task.set_state(TaskState::Blocked(self.block_type));
208    //     } else {
209    //         panic!("[WAKER] Task ID {} not found in scheduler", task_id);
210    //     }
211
212    //     // Yield CPU to scheduler - this will return when the task is woken up
213    //     get_scheduler().schedule(cpu);
214
215    //     // When we reach here, the task has been woken up and rescheduled
216    //     // crate::println!("[WAKER] Task {} woken up from waker '{}'", task_id, self.name);
217    // }
218
219    /// Wake up one waiting task
220    ///
221    /// This method removes one task from the wait queue and moves it from
222    /// the blocked queue to the ready queue, making it eligible for scheduling again.
223    ///
224    /// # Returns
225    ///
226    /// * `true` if a task was woken up
227    /// * `false` if the wait queue was empty
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// // In an interrupt handler
233    /// if UART_RX_WAKER.wake_one() {
234    ///     // A task was woken up
235    /// }
236    /// ```
237    pub fn wake_one(&self) -> bool {
238        let task_id = {
239            let mut queue = self.wait_queue.lock();
240            queue.pop_front()
241        };
242
243        if let Some(task_id) = task_id {
244            // Use the scheduler's wake_task method to move from blocked to ready queue
245            get_scheduler().wake_task(task_id)
246        } else {
247            false
248        }
249    }
250
251    /// Wake up all waiting tasks
252    ///
253    /// This method removes all tasks from the wait queue and moves them from
254    /// the blocked queue to the ready queue, making them all eligible for scheduling again.
255    ///
256    /// # Returns
257    ///
258    /// The number of tasks that were woken up
259    ///
260    /// # Examples
261    ///
262    /// ```
263    /// // Wake all tasks waiting for a broadcast event
264    /// let woken_count = BROADCAST_WAKER.wake_all();
265    /// println!("Woke up {} tasks", woken_count);
266    /// ```
267    pub fn wake_all(&self) -> usize {
268        let task_ids = {
269            let mut queue = self.wait_queue.lock();
270            let ids: VecDeque<usize> = queue.drain(..).collect();
271            ids
272        };
273
274        let mut woken_count = 0;
275        for task_id in task_ids {
276            // Use the scheduler's wake_task method to move from blocked to ready queue
277            if get_scheduler().wake_task(task_id) {
278                woken_count += 1;
279            }
280        }
281
282        woken_count
283    }
284
285    /// Get the blocking type of this waker
286    ///
287    /// # Returns
288    ///
289    /// The `BlockedType` (either `Interruptible` or `Uninterruptible`)
290    pub fn block_type(&self) -> BlockedType {
291        self.block_type
292    }
293
294    /// Get the number of tasks currently waiting
295    ///
296    /// # Returns
297    ///
298    /// The number of tasks in the wait queue
299    pub fn waiting_count(&self) -> usize {
300        self.wait_queue.lock().len()
301    }
302
303    /// Get the name of this waker
304    ///
305    /// # Returns
306    ///
307    /// The human-readable name for debugging purposes
308    pub fn name(&self) -> &'static str {
309        self.name
310    }
311
312    /// Get a list of task IDs currently waiting in the queue
313    ///
314    /// This method returns a snapshot of all task IDs currently waiting
315    /// in this waker's queue. Useful for debugging and monitoring.
316    ///
317    /// # Returns
318    ///
319    /// A vector containing all waiting task IDs
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// let waiting_tasks = waker.get_waiting_task_ids();
325    /// println!("Tasks waiting: {:?}", waiting_tasks);
326    /// ```
327    pub fn get_waiting_task_ids(&self) -> VecDeque<usize> {
328        self.wait_queue.lock().clone()
329    }
330
331    /// Check if a specific task is waiting in this waker
332    ///
333    /// # Arguments
334    ///
335    /// * `task_id` - The ID of the task to check
336    ///
337    /// # Returns
338    ///
339    /// `true` if the task is waiting in this waker, `false` otherwise
340    pub fn is_task_waiting(&self, task_id: usize) -> bool {
341        self.wait_queue.lock().contains(&task_id)
342    }
343
344    /// Get detailed statistics about this waker
345    ///
346    /// This method provides detailed information about the current state
347    /// of the waker, including all waiting tasks and their metadata.
348    ///
349    /// # Returns
350    ///
351    /// A `WakerStats` struct containing comprehensive state information
352    ///
353    /// # Examples
354    ///
355    /// ```
356    /// let stats = uart_waker.get_stats();
357    /// // Use Debug trait to print the stats
358    /// ```
359    pub fn get_stats(&self) -> WakerStats {
360        let waiting_tasks = self.wait_queue.lock();
361        WakerStats {
362            name: self.name,
363            block_type: self.block_type,
364            waiting_count: waiting_tasks.len(),
365            waiting_task_ids: waiting_tasks.clone(),
366        }
367    }
368
369    /// Print debug information about this waker
370    ///
371    /// Outputs detailed information about the waker's current state
372    /// including name, blocking type, waiting task count, and task IDs.
373    /// Useful for debugging and monitoring system state.
374    ///
375    /// # Examples
376    ///
377    /// ```
378    /// waker.debug_print();
379    /// // Output:
380    /// // [Waker DEBUG] uart_rx: Interruptible, 3 waiting tasks: [42, 137, 89]
381    /// ```
382    /// Check if the waker has any waiting tasks
383    ///
384    /// # Returns
385    ///
386    /// `true` if there are no waiting tasks, `false` otherwise
387    pub fn is_empty(&self) -> bool {
388        self.wait_queue.lock().is_empty()
389    }
390
391    /// Clear all waiting tasks without waking them
392    ///
393    /// This is a dangerous operation that should only be used in
394    /// exceptional circumstances like system cleanup or error recovery.
395    /// The tasks will remain in blocked state and need to be handled
396    /// separately.
397    ///
398    /// # Returns
399    ///
400    /// The number of tasks that were removed from the queue
401    ///
402    /// # Safety
403    ///
404    /// This operation can leave tasks in a permanently blocked state.
405    /// Use with extreme caution.
406    pub fn clear_queue(&self) -> usize {
407        let mut queue = self.wait_queue.lock();
408        let count = queue.len();
409        queue.clear();
410        count
411    }
412}
413
414impl fmt::Debug for Waker {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        let waiting_tasks = self.wait_queue.lock();
417        f.debug_struct("Waker")
418            .field("name", &self.name)
419            .field("block_type", &self.block_type)
420            .field("waiting_count", &waiting_tasks.len())
421            .field("waiting_task_ids", &*waiting_tasks)
422            .finish()
423    }
424}
425
426/// Statistics and state information for a Waker
427///
428/// This struct provides a comprehensive view of a waker's current state,
429/// useful for debugging, monitoring, and system analysis.
430#[derive(Debug, Clone)]
431pub struct WakerStats {
432    /// Human-readable name of the waker
433    pub name: &'static str,
434    /// The blocking type (Interruptible or Uninterruptible)
435    pub block_type: BlockedType,
436    /// Number of tasks currently waiting
437    pub waiting_count: usize,
438    /// List of task IDs currently waiting
439    pub waiting_task_ids: VecDeque<usize>,
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test_case]
447    fn test_waker_creation() {
448        let interruptible_waker = Waker::new_interruptible("test_int");
449        assert_eq!(interruptible_waker.name(), "test_int");
450        assert_eq!(interruptible_waker.block_type(), BlockedType::Interruptible);
451        assert_eq!(interruptible_waker.waiting_count(), 0);
452
453        let uninterruptible_waker = Waker::new_uninterruptible("test_unint");
454        assert_eq!(uninterruptible_waker.name(), "test_unint");
455        assert_eq!(
456            uninterruptible_waker.block_type(),
457            BlockedType::Uninterruptible
458        );
459        assert_eq!(uninterruptible_waker.waiting_count(), 0);
460    }
461
462    #[test_case]
463    fn test_wake_empty_queue() {
464        let waker = Waker::new_interruptible("empty_test");
465        assert_eq!(waker.wake_one(), false);
466        assert_eq!(waker.wake_all(), 0);
467    }
468
469    #[test_case]
470    fn test_debug_functionality() {
471        let waker = Waker::new_interruptible("debug_test");
472
473        // Test empty waker
474        assert!(waker.is_empty());
475        assert_eq!(waker.waiting_count(), 0);
476        assert_eq!(waker.get_waiting_task_ids().len(), 0);
477        assert!(!waker.is_task_waiting(42));
478
479        // Test stats
480        let stats = waker.get_stats();
481        assert_eq!(stats.name, "debug_test");
482        assert_eq!(stats.block_type, BlockedType::Interruptible);
483        assert_eq!(stats.waiting_count, 0);
484        assert!(stats.waiting_task_ids.is_empty());
485    }
486
487    #[test_case]
488    fn test_debug_trait() {
489        let waker = Waker::new_uninterruptible("debug_trait_test");
490
491        // Verify Debug trait implementation exists and works
492        let debug_string = alloc::format!("{:?}", waker);
493        assert!(debug_string.contains("debug_trait_test"));
494        assert!(debug_string.contains("Uninterruptible"));
495        assert!(debug_string.contains("waiting_count: 0"));
496    }
497
498    #[test_case]
499    fn test_clear_queue() {
500        let waker = Waker::new_interruptible("clear_test");
501
502        // Test clearing empty queue
503        assert_eq!(waker.clear_queue(), 0);
504        assert!(waker.is_empty());
505    }
506
507    #[test_case]
508    fn test_waker_stats_debug() {
509        let waker = Waker::new_interruptible("stats_test");
510        let stats = waker.get_stats();
511
512        // Test WakerStats Debug implementation
513        let debug_string = alloc::format!("{:?}", stats);
514        assert!(debug_string.contains("stats_test"));
515        assert!(debug_string.contains("Interruptible"));
516    }
517}