1use alloc::{string::String, string::ToString, sync::Arc};
12use spin::Mutex;
13
14use crate::object::KernelObject;
15use crate::object::capability::selectable::{
16 ReadyInterest, ReadySet, SelectWaitOutcome, Selectable,
17};
18use crate::object::capability::{CloneOps, StreamError, StreamOps};
19use crate::sync::waker::Waker;
20
21struct CounterState {
23 counter: u64,
25 semaphore: bool,
27}
28
29struct SharedCounterData {
31 state: Mutex<CounterState>,
33 read_waker: Waker,
35 write_waker: Waker,
37}
38
39impl SharedCounterData {
40 fn new(initval: u32, semaphore: bool) -> Arc<Self> {
41 Arc::new(Self {
42 state: Mutex::new(CounterState {
43 counter: initval as u64,
44 semaphore,
45 }),
46 read_waker: Waker::new_interruptible("counter_read"),
47 write_waker: Waker::new_interruptible("counter_write"),
48 })
49 }
50}
51
52pub struct Counter {
54 data: Arc<SharedCounterData>,
56 id: String,
58 nonblocking: core::sync::atomic::AtomicBool,
60}
61
62impl Counter {
63 pub fn new(initval: u32, semaphore: bool) -> Self {
65 Self {
66 data: SharedCounterData::new(initval, semaphore),
67 id: "counter".to_string(),
68 nonblocking: core::sync::atomic::AtomicBool::new(false),
69 }
70 }
71
72 pub fn create_pair(initval: u32, semaphore: bool) -> (KernelObject, KernelObject) {
74 let data = SharedCounterData::new(initval, semaphore);
75
76 let counter1 = Self {
77 data: data.clone(),
78 id: "counter_1".to_string(),
79 nonblocking: core::sync::atomic::AtomicBool::new(false),
80 };
81
82 let counter2 = Self {
83 data: data.clone(),
84 id: "counter_2".to_string(),
85 nonblocking: core::sync::atomic::AtomicBool::new(false),
86 };
87
88 let obj1 = KernelObject::from_counter(Arc::new(counter1));
90 let obj2 = KernelObject::from_counter(Arc::new(counter2));
91
92 (obj1, obj2)
93 }
94
95 pub fn create_kernel_object(initval: u32, flags: u32) -> KernelObject {
97 const EFD_SEMAPHORE: u32 = 0x00000001;
98 const EFD_NONBLOCK: u32 = 0o00004000;
99
100 let semaphore = (flags & EFD_SEMAPHORE) != 0;
101 let nonblocking = (flags & EFD_NONBLOCK) != 0;
102
103 let mut counter = Self::new(initval, semaphore);
104 if nonblocking {
105 counter
106 .nonblocking
107 .store(true, core::sync::atomic::Ordering::Relaxed);
108 }
109
110 KernelObject::from_counter(Arc::new(counter))
111 }
112
113 fn do_read(&self, buffer: &mut [u8]) -> Result<usize, StreamError> {
115 if buffer.len() < 8 {
116 return Err(StreamError::InvalidArgument);
117 }
118
119 loop {
120 let mut state = self.data.state.lock();
121
122 if state.counter == 0 {
123 if self.nonblocking.load(core::sync::atomic::Ordering::Relaxed) {
125 return Err(StreamError::WouldBlock);
126 }
127
128 return Err(StreamError::WouldBlock);
130 }
131
132 let value = if state.semaphore {
134 state.counter -= 1;
136 1u64
137 } else {
138 let value = state.counter;
140 state.counter = 0;
141 value
142 };
143
144 drop(state);
146
147 buffer[0..8].copy_from_slice(&value.to_ne_bytes());
149
150 self.data.write_waker.wake_all();
152
153 return Ok(8);
154 }
155 }
156
157 fn do_write(&self, buffer: &[u8]) -> Result<usize, StreamError> {
159 if buffer.len() < 8 {
160 return Err(StreamError::InvalidArgument);
161 }
162
163 let mut value_bytes = [0u8; 8];
165 value_bytes.copy_from_slice(&buffer[0..8]);
166 let add_value = u64::from_ne_bytes(value_bytes);
167
168 if add_value == u64::MAX {
170 return Err(StreamError::InvalidArgument);
171 }
172
173 loop {
174 let mut state = self.data.state.lock();
175
176 if state.counter > u64::MAX - add_value - 1 {
178 if self.nonblocking.load(core::sync::atomic::Ordering::Relaxed) {
180 return Err(StreamError::WouldBlock);
181 }
182
183 return Err(StreamError::WouldBlock);
185 }
186
187 state.counter = state.counter.wrapping_add(add_value);
189
190 drop(state);
192
193 self.data.read_waker.wake_all();
195
196 return Ok(8);
197 }
198 }
199}
200
201impl StreamOps for Counter {
202 fn read(&self, buffer: &mut [u8]) -> Result<usize, StreamError> {
203 self.do_read(buffer)
204 }
205
206 fn write(&self, buffer: &[u8]) -> Result<usize, StreamError> {
207 self.do_write(buffer)
208 }
209}
210
211impl CloneOps for Counter {
212 fn custom_clone(&self) -> KernelObject {
213 KernelObject::from_counter(Arc::new(self.clone()))
215 }
216}
217
218impl Clone for Counter {
219 fn clone(&self) -> Self {
220 let mut new_id = String::from(self.id.as_str());
221 new_id.push_str("_clone");
222 Self {
223 data: self.data.clone(),
224 id: new_id,
225 nonblocking: core::sync::atomic::AtomicBool::new(
226 self.nonblocking.load(core::sync::atomic::Ordering::Relaxed),
227 ),
228 }
229 }
230}
231
232impl Selectable for Counter {
233 fn current_ready(&self, interest: ReadyInterest) -> ReadySet {
234 let mut set = ReadySet::none();
235 let state = self.data.state.lock();
236
237 if interest.read {
238 set.read = state.counter > 0;
240 }
241 if interest.write {
242 set.write = state.counter < u64::MAX - 1;
244 }
245 if interest.except {
246 set.except = false;
247 }
248
249 set
250 }
251
252 fn wait_until_ready(
253 &self,
254 interest: ReadyInterest,
255 trapframe: &mut crate::arch::Trapframe,
256 timeout_ticks: Option<u64>,
257 ) -> SelectWaitOutcome {
258 let current = self.current_ready(interest);
259 if (interest.read && current.read) || (interest.write && current.write) {
260 return SelectWaitOutcome::Ready;
261 }
262
263 let task_id = {
264 use crate::arch::get_cpu;
265 use crate::sched::scheduler::get_scheduler;
266 let cpu_id = get_cpu().get_cpuid();
267 get_scheduler().get_current_task_id(cpu_id).unwrap_or(0)
268 };
269
270 let woke = if interest.read {
271 self.data
272 .read_waker
273 .wait_with_timeout(task_id, trapframe, timeout_ticks)
274 } else if interest.write {
275 self.data
276 .write_waker
277 .wait_with_timeout(task_id, trapframe, timeout_ticks)
278 } else {
279 false
280 };
281
282 let after = self.current_ready(interest);
283 if timeout_ticks.is_some() && !woke && !after.read && !after.write {
284 SelectWaitOutcome::TimedOut
285 } else {
286 SelectWaitOutcome::Ready
287 }
288 }
289
290 fn set_nonblocking(&self, enabled: bool) {
291 self.nonblocking
292 .store(enabled, core::sync::atomic::Ordering::Relaxed);
293 }
294
295 fn is_nonblocking(&self) -> bool {
296 self.nonblocking.load(core::sync::atomic::Ordering::Relaxed)
297 }
298}
299
300pub trait CounterObject: StreamOps + Selectable + CloneOps {
302 fn is_semaphore(&self) -> bool;
304}
305
306impl CounterObject for Counter {
307 fn is_semaphore(&self) -> bool {
308 self.data.state.lock().semaphore
309 }
310}