1use alloc::{
21 collections::VecDeque,
22 format,
23 string::{String, ToString},
24 sync::{Arc, Weak},
25 vec::Vec,
26};
27use core::any::Any;
28use spin::RwLock;
29
30use super::{
31 LocalSocketAddress, NetworkManager, ShutdownHow, SocketAddress, SocketControl, SocketDomain,
32 SocketError, SocketObject, SocketProtocol, SocketState, SocketType,
33};
34use crate::ipc::StreamIpcOps;
35use crate::object::KernelObject;
36use crate::object::capability::{
37 ControlOps, ReadyInterest, ReadySet, SelectWaitOutcome, Selectable, StreamError, StreamOps,
38};
39use crate::sync::Waker;
40
41const LOCALSOCKET_LOG: bool = false;
42
43macro_rules! localsocket_log {
44 ($($arg:tt)*) => {
45 if LOCALSOCKET_LOG {
46 crate::println!($($arg)*);
47 }
48 };
49}
50
51const MAX_BUFFER_SIZE: usize = 65536;
53
54const MAX_HANDLE_QUEUE_SIZE: usize = 64;
57
58struct SocketBuffer {
60 data: RwLock<VecDeque<u8>>,
62 closed: RwLock<bool>,
64}
65
66impl SocketBuffer {
67 fn new() -> Arc<Self> {
68 Arc::new(Self {
69 data: RwLock::new(VecDeque::with_capacity(MAX_BUFFER_SIZE)),
70 closed: RwLock::new(false),
71 })
72 }
73}
74
75pub struct LocalSocket {
81 socket_type: SocketType,
83
84 self_weak: RwLock<Weak<LocalSocket>>,
89
90 protocol: SocketProtocol,
92
93 state: RwLock<SocketState>,
95
96 local_addr: RwLock<Option<String>>,
98
99 peer_addr: RwLock<Option<String>>,
101
102 read_buffer: RwLock<Arc<SocketBuffer>>,
104
105 peer_read_buffer: RwLock<Option<Arc<SocketBuffer>>>,
108
109 peer_socket: RwLock<Option<Weak<LocalSocket>>>,
111
112 backlog: RwLock<Vec<Arc<LocalSocket>>>,
115
116 max_backlog: RwLock<usize>,
118
119 accept_waker: Waker,
121
122 read_waker: Waker,
124
125 handle_waker: Waker,
127
128 handle_queue: RwLock<VecDeque<KernelObject>>,
131
132 nonblocking: RwLock<bool>,
134}
135
136impl LocalSocket {
137 pub(crate) fn init_self_weak(this: &Arc<Self>) {
138 *this.self_weak.write() = Arc::downgrade(this);
139 }
140
141 pub fn from_socket_object(socket: &Arc<dyn SocketObject>) -> Option<&Self> {
146 socket.as_any().downcast_ref::<LocalSocket>()
148 }
149
150 pub fn new(socket_type: SocketType, protocol: SocketProtocol) -> Self {
161 Self {
162 socket_type,
163 protocol,
164 state: RwLock::new(SocketState::Unconnected),
165 local_addr: RwLock::new(None),
166 peer_addr: RwLock::new(None),
167 read_buffer: RwLock::new(SocketBuffer::new()),
168 peer_read_buffer: RwLock::new(None),
169 peer_socket: RwLock::new(None),
170 backlog: RwLock::new(Vec::new()),
171 max_backlog: RwLock::new(0),
172 accept_waker: Waker::new_interruptible("socket_accept"),
173 read_waker: Waker::new_interruptible("socket_read"),
174 handle_waker: Waker::new_interruptible("socket_handle"),
175 handle_queue: RwLock::new(VecDeque::new()),
176 self_weak: RwLock::new(Weak::new()),
177 nonblocking: RwLock::new(false),
178 }
179 }
180
181 pub fn send_handle(&self, object: KernelObject) -> Result<(), crate::ipc::IpcError> {
185 use crate::ipc::IpcError;
186
187 if *self.state.read() != SocketState::Connected {
189 return Err(IpcError::InvalidState);
190 }
191
192 let peer_weak = self.peer_socket.read();
194 let peer_weak_ref = peer_weak.as_ref().ok_or(IpcError::PeerClosed)?;
195 let peer = peer_weak_ref.upgrade().ok_or(IpcError::PeerClosed)?;
196
197 let mut peer_queue = peer.handle_queue.write();
199 if peer_queue.len() >= MAX_HANDLE_QUEUE_SIZE {
200 return Err(IpcError::ChannelFull);
201 }
202
203 peer_queue.push_back(object);
205 drop(peer_queue);
206
207 peer.handle_waker.wake_one();
209
210 Ok(())
211 }
212
213 pub fn send_handle_and_data(
222 &self,
223 object: KernelObject,
224 data: &[u8],
225 ) -> Result<(), crate::ipc::IpcError> {
226 use crate::ipc::IpcError;
227
228 localsocket_log!(
229 "[LocalSocket] send_handle_and_data: self={:p}, data_len={}",
230 self as *const _,
231 data.len()
232 );
233
234 if *self.state.read() != SocketState::Connected {
236 localsocket_log!("[LocalSocket] send_handle_and_data: not connected");
237 return Err(IpcError::InvalidState);
238 }
239
240 let peer_weak = self.peer_socket.read();
242 let peer_weak_ref = peer_weak.as_ref().ok_or(IpcError::PeerClosed)?;
243 let peer = peer_weak_ref.upgrade().ok_or(IpcError::PeerClosed)?;
244
245 localsocket_log!(
246 "[LocalSocket] send_handle_and_data: peer={:p}",
247 peer.as_ref() as *const _
248 );
249
250 let mut peer_handle_queue = peer.handle_queue.write();
252 if peer_handle_queue.len() >= MAX_HANDLE_QUEUE_SIZE {
253 localsocket_log!("[LocalSocket] send_handle_and_data: handle queue full");
254 return Err(IpcError::ChannelFull);
255 }
256
257 let peer_buffer_option = peer.peer_read_buffer.read();
259 let peer_sock_buffer = peer_buffer_option.as_ref().ok_or(IpcError::PeerClosed)?;
260
261 let mut peer_buffer = peer_sock_buffer.data.write();
263 if peer_buffer.len() + data.len() > MAX_BUFFER_SIZE {
264 localsocket_log!(
265 "[LocalSocket] send_handle_and_data: buffer full, current_len={}, adding_len={}",
266 peer_buffer.len(),
267 data.len()
268 );
269 drop(peer_buffer);
270 drop(peer_buffer_option);
271 drop(peer_handle_queue);
272 return Err(IpcError::ChannelFull);
273 }
274
275 localsocket_log!(
276 "[LocalSocket] send_handle_and_data: before send - handle_queue_len={}, buffer_len={}",
277 peer_handle_queue.len(),
278 peer_buffer.len()
279 );
280
281 peer_handle_queue.push_back(object);
283 let queue_len = peer_handle_queue.len();
284 drop(peer_handle_queue);
285
286 peer_buffer.extend(data.iter().copied());
288 let buffer_len = peer_buffer.len();
289 drop(peer_buffer);
290 drop(peer_buffer_option);
291
292 localsocket_log!(
293 "[LocalSocket] send_handle_and_data: after send - handle_queue_len={}, buffer_len={}",
294 queue_len,
295 buffer_len
296 );
297
298 peer.handle_waker.wake_one();
300 peer.read_waker.wake_one();
301
302 Ok(())
303 }
304
305 pub fn recv_handle_and_data(
319 &self,
320 max_data_len: usize,
321 ) -> Result<(KernelObject, Vec<u8>), crate::ipc::IpcError> {
322 use crate::ipc::IpcError;
323
324 localsocket_log!(
325 "[LocalSocket] recv_handle_and_data: self={:p}, max_data_len={}",
326 self as *const _,
327 max_data_len
328 );
329
330 if *self.state.read() != SocketState::Connected {
332 localsocket_log!("[LocalSocket] recv_handle_and_data: not connected");
333 return Err(IpcError::InvalidState);
334 }
335
336 let mut queue = self.handle_queue.write();
338 localsocket_log!(
339 "[LocalSocket] recv_handle_and_data: handle_queue_len={}",
340 queue.len()
341 );
342
343 let handle = match queue.pop_front() {
344 Some(h) => h,
345 None => {
346 localsocket_log!(
347 "[LocalSocket] recv_handle_and_data: handle queue empty - returning ChannelEmpty"
348 );
349 return Err(IpcError::ChannelEmpty);
350 }
351 };
352 drop(queue);
353
354 let read_buffer = self.read_buffer.read();
356 let mut buffer_data = read_buffer.data.write();
357 localsocket_log!(
358 "[LocalSocket] recv_handle_and_data: buffer_len={}, max_data_len={}",
359 buffer_data.len(),
360 max_data_len
361 );
362
363 let actual_len = buffer_data.len().min(max_data_len);
365 let mut data = Vec::with_capacity(actual_len);
366 for _ in 0..actual_len {
367 data.push(buffer_data.pop_front().unwrap());
368 }
369 drop(buffer_data);
370 drop(read_buffer);
371
372 localsocket_log!(
373 "[LocalSocket] recv_handle_and_data: returning handle and {} bytes of data",
374 data.len()
375 );
376
377 Ok((handle, data))
378 }
379
380 pub fn recv_handle(&self) -> Result<KernelObject, crate::ipc::IpcError> {
382 use crate::ipc::IpcError;
383
384 if *self.state.read() != SocketState::Connected {
386 return Err(IpcError::InvalidState);
387 }
388
389 let mut queue = self.handle_queue.write();
391 queue.pop_front().ok_or(IpcError::ChannelEmpty)
392 }
393
394 pub fn accept_blocking(
408 &self,
409 task_id: usize,
410 trapframe: &mut crate::arch::Trapframe,
411 ) -> Result<Arc<dyn SocketObject>, SocketError> {
412 let state = self.state.read();
413 if *state != SocketState::Listening {
414 return Err(SocketError::NotListening);
415 }
416 drop(state);
417
418 loop {
420 {
421 let mut backlog = self.backlog.write();
422 if let Some(client_socket) = backlog.pop() {
423 return Ok(client_socket);
424 }
425 } self.accept_waker.wait(task_id, trapframe);
429
430 }
433 }
434
435 pub fn create_connected_pair(local_addr: String, peer_addr: String) -> (Arc<Self>, Arc<Self>) {
448 let local_read_buffer = SocketBuffer::new();
450 let peer_read_buffer = SocketBuffer::new();
451
452 let local_socket = Arc::new(Self {
455 socket_type: SocketType::Stream,
456 protocol: SocketProtocol::Default,
457 state: RwLock::new(SocketState::Connected),
458 local_addr: RwLock::new(Some(local_addr.clone())),
459 peer_addr: RwLock::new(Some(peer_addr.clone())),
460 read_buffer: RwLock::new(local_read_buffer.clone()),
461 peer_read_buffer: RwLock::new(Some(peer_read_buffer.clone())),
462 peer_socket: RwLock::new(None),
463 backlog: RwLock::new(Vec::new()),
464 max_backlog: RwLock::new(0),
465 accept_waker: Waker::new_interruptible("socket_accept"),
466 read_waker: Waker::new_interruptible("socket_read"),
467 handle_waker: Waker::new_interruptible("socket_handle"),
468 handle_queue: RwLock::new(VecDeque::new()),
469 self_weak: RwLock::new(Weak::new()),
470 nonblocking: RwLock::new(false),
471 });
472
473 let peer_socket = Arc::new(Self {
476 socket_type: SocketType::Stream,
477 protocol: SocketProtocol::Default,
478 state: RwLock::new(SocketState::Connected),
479 local_addr: RwLock::new(Some(peer_addr)),
480 peer_addr: RwLock::new(Some(local_addr)),
481 read_buffer: RwLock::new(peer_read_buffer.clone()),
482 peer_read_buffer: RwLock::new(Some(local_read_buffer.clone())),
483 peer_socket: RwLock::new(None),
484 backlog: RwLock::new(Vec::new()),
485 max_backlog: RwLock::new(0),
486 accept_waker: Waker::new_interruptible("socket_accept"),
487 read_waker: Waker::new_interruptible("socket_read"),
488 handle_waker: Waker::new_interruptible("socket_handle"),
489 handle_queue: RwLock::new(VecDeque::new()),
490 self_weak: RwLock::new(Weak::new()),
491 nonblocking: RwLock::new(false),
492 });
493
494 Self::init_self_weak(&local_socket);
495 Self::init_self_weak(&peer_socket);
496
497 *local_socket.peer_socket.write() = Some(Arc::downgrade(&peer_socket));
499 *peer_socket.peer_socket.write() = Some(Arc::downgrade(&local_socket));
500
501 (local_socket, peer_socket)
502 }
503
504 pub fn recv_handle_blocking(
509 &self,
510 task_id: usize,
511 trapframe: &mut crate::arch::Trapframe,
512 ) -> Result<KernelObject, crate::ipc::IpcError> {
513 use crate::ipc::IpcError;
514
515 loop {
516 {
518 let state = self.state.read();
519 if *state != SocketState::Connected {
520 return Err(IpcError::InvalidState);
521 }
522 }
523
524 {
526 let mut queue = self.handle_queue.write();
527 if let Some(obj) = queue.pop_front() {
528 return Ok(obj);
529 }
530 }
531
532 {
535 let peer_weak_opt = self.peer_socket.read();
536 if let Some(peer_weak) = peer_weak_opt.as_ref() {
537 if let Some(peer) = peer_weak.upgrade() {
538 let peer_state = peer.state.read();
539 if *peer_state == SocketState::Closed {
540 return Err(IpcError::PeerClosed);
541 }
542 } else {
543 return Err(IpcError::PeerClosed);
544 }
545 }
546 }
547
548 {
550 let read_buf = self.read_buffer.read();
551 let closed = read_buf.closed.read();
552 if *closed {
553 return Err(IpcError::PeerClosed);
554 }
555 }
556
557 self.handle_waker.wait(task_id, trapframe);
559 }
560 }
561}
562
563impl StreamOps for LocalSocket {
564 fn read(&self, buffer: &mut [u8]) -> Result<usize, StreamError> {
565 use crate::task::mytask;
566
567 static READ_ATTEMPT_COUNTER: core::sync::atomic::AtomicUsize =
569 core::sync::atomic::AtomicUsize::new(0);
570 let attempt = READ_ATTEMPT_COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
571
572 loop {
573 {
574 let read_buf_arc = self.read_buffer.read();
575 let mut read_data = read_buf_arc.data.write();
576 let is_nonblocking = *self.nonblocking.read();
577 let has_data = !read_data.is_empty();
578
579 if !read_data.is_empty() {
592 let bytes_to_read = buffer.len().min(read_data.len());
593 for i in 0..bytes_to_read {
594 buffer[i] = read_data.pop_front().unwrap();
595 }
596
597 return Ok(bytes_to_read);
605 }
606 } if *self.nonblocking.read() {
610 return Err(StreamError::WouldBlock);
618 }
619
620 {
621 let read_buf_arc = self.read_buffer.read();
622
623 let my_state = *self.state.read();
626 if my_state == SocketState::Closed {
627 return Ok(0);
628 }
629
630 if let Some(peer_weak) = self.peer_socket.read().as_ref() {
632 if let Some(peer) = peer_weak.upgrade() {
633 let peer_state = *peer.state.read();
634 if peer_state == SocketState::Closed {
635 return Ok(0); }
637 } else {
638 return Ok(0); }
640 }
641
642 if *read_buf_arc.closed.read() {
644 return Ok(0);
645 }
646
647 if let Some(task) = mytask() {
649 drop(read_buf_arc);
650
651 self.read_waker.wait(task.get_id(), task.get_trapframe());
653 } else {
654 return Err(StreamError::WouldBlock);
655 }
656 } }
659 }
660
661 fn write(&self, data: &[u8]) -> Result<usize, StreamError> {
662 let peer_buffer = self.peer_read_buffer.read();
663 match peer_buffer.as_ref() {
664 Some(peer_sock_buffer) => {
665 let mut peer_data = peer_sock_buffer.data.write();
666
667 if peer_data.len() + data.len() > MAX_BUFFER_SIZE {
669 return Err(StreamError::WouldBlock);
670 }
671
672 peer_data.extend(data.iter().copied());
674 let bytes_written = data.len();
675
676 drop(peer_data); if let Some(peer_weak) = self.peer_socket.read().as_ref() {
680 if let Some(peer) = peer_weak.upgrade() {
681 peer.read_waker.wake_one();
682 }
683 }
684
685 drop(peer_buffer); Ok(bytes_written)
688 }
689 None => {
690 Err(StreamError::Closed)
692 }
693 }
694 }
695}
696
697impl StreamIpcOps for LocalSocket {
698 fn is_connected(&self) -> bool {
699 *self.state.read() == SocketState::Connected
700 }
701
702 fn peer_count(&self) -> usize {
703 if StreamIpcOps::is_connected(self) {
704 1
705 } else {
706 0
707 }
708 }
709
710 fn description(&self) -> String {
711 let local = self.local_addr.read();
712 let peer = self.peer_addr.read();
713 format!("LocalSocket[{:?} -> {:?}]", local.as_ref(), peer.as_ref())
714 }
715}
716
717impl SocketControl for LocalSocket {
718 fn bind(&self, address: &SocketAddress) -> Result<(), SocketError> {
719 let mut state = self.state.write();
721 if *state != SocketState::Unconnected {
722 return Err(SocketError::AlreadyConnected);
723 }
724
725 let path = match address {
727 SocketAddress::Local(addr) => addr.path(),
728 _ => return Err(SocketError::InvalidAddress),
729 };
730
731 *self.local_addr.write() = Some(path.to_string());
735 *state = SocketState::Bound;
736
737 Ok(())
738 }
739
740 fn listen(&self, backlog: usize) -> Result<(), SocketError> {
741 let mut state = self.state.write();
742 if *state != SocketState::Bound {
743 return Err(SocketError::InvalidOperation);
744 }
745
746 *self.max_backlog.write() = backlog;
747 *state = SocketState::Listening;
748
749 Ok(())
750 }
751
752 fn accept(&self) -> Result<Arc<dyn SocketObject>, SocketError> {
753 let state = self.state.read();
754 if *state != SocketState::Listening {
755 return Err(SocketError::NotListening);
756 }
757 drop(state);
758
759 let mut backlog = self.backlog.write();
761 if let Some(client_socket) = backlog.pop() {
762 Ok(client_socket)
763 } else {
764 Err(SocketError::WouldBlock)
765 }
766 }
767
768 fn connect(&self, address: &SocketAddress) -> Result<(), SocketError> {
769 let state = self.state.read();
771 if *state != SocketState::Unconnected {
772 return Err(SocketError::AlreadyConnected);
773 }
774 drop(state);
775
776 let path = match address {
778 SocketAddress::Local(addr) => addr.path(),
779 _ => return Err(SocketError::InvalidAddress),
780 };
781
782 let manager = NetworkManager::get_manager();
784 let server_socket = match manager.lookup_named_socket(path) {
785 Ok(socket) => socket,
786 Err(e) => return Err(e),
787 };
788
789 if server_socket.state() != SocketState::Listening {
791 return Err(SocketError::ConnectionRefused);
792 }
793
794 let local_addr = format!("anon-{}", self as *const _ as usize);
800
801 let client_read_buffer = SocketBuffer::new();
803 let server_read_buffer = SocketBuffer::new();
804
805 let server_conn = Arc::new(Self {
807 socket_type: SocketType::Stream,
808 protocol: SocketProtocol::Default,
809 state: RwLock::new(SocketState::Connected),
810 local_addr: RwLock::new(Some(path.to_string())),
811 peer_addr: RwLock::new(Some(local_addr.clone())),
812 read_buffer: RwLock::new(server_read_buffer.clone()),
813 peer_read_buffer: RwLock::new(Some(client_read_buffer.clone())),
814 peer_socket: RwLock::new(None), backlog: RwLock::new(Vec::new()),
816 max_backlog: RwLock::new(0),
817 accept_waker: Waker::new_interruptible("socket_accept"),
818 read_waker: Waker::new_interruptible("socket_read"),
819 handle_waker: Waker::new_interruptible("socket_handle"),
820 handle_queue: RwLock::new(VecDeque::new()),
821 self_weak: RwLock::new(Weak::new()),
822 nonblocking: RwLock::new(false),
823 });
824
825 Self::init_self_weak(&server_conn);
826
827 *self.read_buffer.write() = client_read_buffer.clone();
829 *self.peer_read_buffer.write() = Some(server_read_buffer.clone());
830 *self.local_addr.write() = Some(local_addr);
831 *self.peer_addr.write() = Some(path.to_string());
832 *self.state.write() = SocketState::Connected;
833
834 *self.peer_socket.write() = Some(Arc::downgrade(&server_conn));
837
838 let client_arc = self
841 .self_weak
842 .read()
843 .upgrade()
844 .ok_or(SocketError::InvalidOperation)?;
845 *server_conn.peer_socket.write() = Some(Arc::downgrade(&client_arc));
846
847 let server_local = match Self::from_socket_object(&server_socket) {
849 Some(socket) => socket,
850 None => return Err(SocketError::InvalidOperation), };
852 let mut server_backlog = server_local.backlog.write();
853 let max_backlog = *server_local.max_backlog.read();
854
855 if server_backlog.len() >= max_backlog {
856 *self.read_buffer.write() = SocketBuffer::new();
858 *self.state.write() = SocketState::Unconnected;
859 *self.local_addr.write() = None;
860 *self.peer_addr.write() = None;
861 *self.peer_read_buffer.write() = None;
862 *self.peer_socket.write() = None;
863 return Err(SocketError::ConnectionRefused);
864 }
865 server_backlog.push(server_conn);
866 drop(server_backlog); server_local.accept_waker.wake_one();
870
871 Ok(())
872 }
873
874 fn shutdown(&self, how: ShutdownHow) -> Result<(), SocketError> {
875 let mut state = self.state.write();
876 if *state != SocketState::Connected {
877 return Err(SocketError::NotConnected);
878 }
879
880 match how {
883 ShutdownHow::Read | ShutdownHow::Write | ShutdownHow::Both => {
884 *state = SocketState::Closed;
885
886 if let Some(peer_buf) = self.peer_read_buffer.read().as_ref() {
888 *peer_buf.closed.write() = true;
890 }
891
892 if let Some(peer_weak) = self.peer_socket.read().as_ref() {
894 if let Some(peer) = peer_weak.upgrade() {
895 peer.read_waker.wake_one();
897 peer.handle_waker.wake_all();
899 } else {
900 }
902 } else {
903 self.read_waker.wake_all(); self.handle_waker.wake_all(); }
910
911 Ok(())
912 }
913 }
914 }
915
916 fn is_connected(&self) -> bool {
917 *self.state.read() == SocketState::Connected
918 }
919
920 fn state(&self) -> SocketState {
921 *self.state.read()
922 }
923
924 fn getpeername(&self) -> Result<SocketAddress, SocketError> {
925 let peer = self.peer_addr.read();
926 match peer.as_ref() {
927 Some(path) => Ok(SocketAddress::Local(
928 LocalSocketAddress::from_path(path)
929 .unwrap_or_else(|_| LocalSocketAddress::unnamed()),
930 )),
931 None => Err(SocketError::NotConnected),
932 }
933 }
934
935 fn getsockname(&self) -> Result<SocketAddress, SocketError> {
936 let local = self.local_addr.read();
937 match local.as_ref() {
938 Some(path) => Ok(SocketAddress::Local(
939 LocalSocketAddress::from_path(path)
940 .unwrap_or_else(|_| LocalSocketAddress::unnamed()),
941 )),
942 None => Err(SocketError::InvalidOperation),
943 }
944 }
945}
946
947impl SocketObject for LocalSocket {
948 fn socket_type(&self) -> SocketType {
949 self.socket_type
950 }
951
952 fn socket_domain(&self) -> SocketDomain {
953 SocketDomain::Local
954 }
955
956 fn socket_protocol(&self) -> SocketProtocol {
957 self.protocol
958 }
959
960 fn as_any(&self) -> &dyn Any {
961 self
962 }
963
964 fn as_selectable(&self) -> Option<&dyn Selectable> {
965 Some(self)
966 }
967
968 fn as_control_ops(&self) -> Option<&dyn crate::object::capability::ControlOps> {
969 Some(self)
970 }
971}
972
973impl Selectable for LocalSocket {
974 fn current_ready(&self, interest: ReadyInterest) -> ReadySet {
975 let mut ready = ReadySet::none();
976
977 let state = *self.state.read();
978
979 match state {
980 SocketState::Listening => {
981 if interest.read {
983 let backlog = self.backlog.read();
984 ready.read = !backlog.is_empty();
985 }
986 if interest.write {
988 ready.write = false;
989 }
990 }
991 SocketState::Connected => {
992 if interest.read {
994 let read_buffer = self.read_buffer.read();
995 let data = read_buffer.data.read();
996 let closed = *read_buffer.closed.read();
997 ready.read = !data.is_empty() || closed;
998 }
999 if interest.write {
1001 if let Some(peer_buffer) = self.peer_read_buffer.read().as_ref() {
1002 let data = peer_buffer.data.read();
1003 let closed = *peer_buffer.closed.read();
1004 ready.write = data.len() < MAX_BUFFER_SIZE && !closed;
1005 } else {
1006 ready.write = false;
1007 }
1008 }
1009 }
1010 _ => {
1011 ready.read = false;
1013 ready.write = false;
1014 }
1015 }
1016
1017 ready
1018 }
1019
1020 fn wait_until_ready(
1021 &self,
1022 interest: ReadyInterest,
1023 trapframe: &mut crate::arch::Trapframe,
1024 timeout_ticks: Option<u64>,
1025 ) -> SelectWaitOutcome {
1026 let current = self.current_ready(interest);
1028 if (interest.read && current.read) || (interest.write && current.write) {
1029 return SelectWaitOutcome::Ready;
1030 }
1031
1032 let state = *self.state.read();
1033
1034 let task_id = {
1036 use crate::arch::get_cpu;
1037 use crate::sched::scheduler::get_scheduler;
1038 let cpu_id = get_cpu().get_cpuid();
1039 get_scheduler().get_current_task_id(cpu_id).unwrap_or(0)
1040 };
1041
1042 let woke = match state {
1045 SocketState::Listening if interest.read => {
1046 self.accept_waker
1048 .wait_with_timeout(task_id, trapframe, timeout_ticks)
1049 }
1050 SocketState::Connected if interest.read => {
1051 self.read_waker
1053 .wait_with_timeout(task_id, trapframe, timeout_ticks)
1054 }
1055 SocketState::Connected if interest.write => {
1056 true
1059 }
1060 _ => {
1061 true
1063 }
1064 };
1065
1066 if timeout_ticks.is_some() && !woke {
1067 let after = self.current_ready(interest);
1068 if (interest.read && !after.read) && (interest.write && !after.write) {
1069 return SelectWaitOutcome::TimedOut;
1070 }
1071 }
1072
1073 SelectWaitOutcome::Ready
1076 }
1077
1078 fn set_nonblocking(&self, enabled: bool) {
1079 *self.nonblocking.write() = enabled;
1085 let verify = *self.nonblocking.read();
1086 }
1092
1093 fn is_nonblocking(&self) -> bool {
1094 let value = *self.nonblocking.read();
1095 value
1101 }
1102}
1103
1104impl ControlOps for LocalSocket {
1105 fn control(&self, command: u32, arg: usize) -> Result<i32, &'static str> {
1106 match command {
1108 crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK => {
1109 let enabled = arg != 0;
1110 self.set_nonblocking(enabled);
1112 let verify = self.is_nonblocking();
1113 Ok(0)
1115 }
1116 crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK => {
1117 let is_nonblocking = self.is_nonblocking();
1118 Ok(if is_nonblocking { 1 } else { 0 })
1123 }
1124 _ => {
1125 localsocket_log!("[LocalSocket::control] Unknown command");
1126 Err("Unknown control command")
1127 }
1128 }
1129 }
1130
1131 fn supported_control_commands(&self) -> alloc::vec::Vec<(u32, &'static str)> {
1132 alloc::vec![
1133 (
1134 crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK,
1135 "Set non-blocking mode",
1136 ),
1137 (
1138 crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK,
1139 "Get non-blocking mode",
1140 ),
1141 ]
1142 }
1143}
1144
1145pub fn local_socket_factory(
1150 socket_type: SocketType,
1151 protocol: SocketProtocol,
1152) -> Result<Arc<dyn SocketObject>, SocketError> {
1153 Ok(Arc::new(LocalSocket::new(socket_type, protocol)))
1154}
1155
1156#[cfg(test)]
1157mod tests {
1158 use super::*;
1159
1160 #[test_case]
1161 fn test_socket_creation() {
1162 let socket = LocalSocket::new(SocketType::Stream, SocketProtocol::Default);
1163 assert_eq!(socket.state(), SocketState::Unconnected);
1164 assert_eq!(socket.socket_domain(), SocketDomain::Local);
1165 }
1166
1167 #[test_case]
1168 fn test_socket_factory() {
1169 let socket = local_socket_factory(SocketType::Stream, SocketProtocol::Default).unwrap();
1170 assert_eq!(socket.socket_domain(), SocketDomain::Local);
1171 assert_eq!(socket.socket_type(), SocketType::Stream);
1172 }
1173
1174 #[test_case]
1175 fn test_connected_pair() {
1176 let (sock1, sock2) =
1177 LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1178 assert_eq!(sock1.state(), SocketState::Connected);
1179 assert_eq!(sock2.state(), SocketState::Connected);
1180 }
1181
1182 #[test_case]
1183 fn test_read_write() {
1184 let (sock1, sock2) =
1185 LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1186
1187 let data = b"Hello, World!";
1189 let written = sock1.write(data).unwrap();
1190 assert_eq!(written, data.len());
1191
1192 let mut buffer = [0u8; 32];
1194 let read = sock2.read(&mut buffer).unwrap();
1195 assert_eq!(read, data.len());
1196 assert_eq!(&buffer[..read], data);
1197 }
1198
1199 #[test_case]
1200 fn test_bidirectional_communication() {
1201 let (sock1, sock2) =
1202 LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1203
1204 sock1.write(b"ping").unwrap();
1206 let mut buf = [0u8; 4];
1207 sock2.read(&mut buf).unwrap();
1208 assert_eq!(&buf, b"ping");
1209
1210 sock2.write(b"pong").unwrap();
1212 let mut buf = [0u8; 4];
1213 sock1.read(&mut buf).unwrap();
1214 assert_eq!(&buf, b"pong");
1215 }
1216
1217 #[test_case]
1218 fn test_handle_transfer_send_recv() {
1219 use crate::ipc::SharedMemory;
1220 use alloc::sync::Arc;
1221
1222 let (sock1, sock2) =
1224 LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1225
1226 let shmem = match SharedMemory::new(4096, 0x3) {
1228 Ok(shmem) => shmem,
1230 Err(_) => {
1231 crate::println!("SharedMemory::new failed, skipping test");
1232 return;
1233 }
1234 };
1235 let shmem_obj = KernelObject::from_shared_memory_object(Arc::new(shmem));
1236
1237 let result = sock1.send_handle(shmem_obj);
1239 assert!(result.is_ok(), "send_handle should succeed");
1240
1241 let received = sock2.recv_handle();
1243 assert!(received.is_ok(), "recv_handle should succeed");
1244
1245 let received_obj = received.unwrap();
1247 assert!(
1248 received_obj.as_shared_memory().is_some(),
1249 "Received object should be SharedMemory"
1250 );
1251 }
1252
1253 #[test_case]
1254 fn test_handle_transfer_multiple_handles() {
1255 use crate::ipc::SharedMemory;
1256 use alloc::sync::Arc;
1257
1258 let (sock1, sock2) =
1260 LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1261
1262 for i in 0..3 {
1264 if let Ok(shmem) = SharedMemory::new(4096 * (i + 1), 0x3) {
1265 let shmem_obj = KernelObject::from_shared_memory_object(Arc::new(shmem));
1266 assert!(sock1.send_handle(shmem_obj).is_ok());
1267 }
1268 }
1269
1270 for _ in 0..3 {
1272 let received = sock2.recv_handle();
1273 assert!(received.is_ok(), "recv_handle should succeed");
1274 assert!(
1275 received.unwrap().as_shared_memory().is_some(),
1276 "Received object should be SharedMemory"
1277 );
1278 }
1279
1280 let result = sock2.recv_handle();
1282 assert!(
1283 result.is_err(),
1284 "recv_handle should fail when queue is empty"
1285 );
1286 }
1287
1288 #[test_case]
1289 fn test_handle_transfer_on_disconnected_socket() {
1290 use crate::ipc::SharedMemory;
1291 use alloc::sync::Arc;
1292
1293 let sock = LocalSocket::new(SocketType::Stream, SocketProtocol::Default);
1295
1296 if let Ok(shmem) = SharedMemory::new(4096, 0x3) {
1298 let shmem_obj = KernelObject::from_shared_memory_object(Arc::new(shmem));
1300 let result = sock.send_handle(shmem_obj);
1301 assert!(
1302 result.is_err(),
1303 "send_handle should fail on disconnected socket"
1304 );
1305 }
1306
1307 let result = sock.recv_handle();
1309 assert!(
1310 result.is_err(),
1311 "recv_handle should fail on disconnected socket"
1312 );
1313 }
1314
1315 #[test_case]
1316 fn test_handle_transfer_empty_queue() {
1317 let (_, sock2) =
1319 LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1320
1321 let result = sock2.recv_handle();
1323 assert!(
1324 result.is_err(),
1325 "recv_handle should fail when queue is empty"
1326 );
1327 }
1328}