kernel/network/
local.rs

1//! Local Socket Implementation
2//!
3//! This module implements local (Unix-like) domain sockets for inter-process
4//! communication through named socket paths in the filesystem namespace.
5//!
6//! # Design
7//!
8//! - **Named Sockets**: Sockets can be bound to filesystem paths
9//! - **Connection Oriented**: Uses stream sockets for reliable, ordered data transfer
10//! - **NetworkManager Integration**: Uses global NetworkManager for socket registry
11//! - **Direct Buffer Management**: Uses VecDeque for efficient data queuing
12//!
13//! # Socket States
14//!
15//! 1. **Unconnected**: Initial state after creation
16//! 2. **Bound**: Socket bound to a local address
17//! 3. **Listening**: Server socket accepting connections
18//! 4. **Connected**: Client socket or accepted connection
19
20use alloc::{
21    collections::VecDeque,
22    format,
23    string::{String, ToString},
24    sync::{Arc, Weak},
25    vec::Vec,
26};
27use core::any::Any;
28use spin::RwLock;
29
30use super::{
31    LocalSocketAddress, NetworkManager, ShutdownHow, SocketAddress, SocketControl, SocketDomain,
32    SocketError, SocketObject, SocketProtocol, SocketState, SocketType,
33};
34use crate::ipc::StreamIpcOps;
35use crate::object::KernelObject;
36use crate::object::capability::{
37    ControlOps, ReadyInterest, ReadySet, SelectWaitOutcome, Selectable, StreamError, StreamOps,
38};
39use crate::sync::Waker;
40
41const LOCALSOCKET_LOG: bool = false;
42
43macro_rules! localsocket_log {
44    ($($arg:tt)*) => {
45        if LOCALSOCKET_LOG {
46            crate::println!($($arg)*);
47        }
48    };
49}
50
51/// Maximum buffer size per socket (64 KB)
52const MAX_BUFFER_SIZE: usize = 65536;
53
54/// Maximum number of handles that can be queued for transfer
55/// This prevents unbounded memory growth from DoS attacks
56const MAX_HANDLE_QUEUE_SIZE: usize = 64;
57
58/// Shared buffer structure for socket data
59struct SocketBuffer {
60    /// Data buffer
61    data: RwLock<VecDeque<u8>>,
62    /// Flag indicating this buffer has been closed (peer shutdown)
63    closed: RwLock<bool>,
64}
65
66impl SocketBuffer {
67    fn new() -> Arc<Self> {
68        Arc::new(Self {
69            data: RwLock::new(VecDeque::with_capacity(MAX_BUFFER_SIZE)),
70            closed: RwLock::new(false),
71        })
72    }
73}
74
75/// Local Socket Implementation
76///
77/// This socket type provides local (Unix-like) domain socket functionality.
78/// It uses VecDeque buffers internally for data transfer and integrates with
79/// the NetworkManager for socket registry.
80pub struct LocalSocket {
81    /// Socket type (Stream, Datagram, etc.)
82    socket_type: SocketType,
83
84    /// Weak self reference (initialized when wrapped in Arc)
85    ///
86    /// This is used to establish peer relationships in methods that only
87    /// have `&self` (e.g., connect()), where we still need an `Arc<Self>`.
88    self_weak: RwLock<Weak<LocalSocket>>,
89
90    /// Socket protocol
91    protocol: SocketProtocol,
92
93    /// Current socket state
94    state: RwLock<SocketState>,
95
96    /// Local address (if bound)
97    local_addr: RwLock<Option<String>>,
98
99    /// Peer address (if connected)
100    peer_addr: RwLock<Option<String>>,
101
102    /// Read buffer: data received from peer (shared with peer for writing)
103    read_buffer: RwLock<Arc<SocketBuffer>>,
104
105    /// Write buffer reference: shared with peer socket for writing
106    /// When we write, we push to peer's read_buffer
107    peer_read_buffer: RwLock<Option<Arc<SocketBuffer>>>,
108
109    /// Peer socket reference (for waking read waiters)
110    peer_socket: RwLock<Option<Weak<LocalSocket>>>,
111
112    /// Backlog queue for listening sockets
113    /// Contains pending connections waiting to be accepted
114    backlog: RwLock<Vec<Arc<LocalSocket>>>,
115
116    /// Maximum backlog size (set by listen())
117    max_backlog: RwLock<usize>,
118
119    /// Waker for blocking accept() operations
120    accept_waker: Waker,
121
122    /// Waker for blocking read() operations
123    read_waker: Waker,
124
125    /// Waker for blocking recv_handle() operations
126    handle_waker: Waker,
127
128    /// Queue of handles (KernelObjects) received from peer
129    /// This allows passing file descriptors / kernel objects between tasks
130    handle_queue: RwLock<VecDeque<KernelObject>>,
131
132    /// Nonblocking I/O flag
133    nonblocking: RwLock<bool>,
134}
135
136impl LocalSocket {
137    pub(crate) fn init_self_weak(this: &Arc<Self>) {
138        *this.self_weak.write() = Arc::downgrade(this);
139    }
140
141    /// Safely downcast a SocketObject to LocalSocket using Any trait
142    ///
143    /// Returns None if the socket is not a LocalSocket.
144    /// This is completely safe and does not use any unsafe code.
145    pub fn from_socket_object(socket: &Arc<dyn SocketObject>) -> Option<&Self> {
146        // Use SocketObject's as_any() to get &dyn Any
147        socket.as_any().downcast_ref::<LocalSocket>()
148    }
149
150    /// Create a new local socket
151    ///
152    /// # Arguments
153    ///
154    /// * `socket_type` - Socket type (Stream, Datagram, etc.)
155    /// * `protocol` - Socket protocol
156    ///
157    /// # Returns
158    ///
159    /// A new socket in the Unconnected state
160    pub fn new(socket_type: SocketType, protocol: SocketProtocol) -> Self {
161        Self {
162            socket_type,
163            protocol,
164            state: RwLock::new(SocketState::Unconnected),
165            local_addr: RwLock::new(None),
166            peer_addr: RwLock::new(None),
167            read_buffer: RwLock::new(SocketBuffer::new()),
168            peer_read_buffer: RwLock::new(None),
169            peer_socket: RwLock::new(None),
170            backlog: RwLock::new(Vec::new()),
171            max_backlog: RwLock::new(0),
172            accept_waker: Waker::new_interruptible("socket_accept"),
173            read_waker: Waker::new_interruptible("socket_read"),
174            handle_waker: Waker::new_interruptible("socket_handle"),
175            handle_queue: RwLock::new(VecDeque::new()),
176            self_weak: RwLock::new(Weak::new()),
177            nonblocking: RwLock::new(false),
178        }
179    }
180
181    /// Send a KernelObject handle through this socket
182    ///
183    /// This is LocalSocket-only (SCM_RIGHTS equivalent) and uses dup() semantics.
184    pub fn send_handle(&self, object: KernelObject) -> Result<(), crate::ipc::IpcError> {
185        use crate::ipc::IpcError;
186
187        // Verify socket is connected
188        if *self.state.read() != SocketState::Connected {
189            return Err(IpcError::InvalidState);
190        }
191
192        // Get peer socket reference
193        let peer_weak = self.peer_socket.read();
194        let peer_weak_ref = peer_weak.as_ref().ok_or(IpcError::PeerClosed)?;
195        let peer = peer_weak_ref.upgrade().ok_or(IpcError::PeerClosed)?;
196
197        // Check if peer's handle queue is full to prevent DoS attacks
198        let mut peer_queue = peer.handle_queue.write();
199        if peer_queue.len() >= MAX_HANDLE_QUEUE_SIZE {
200            return Err(IpcError::ChannelFull);
201        }
202
203        // Add handle to peer's receive queue
204        peer_queue.push_back(object);
205        drop(peer_queue);
206
207        // Wake one task potentially blocked on recv_handle
208        peer.handle_waker.wake_one();
209
210        Ok(())
211    }
212
213    /// Send a handle and data together atomically for Wayland protocol
214    ///
215    /// This method ensures that both the handle and data are available before
216    /// waking the peer, preventing race conditions where recvmsg might get
217    /// the handle but not the data (or vice versa).
218    ///
219    /// This is needed for Wayland protocol messages with file descriptors,
220    /// where the client expects both the FD and message data in a single recvmsg call.
221    pub fn send_handle_and_data(
222        &self,
223        object: KernelObject,
224        data: &[u8],
225    ) -> Result<(), crate::ipc::IpcError> {
226        use crate::ipc::IpcError;
227
228        localsocket_log!(
229            "[LocalSocket] send_handle_and_data: self={:p}, data_len={}",
230            self as *const _,
231            data.len()
232        );
233
234        // Verify socket is connected
235        if *self.state.read() != SocketState::Connected {
236            localsocket_log!("[LocalSocket] send_handle_and_data: not connected");
237            return Err(IpcError::InvalidState);
238        }
239
240        // Get peer socket reference
241        let peer_weak = self.peer_socket.read();
242        let peer_weak_ref = peer_weak.as_ref().ok_or(IpcError::PeerClosed)?;
243        let peer = peer_weak_ref.upgrade().ok_or(IpcError::PeerClosed)?;
244
245        localsocket_log!(
246            "[LocalSocket] send_handle_and_data: peer={:p}",
247            peer.as_ref() as *const _
248        );
249
250        // Check if peer's handle queue is full to prevent DoS attacks
251        let mut peer_handle_queue = peer.handle_queue.write();
252        if peer_handle_queue.len() >= MAX_HANDLE_QUEUE_SIZE {
253            localsocket_log!("[LocalSocket] send_handle_and_data: handle queue full");
254            return Err(IpcError::ChannelFull);
255        }
256
257        // Get peer's data buffer through peer_read_buffer
258        let peer_buffer_option = peer.peer_read_buffer.read();
259        let peer_sock_buffer = peer_buffer_option.as_ref().ok_or(IpcError::PeerClosed)?;
260
261        // Check if peer's data buffer has space
262        let mut peer_buffer = peer_sock_buffer.data.write();
263        if peer_buffer.len() + data.len() > MAX_BUFFER_SIZE {
264            localsocket_log!(
265                "[LocalSocket] send_handle_and_data: buffer full, current_len={}, adding_len={}",
266                peer_buffer.len(),
267                data.len()
268            );
269            drop(peer_buffer);
270            drop(peer_buffer_option);
271            drop(peer_handle_queue);
272            return Err(IpcError::ChannelFull);
273        }
274
275        localsocket_log!(
276            "[LocalSocket] send_handle_and_data: before send - handle_queue_len={}, buffer_len={}",
277            peer_handle_queue.len(),
278            peer_buffer.len()
279        );
280
281        // Add handle to peer's receive queue
282        peer_handle_queue.push_back(object);
283        let queue_len = peer_handle_queue.len();
284        drop(peer_handle_queue);
285
286        // Add data to peer's buffer
287        peer_buffer.extend(data.iter().copied());
288        let buffer_len = peer_buffer.len();
289        drop(peer_buffer);
290        drop(peer_buffer_option);
291
292        localsocket_log!(
293            "[LocalSocket] send_handle_and_data: after send - handle_queue_len={}, buffer_len={}",
294            queue_len,
295            buffer_len
296        );
297
298        // Wake the peer after BOTH handle and data are available
299        peer.handle_waker.wake_one();
300        peer.read_waker.wake_one();
301
302        Ok(())
303    }
304
305    /// Receive a handle and data together atomically for Wayland protocol
306    ///
307    /// Returns both a handle and data in a single atomic operation.
308    /// This is the counterpart to send_handle_and_data().
309    ///
310    /// # Arguments
311    ///
312    /// * `max_data_len` - Maximum amount of data to read
313    ///
314    /// # Returns
315    ///
316    /// * `(KernelObject, Vec<u8>)` - Handle and data on success
317    /// * `IpcError` - Error if no handle/data available or other error
318    pub fn recv_handle_and_data(
319        &self,
320        max_data_len: usize,
321    ) -> Result<(KernelObject, Vec<u8>), crate::ipc::IpcError> {
322        use crate::ipc::IpcError;
323
324        localsocket_log!(
325            "[LocalSocket] recv_handle_and_data: self={:p}, max_data_len={}",
326            self as *const _,
327            max_data_len
328        );
329
330        // Verify socket is connected
331        if *self.state.read() != SocketState::Connected {
332            localsocket_log!("[LocalSocket] recv_handle_and_data: not connected");
333            return Err(IpcError::InvalidState);
334        }
335
336        // Try to get a handle from the queue
337        let mut queue = self.handle_queue.write();
338        localsocket_log!(
339            "[LocalSocket] recv_handle_and_data: handle_queue_len={}",
340            queue.len()
341        );
342
343        let handle = match queue.pop_front() {
344            Some(h) => h,
345            None => {
346                localsocket_log!(
347                    "[LocalSocket] recv_handle_and_data: handle queue empty - returning ChannelEmpty"
348                );
349                return Err(IpcError::ChannelEmpty);
350            }
351        };
352        drop(queue);
353
354        // Read data from read buffer
355        let read_buffer = self.read_buffer.read();
356        let mut buffer_data = read_buffer.data.write();
357        localsocket_log!(
358            "[LocalSocket] recv_handle_and_data: buffer_len={}, max_data_len={}",
359            buffer_data.len(),
360            max_data_len
361        );
362
363        // Read up to max_data_len bytes
364        let actual_len = buffer_data.len().min(max_data_len);
365        let mut data = Vec::with_capacity(actual_len);
366        for _ in 0..actual_len {
367            data.push(buffer_data.pop_front().unwrap());
368        }
369        drop(buffer_data);
370        drop(read_buffer);
371
372        localsocket_log!(
373            "[LocalSocket] recv_handle_and_data: returning handle and {} bytes of data",
374            data.len()
375        );
376
377        Ok((handle, data))
378    }
379
380    /// Receive a KernelObject handle from this socket (non-blocking)
381    pub fn recv_handle(&self) -> Result<KernelObject, crate::ipc::IpcError> {
382        use crate::ipc::IpcError;
383
384        // Verify socket is connected
385        if *self.state.read() != SocketState::Connected {
386            return Err(IpcError::InvalidState);
387        }
388
389        // Try to get a handle from the queue
390        let mut queue = self.handle_queue.write();
391        queue.pop_front().ok_or(IpcError::ChannelEmpty)
392    }
393
394    /// Accept a connection with blocking behavior
395    ///
396    /// This method blocks the calling task until a connection is available in the backlog.
397    /// It uses the waker mechanism to properly suspend and wake the task.
398    ///
399    /// # Arguments
400    ///
401    /// * `task_id` - ID of the task calling accept
402    /// * `trapframe` - Trapframe for scheduler context switching
403    ///
404    /// # Returns
405    ///
406    /// Arc to the accepted socket connection
407    pub fn accept_blocking(
408        &self,
409        task_id: usize,
410        trapframe: &mut crate::arch::Trapframe,
411    ) -> Result<Arc<dyn SocketObject>, SocketError> {
412        let state = self.state.read();
413        if *state != SocketState::Listening {
414            return Err(SocketError::NotListening);
415        }
416        drop(state);
417
418        // Try to get a pending connection from backlog
419        loop {
420            {
421                let mut backlog = self.backlog.write();
422                if let Some(client_socket) = backlog.pop() {
423                    return Ok(client_socket);
424                }
425            } // Release backlog lock
426
427            // No connection available, block the task
428            self.accept_waker.wait(task_id, trapframe);
429
430            // When we reach here, task has been woken up
431            // Check again if there's a connection
432        }
433    }
434
435    /// Create a connected socket pair (for internal use)
436    ///
437    /// This creates two connected sockets, useful for accept() implementation.
438    ///
439    /// # Arguments
440    ///
441    /// * `local_addr` - Local address for the first socket
442    /// * `peer_addr` - Peer address for the second socket
443    ///
444    /// # Returns
445    ///
446    /// A tuple of (local_socket, peer_socket) that are connected
447    pub fn create_connected_pair(local_addr: String, peer_addr: String) -> (Arc<Self>, Arc<Self>) {
448        // Create shared buffers for bidirectional communication
449        let local_read_buffer = SocketBuffer::new();
450        let peer_read_buffer = SocketBuffer::new();
451
452        // Create local socket (server side)
453        // It reads from local_read_buffer, writes to peer_read_buffer
454        let local_socket = Arc::new(Self {
455            socket_type: SocketType::Stream,
456            protocol: SocketProtocol::Default,
457            state: RwLock::new(SocketState::Connected),
458            local_addr: RwLock::new(Some(local_addr.clone())),
459            peer_addr: RwLock::new(Some(peer_addr.clone())),
460            read_buffer: RwLock::new(local_read_buffer.clone()),
461            peer_read_buffer: RwLock::new(Some(peer_read_buffer.clone())),
462            peer_socket: RwLock::new(None),
463            backlog: RwLock::new(Vec::new()),
464            max_backlog: RwLock::new(0),
465            accept_waker: Waker::new_interruptible("socket_accept"),
466            read_waker: Waker::new_interruptible("socket_read"),
467            handle_waker: Waker::new_interruptible("socket_handle"),
468            handle_queue: RwLock::new(VecDeque::new()),
469            self_weak: RwLock::new(Weak::new()),
470            nonblocking: RwLock::new(false),
471        });
472
473        // Create peer socket (client side)
474        // It reads from peer_read_buffer, writes to local_read_buffer
475        let peer_socket = Arc::new(Self {
476            socket_type: SocketType::Stream,
477            protocol: SocketProtocol::Default,
478            state: RwLock::new(SocketState::Connected),
479            local_addr: RwLock::new(Some(peer_addr)),
480            peer_addr: RwLock::new(Some(local_addr)),
481            read_buffer: RwLock::new(peer_read_buffer.clone()),
482            peer_read_buffer: RwLock::new(Some(local_read_buffer.clone())),
483            peer_socket: RwLock::new(None),
484            backlog: RwLock::new(Vec::new()),
485            max_backlog: RwLock::new(0),
486            accept_waker: Waker::new_interruptible("socket_accept"),
487            read_waker: Waker::new_interruptible("socket_read"),
488            handle_waker: Waker::new_interruptible("socket_handle"),
489            handle_queue: RwLock::new(VecDeque::new()),
490            self_weak: RwLock::new(Weak::new()),
491            nonblocking: RwLock::new(false),
492        });
493
494        Self::init_self_weak(&local_socket);
495        Self::init_self_weak(&peer_socket);
496
497        // Set peer references
498        *local_socket.peer_socket.write() = Some(Arc::downgrade(&peer_socket));
499        *peer_socket.peer_socket.write() = Some(Arc::downgrade(&local_socket));
500
501        (local_socket, peer_socket)
502    }
503
504    /// Blocking handle receive operation
505    ///
506    /// This method blocks the calling task until a handle is available in the
507    /// handle queue, or the peer is closed.
508    pub fn recv_handle_blocking(
509        &self,
510        task_id: usize,
511        trapframe: &mut crate::arch::Trapframe,
512    ) -> Result<KernelObject, crate::ipc::IpcError> {
513        use crate::ipc::IpcError;
514
515        loop {
516            // Verify socket is connected
517            {
518                let state = self.state.read();
519                if *state != SocketState::Connected {
520                    return Err(IpcError::InvalidState);
521                }
522            }
523
524            // Fast path: handle already queued
525            {
526                let mut queue = self.handle_queue.write();
527                if let Some(obj) = queue.pop_front() {
528                    return Ok(obj);
529                }
530            }
531
532            // If peer has shut down (or been dropped), don't block forever.
533            // We reuse the same conditions as read_blocking() uses for EOF.
534            {
535                let peer_weak_opt = self.peer_socket.read();
536                if let Some(peer_weak) = peer_weak_opt.as_ref() {
537                    if let Some(peer) = peer_weak.upgrade() {
538                        let peer_state = peer.state.read();
539                        if *peer_state == SocketState::Closed {
540                            return Err(IpcError::PeerClosed);
541                        }
542                    } else {
543                        return Err(IpcError::PeerClosed);
544                    }
545                }
546            }
547
548            // If peer performed shutdown(), our read buffer is marked closed.
549            {
550                let read_buf = self.read_buffer.read();
551                let closed = read_buf.closed.read();
552                if *closed {
553                    return Err(IpcError::PeerClosed);
554                }
555            }
556
557            // No handle available, block the task
558            self.handle_waker.wait(task_id, trapframe);
559        }
560    }
561}
562
563impl StreamOps for LocalSocket {
564    fn read(&self, buffer: &mut [u8]) -> Result<usize, StreamError> {
565        use crate::task::mytask;
566
567        // Debug: count read attempts
568        static READ_ATTEMPT_COUNTER: core::sync::atomic::AtomicUsize =
569            core::sync::atomic::AtomicUsize::new(0);
570        let attempt = READ_ATTEMPT_COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
571
572        loop {
573            {
574                let read_buf_arc = self.read_buffer.read();
575                let mut read_data = read_buf_arc.data.write();
576                let is_nonblocking = *self.nonblocking.read();
577                let has_data = !read_data.is_empty();
578
579                // // Log every 100 attempts or first 5 attempts
580                // if attempt < 5 || attempt % 100 == 0 {
581                //     crate::println!(
582                //         "[LocalSocket::read] self={:p} attempt={} nonblocking={} has_data={} data_len={}",
583                //         self as *const _,
584                //         attempt,
585                //         is_nonblocking,
586                //         has_data,
587                //         read_data.len()
588                //     );
589                // }
590
591                if !read_data.is_empty() {
592                    let bytes_to_read = buffer.len().min(read_data.len());
593                    for i in 0..bytes_to_read {
594                        buffer[i] = read_data.pop_front().unwrap();
595                    }
596
597                    // if attempt < 5 || attempt % 100 == 0 {
598                    //     crate::println!(
599                    //         "[LocalSocket::read] attempt={} returning {} bytes",
600                    //         attempt,
601                    //         bytes_to_read
602                    //     );
603                    // }
604                    return Ok(bytes_to_read);
605                }
606            } // Release locks before checking nonblocking/EOF
607
608            // Check nonblocking mode before blocking
609            if *self.nonblocking.read() {
610                // // Nonblocking mode: return WouldBlock error immediately
611                // if attempt < 5 || attempt % 100 == 0 {
612                //     crate::println!(
613                //         "[LocalSocket::read] attempt={} returning WouldBlock",
614                //         attempt
615                //     );
616                // }
617                return Err(StreamError::WouldBlock);
618            }
619
620            {
621                let read_buf_arc = self.read_buffer.read();
622
623                // Check if socket is closed (peer shutdown)
624                // Return 0 to indicate EOF (not an error)
625                let my_state = *self.state.read();
626                if my_state == SocketState::Closed {
627                    return Ok(0);
628                }
629
630                // Check if peer is closed (they called shutdown)
631                if let Some(peer_weak) = self.peer_socket.read().as_ref() {
632                    if let Some(peer) = peer_weak.upgrade() {
633                        let peer_state = *peer.state.read();
634                        if peer_state == SocketState::Closed {
635                            return Ok(0); // Peer closed, return EOF
636                        }
637                    } else {
638                        return Ok(0); // Peer dropped, treat as EOF
639                    }
640                }
641
642                // Check if this read buffer has been closed by peer's shutdown()
643                if *read_buf_arc.closed.read() {
644                    return Ok(0);
645                }
646
647                // Register this task as waiting to read
648                if let Some(task) = mytask() {
649                    drop(read_buf_arc);
650
651                    // Block the task
652                    self.read_waker.wait(task.get_id(), task.get_trapframe());
653                } else {
654                    return Err(StreamError::WouldBlock);
655                }
656            } // Release lock
657            // When woken, loop back to check for data or shutdown
658        }
659    }
660
661    fn write(&self, data: &[u8]) -> Result<usize, StreamError> {
662        let peer_buffer = self.peer_read_buffer.read();
663        match peer_buffer.as_ref() {
664            Some(peer_sock_buffer) => {
665                let mut peer_data = peer_sock_buffer.data.write();
666
667                // Check if buffer has space
668                if peer_data.len() + data.len() > MAX_BUFFER_SIZE {
669                    return Err(StreamError::WouldBlock);
670                }
671
672                // Write data to peer's read buffer
673                peer_data.extend(data.iter().copied());
674                let bytes_written = data.len();
675
676                drop(peer_data); // Release data lock
677
678                // Wake tasks waiting on read/select/poll.
679                if let Some(peer_weak) = self.peer_socket.read().as_ref() {
680                    if let Some(peer) = peer_weak.upgrade() {
681                        peer.read_waker.wake_one();
682                    }
683                }
684
685                drop(peer_buffer); // Release peer_buffer lock
686
687                Ok(bytes_written)
688            }
689            None => {
690                // crate::println!("[LocalSocket] write: peer buffer is None (closed)");
691                Err(StreamError::Closed)
692            }
693        }
694    }
695}
696
697impl StreamIpcOps for LocalSocket {
698    fn is_connected(&self) -> bool {
699        *self.state.read() == SocketState::Connected
700    }
701
702    fn peer_count(&self) -> usize {
703        if StreamIpcOps::is_connected(self) {
704            1
705        } else {
706            0
707        }
708    }
709
710    fn description(&self) -> String {
711        let local = self.local_addr.read();
712        let peer = self.peer_addr.read();
713        format!("LocalSocket[{:?} -> {:?}]", local.as_ref(), peer.as_ref())
714    }
715}
716
717impl SocketControl for LocalSocket {
718    fn bind(&self, address: &SocketAddress) -> Result<(), SocketError> {
719        // Check socket is unconnected
720        let mut state = self.state.write();
721        if *state != SocketState::Unconnected {
722            return Err(SocketError::AlreadyConnected);
723        }
724
725        // Extract path from address
726        let path = match address {
727            SocketAddress::Local(addr) => addr.path(),
728            _ => return Err(SocketError::InvalidAddress),
729        };
730
731        // Update state
732        // Note: NetworkManager registration is done by the syscall layer
733        // to ensure the same Arc<Self> is registered that's in the handle table
734        *self.local_addr.write() = Some(path.to_string());
735        *state = SocketState::Bound;
736
737        Ok(())
738    }
739
740    fn listen(&self, backlog: usize) -> Result<(), SocketError> {
741        let mut state = self.state.write();
742        if *state != SocketState::Bound {
743            return Err(SocketError::InvalidOperation);
744        }
745
746        *self.max_backlog.write() = backlog;
747        *state = SocketState::Listening;
748
749        Ok(())
750    }
751
752    fn accept(&self) -> Result<Arc<dyn SocketObject>, SocketError> {
753        let state = self.state.read();
754        if *state != SocketState::Listening {
755            return Err(SocketError::NotListening);
756        }
757        drop(state);
758
759        // Try to get a pending connection from backlog
760        let mut backlog = self.backlog.write();
761        if let Some(client_socket) = backlog.pop() {
762            Ok(client_socket)
763        } else {
764            Err(SocketError::WouldBlock)
765        }
766    }
767
768    fn connect(&self, address: &SocketAddress) -> Result<(), SocketError> {
769        // Validate current state
770        let state = self.state.read();
771        if *state != SocketState::Unconnected {
772            return Err(SocketError::AlreadyConnected);
773        }
774        drop(state);
775
776        // Extract path from address
777        let path = match address {
778            SocketAddress::Local(addr) => addr.path(),
779            _ => return Err(SocketError::InvalidAddress),
780        };
781
782        // Lookup listening socket in NetworkManager
783        let manager = NetworkManager::get_manager();
784        let server_socket = match manager.lookup_named_socket(path) {
785            Ok(socket) => socket,
786            Err(e) => return Err(e),
787        };
788
789        // Check server is listening
790        if server_socket.state() != SocketState::Listening {
791            return Err(SocketError::ConnectionRefused);
792        }
793
794        // We need to create a proper Arc to self to be able to store a Weak reference in the peer
795        // Since we're in &self, we don't have access to the Arc. We'll need to store the
796        // connection information and let the server-side socket refer back through handle table.
797
798        // Instead, we'll use a different approach: create shared buffers and update both sockets
799        let local_addr = format!("anon-{}", self as *const _ as usize);
800
801        // Create shared buffers for bidirectional communication
802        let client_read_buffer = SocketBuffer::new();
803        let server_read_buffer = SocketBuffer::new();
804
805        // Create server-side connection socket that will be added to backlog
806        let server_conn = Arc::new(Self {
807            socket_type: SocketType::Stream,
808            protocol: SocketProtocol::Default,
809            state: RwLock::new(SocketState::Connected),
810            local_addr: RwLock::new(Some(path.to_string())),
811            peer_addr: RwLock::new(Some(local_addr.clone())),
812            read_buffer: RwLock::new(server_read_buffer.clone()),
813            peer_read_buffer: RwLock::new(Some(client_read_buffer.clone())),
814            peer_socket: RwLock::new(None), // Will be set below
815            backlog: RwLock::new(Vec::new()),
816            max_backlog: RwLock::new(0),
817            accept_waker: Waker::new_interruptible("socket_accept"),
818            read_waker: Waker::new_interruptible("socket_read"),
819            handle_waker: Waker::new_interruptible("socket_handle"),
820            handle_queue: RwLock::new(VecDeque::new()),
821            self_weak: RwLock::new(Weak::new()),
822            nonblocking: RwLock::new(false),
823        });
824
825        Self::init_self_weak(&server_conn);
826
827        // Update self (client socket) to use the shared buffers
828        *self.read_buffer.write() = client_read_buffer.clone();
829        *self.peer_read_buffer.write() = Some(server_read_buffer.clone());
830        *self.local_addr.write() = Some(local_addr);
831        *self.peer_addr.write() = Some(path.to_string());
832        *self.state.write() = SocketState::Connected;
833
834        // Set peer_socket references - IMPORTANT for shutdown()
835        // Client (self) points to server_conn
836        *self.peer_socket.write() = Some(Arc::downgrade(&server_conn));
837
838        // Server_conn needs to point back to client for handle transfer.
839        // We keep a Weak<Self> initialized at creation time, so upgrade it here.
840        let client_arc = self
841            .self_weak
842            .read()
843            .upgrade()
844            .ok_or(SocketError::InvalidOperation)?;
845        *server_conn.peer_socket.write() = Some(Arc::downgrade(&client_arc));
846
847        // Add server connection to server's backlog
848        let server_local = match Self::from_socket_object(&server_socket) {
849            Some(socket) => socket,
850            None => return Err(SocketError::InvalidOperation), // Not a LocalSocket
851        };
852        let mut server_backlog = server_local.backlog.write();
853        let max_backlog = *server_local.max_backlog.read();
854
855        if server_backlog.len() >= max_backlog {
856            // Rollback state change - restore original empty buffer
857            *self.read_buffer.write() = SocketBuffer::new();
858            *self.state.write() = SocketState::Unconnected;
859            *self.local_addr.write() = None;
860            *self.peer_addr.write() = None;
861            *self.peer_read_buffer.write() = None;
862            *self.peer_socket.write() = None;
863            return Err(SocketError::ConnectionRefused);
864        }
865        server_backlog.push(server_conn);
866        drop(server_backlog); // Release lock before waking
867
868        // Wake up any task waiting in accept()
869        server_local.accept_waker.wake_one();
870
871        Ok(())
872    }
873
874    fn shutdown(&self, how: ShutdownHow) -> Result<(), SocketError> {
875        let mut state = self.state.write();
876        if *state != SocketState::Connected {
877            return Err(SocketError::NotConnected);
878        }
879
880        // crate::println!("[LocalSocket] shutdown({:?}) called", how);
881
882        match how {
883            ShutdownHow::Read | ShutdownHow::Write | ShutdownHow::Both => {
884                *state = SocketState::Closed;
885
886                // Mark peer's read buffer as closed so they detect EOF
887                if let Some(peer_buf) = self.peer_read_buffer.read().as_ref() {
888                    // crate::println!("[LocalSocket] shutdown: marking peer_read_buffer as closed");
889                    *peer_buf.closed.write() = true;
890                }
891
892                // Wake up peer's read_waker so it can detect the shutdown
893                if let Some(peer_weak) = self.peer_socket.read().as_ref() {
894                    if let Some(peer) = peer_weak.upgrade() {
895                        // crate::println!("[LocalSocket] shutdown: waking peer's read_waker");
896                        peer.read_waker.wake_one();
897                        // Also wake any tasks waiting for handle transfer
898                        peer.handle_waker.wake_all();
899                    } else {
900                        // crate::println!("[LocalSocket] shutdown: peer already dropped");
901                    }
902                } else {
903                    // No direct peer reference - wake via waker
904                    // crate::println!(
905                    //     "[LocalSocket] shutdown: no peer_socket, waking via read_waker"
906                    // );
907                    self.read_waker.wake_all(); // Wake any waiting readers
908                    self.handle_waker.wake_all(); // Wake any waiting handle receivers
909                }
910
911                Ok(())
912            }
913        }
914    }
915
916    fn is_connected(&self) -> bool {
917        *self.state.read() == SocketState::Connected
918    }
919
920    fn state(&self) -> SocketState {
921        *self.state.read()
922    }
923
924    fn getpeername(&self) -> Result<SocketAddress, SocketError> {
925        let peer = self.peer_addr.read();
926        match peer.as_ref() {
927            Some(path) => Ok(SocketAddress::Local(
928                LocalSocketAddress::from_path(path)
929                    .unwrap_or_else(|_| LocalSocketAddress::unnamed()),
930            )),
931            None => Err(SocketError::NotConnected),
932        }
933    }
934
935    fn getsockname(&self) -> Result<SocketAddress, SocketError> {
936        let local = self.local_addr.read();
937        match local.as_ref() {
938            Some(path) => Ok(SocketAddress::Local(
939                LocalSocketAddress::from_path(path)
940                    .unwrap_or_else(|_| LocalSocketAddress::unnamed()),
941            )),
942            None => Err(SocketError::InvalidOperation),
943        }
944    }
945}
946
947impl SocketObject for LocalSocket {
948    fn socket_type(&self) -> SocketType {
949        self.socket_type
950    }
951
952    fn socket_domain(&self) -> SocketDomain {
953        SocketDomain::Local
954    }
955
956    fn socket_protocol(&self) -> SocketProtocol {
957        self.protocol
958    }
959
960    fn as_any(&self) -> &dyn Any {
961        self
962    }
963
964    fn as_selectable(&self) -> Option<&dyn Selectable> {
965        Some(self)
966    }
967
968    fn as_control_ops(&self) -> Option<&dyn crate::object::capability::ControlOps> {
969        Some(self)
970    }
971}
972
973impl Selectable for LocalSocket {
974    fn current_ready(&self, interest: ReadyInterest) -> ReadySet {
975        let mut ready = ReadySet::none();
976
977        let state = *self.state.read();
978
979        match state {
980            SocketState::Listening => {
981                // Listening sockets: readable when backlog has connections
982                if interest.read {
983                    let backlog = self.backlog.read();
984                    ready.read = !backlog.is_empty();
985                }
986                // Listening sockets are always writable (not applicable)
987                if interest.write {
988                    ready.write = false;
989                }
990            }
991            SocketState::Connected => {
992                // Connected sockets: readable when data available
993                if interest.read {
994                    let read_buffer = self.read_buffer.read();
995                    let data = read_buffer.data.read();
996                    let closed = *read_buffer.closed.read();
997                    ready.read = !data.is_empty() || closed;
998                }
999                // Connected sockets: writable when peer buffer not full
1000                if interest.write {
1001                    if let Some(peer_buffer) = self.peer_read_buffer.read().as_ref() {
1002                        let data = peer_buffer.data.read();
1003                        let closed = *peer_buffer.closed.read();
1004                        ready.write = data.len() < MAX_BUFFER_SIZE && !closed;
1005                    } else {
1006                        ready.write = false;
1007                    }
1008                }
1009            }
1010            _ => {
1011                // Unconnected/Bound/other: not ready
1012                ready.read = false;
1013                ready.write = false;
1014            }
1015        }
1016
1017        ready
1018    }
1019
1020    fn wait_until_ready(
1021        &self,
1022        interest: ReadyInterest,
1023        trapframe: &mut crate::arch::Trapframe,
1024        timeout_ticks: Option<u64>,
1025    ) -> SelectWaitOutcome {
1026        // Check if already ready
1027        let current = self.current_ready(interest);
1028        if (interest.read && current.read) || (interest.write && current.write) {
1029            return SelectWaitOutcome::Ready;
1030        }
1031
1032        let state = *self.state.read();
1033
1034        // Get current task ID
1035        let task_id = {
1036            use crate::arch::get_cpu;
1037            use crate::sched::scheduler::get_scheduler;
1038            let cpu_id = get_cpu().get_cpuid();
1039            get_scheduler().get_current_task_id(cpu_id).unwrap_or(0)
1040        };
1041
1042        // Wait based on state and interest
1043        // Note: timeout is not yet implemented - always blocks until ready
1044        let woke = match state {
1045            SocketState::Listening if interest.read => {
1046                // Wait for incoming connections
1047                self.accept_waker
1048                    .wait_with_timeout(task_id, trapframe, timeout_ticks)
1049            }
1050            SocketState::Connected if interest.read => {
1051                // Wait for data to arrive
1052                self.read_waker
1053                    .wait_with_timeout(task_id, trapframe, timeout_ticks)
1054            }
1055            SocketState::Connected if interest.write => {
1056                // For write readiness, treat as immediately ready (optimistic)
1057                // Most sockets are writable most of the time
1058                true
1059            }
1060            _ => {
1061                // Other states: immediately return as not ready
1062                true
1063            }
1064        };
1065
1066        if timeout_ticks.is_some() && !woke {
1067            let after = self.current_ready(interest);
1068            if (interest.read && !after.read) && (interest.write && !after.write) {
1069                return SelectWaitOutcome::TimedOut;
1070            }
1071        }
1072
1073        // After waking, consider it ready
1074        // TODO: properly check timeout and return TimedOut if needed
1075        SelectWaitOutcome::Ready
1076    }
1077
1078    fn set_nonblocking(&self, enabled: bool) {
1079        // crate::println!(
1080        //     "[LocalSocket::set_nonblocking] self={:p} enabled={}",
1081        //     self as *const _,
1082        //     enabled
1083        // );
1084        *self.nonblocking.write() = enabled;
1085        let verify = *self.nonblocking.read();
1086        // crate::println!(
1087        //     "[LocalSocket::set_nonblocking] self={:p} after write, read back={}",
1088        //     self as *const _,
1089        //     verify
1090        // );
1091    }
1092
1093    fn is_nonblocking(&self) -> bool {
1094        let value = *self.nonblocking.read();
1095        // crate::println!(
1096        //     "[LocalSocket::is_nonblocking] self={:p} returning={}",
1097        //     self as *const _,
1098        //     value
1099        // );
1100        value
1101    }
1102}
1103
1104impl ControlOps for LocalSocket {
1105    fn control(&self, command: u32, arg: usize) -> Result<i32, &'static str> {
1106        // crate::println!("[LocalSocket::control] command={} arg={}", command, arg);
1107        match command {
1108            crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK => {
1109                let enabled = arg != 0;
1110                // crate::println!("[LocalSocket::control] Setting nonblocking={}", enabled);
1111                self.set_nonblocking(enabled);
1112                let verify = self.is_nonblocking();
1113                // crate::println!("[LocalSocket::control] Verified nonblocking={}", verify);
1114                Ok(0)
1115            }
1116            crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK => {
1117                let is_nonblocking = self.is_nonblocking();
1118                // crate::println!(
1119                // "[LocalSocket::control] Getting nonblocking={}",
1120                // is_nonblocking
1121                // );
1122                Ok(if is_nonblocking { 1 } else { 0 })
1123            }
1124            _ => {
1125                localsocket_log!("[LocalSocket::control] Unknown command");
1126                Err("Unknown control command")
1127            }
1128        }
1129    }
1130
1131    fn supported_control_commands(&self) -> alloc::vec::Vec<(u32, &'static str)> {
1132        alloc::vec![
1133            (
1134                crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK,
1135                "Set non-blocking mode",
1136            ),
1137            (
1138                crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK,
1139                "Get non-blocking mode",
1140            ),
1141        ]
1142    }
1143}
1144
1145/// Socket factory function for local sockets
1146///
1147/// This function is registered with the NetworkManager to create
1148/// local domain sockets.
1149pub fn local_socket_factory(
1150    socket_type: SocketType,
1151    protocol: SocketProtocol,
1152) -> Result<Arc<dyn SocketObject>, SocketError> {
1153    Ok(Arc::new(LocalSocket::new(socket_type, protocol)))
1154}
1155
1156#[cfg(test)]
1157mod tests {
1158    use super::*;
1159
1160    #[test_case]
1161    fn test_socket_creation() {
1162        let socket = LocalSocket::new(SocketType::Stream, SocketProtocol::Default);
1163        assert_eq!(socket.state(), SocketState::Unconnected);
1164        assert_eq!(socket.socket_domain(), SocketDomain::Local);
1165    }
1166
1167    #[test_case]
1168    fn test_socket_factory() {
1169        let socket = local_socket_factory(SocketType::Stream, SocketProtocol::Default).unwrap();
1170        assert_eq!(socket.socket_domain(), SocketDomain::Local);
1171        assert_eq!(socket.socket_type(), SocketType::Stream);
1172    }
1173
1174    #[test_case]
1175    fn test_connected_pair() {
1176        let (sock1, sock2) =
1177            LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1178        assert_eq!(sock1.state(), SocketState::Connected);
1179        assert_eq!(sock2.state(), SocketState::Connected);
1180    }
1181
1182    #[test_case]
1183    fn test_read_write() {
1184        let (sock1, sock2) =
1185            LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1186
1187        // Write from sock1 to sock2
1188        let data = b"Hello, World!";
1189        let written = sock1.write(data).unwrap();
1190        assert_eq!(written, data.len());
1191
1192        // Read from sock2
1193        let mut buffer = [0u8; 32];
1194        let read = sock2.read(&mut buffer).unwrap();
1195        assert_eq!(read, data.len());
1196        assert_eq!(&buffer[..read], data);
1197    }
1198
1199    #[test_case]
1200    fn test_bidirectional_communication() {
1201        let (sock1, sock2) =
1202            LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1203
1204        // sock1 -> sock2
1205        sock1.write(b"ping").unwrap();
1206        let mut buf = [0u8; 4];
1207        sock2.read(&mut buf).unwrap();
1208        assert_eq!(&buf, b"ping");
1209
1210        // sock2 -> sock1
1211        sock2.write(b"pong").unwrap();
1212        let mut buf = [0u8; 4];
1213        sock1.read(&mut buf).unwrap();
1214        assert_eq!(&buf, b"pong");
1215    }
1216
1217    #[test_case]
1218    fn test_handle_transfer_send_recv() {
1219        use crate::ipc::SharedMemory;
1220        use alloc::sync::Arc;
1221
1222        // Create a connected socket pair
1223        let (sock1, sock2) =
1224            LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1225
1226        // Create a shared memory object to transfer
1227        let shmem = match SharedMemory::new(4096, 0x3) {
1228            // READ | WRITE
1229            Ok(shmem) => shmem,
1230            Err(_) => {
1231                crate::println!("SharedMemory::new failed, skipping test");
1232                return;
1233            }
1234        };
1235        let shmem_obj = KernelObject::from_shared_memory_object(Arc::new(shmem));
1236
1237        // Send handle from sock1 to sock2
1238        let result = sock1.send_handle(shmem_obj);
1239        assert!(result.is_ok(), "send_handle should succeed");
1240
1241        // Receive handle at sock2
1242        let received = sock2.recv_handle();
1243        assert!(received.is_ok(), "recv_handle should succeed");
1244
1245        // Verify it's a SharedMemory object
1246        let received_obj = received.unwrap();
1247        assert!(
1248            received_obj.as_shared_memory().is_some(),
1249            "Received object should be SharedMemory"
1250        );
1251    }
1252
1253    #[test_case]
1254    fn test_handle_transfer_multiple_handles() {
1255        use crate::ipc::SharedMemory;
1256        use alloc::sync::Arc;
1257
1258        // Create a connected socket pair
1259        let (sock1, sock2) =
1260            LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1261
1262        // Send multiple handles
1263        for i in 0..3 {
1264            if let Ok(shmem) = SharedMemory::new(4096 * (i + 1), 0x3) {
1265                let shmem_obj = KernelObject::from_shared_memory_object(Arc::new(shmem));
1266                assert!(sock1.send_handle(shmem_obj).is_ok());
1267            }
1268        }
1269
1270        // Receive all handles
1271        for _ in 0..3 {
1272            let received = sock2.recv_handle();
1273            assert!(received.is_ok(), "recv_handle should succeed");
1274            assert!(
1275                received.unwrap().as_shared_memory().is_some(),
1276                "Received object should be SharedMemory"
1277            );
1278        }
1279
1280        // Queue should be empty now
1281        let result = sock2.recv_handle();
1282        assert!(
1283            result.is_err(),
1284            "recv_handle should fail when queue is empty"
1285        );
1286    }
1287
1288    #[test_case]
1289    fn test_handle_transfer_on_disconnected_socket() {
1290        use crate::ipc::SharedMemory;
1291        use alloc::sync::Arc;
1292
1293        // Create an unconnected socket
1294        let sock = LocalSocket::new(SocketType::Stream, SocketProtocol::Default);
1295
1296        // Try to send handle on disconnected socket
1297        if let Ok(shmem) = SharedMemory::new(4096, 0x3) {
1298            // READ | WRITE
1299            let shmem_obj = KernelObject::from_shared_memory_object(Arc::new(shmem));
1300            let result = sock.send_handle(shmem_obj);
1301            assert!(
1302                result.is_err(),
1303                "send_handle should fail on disconnected socket"
1304            );
1305        }
1306
1307        // Try to receive handle on disconnected socket
1308        let result = sock.recv_handle();
1309        assert!(
1310            result.is_err(),
1311            "recv_handle should fail on disconnected socket"
1312        );
1313    }
1314
1315    #[test_case]
1316    fn test_handle_transfer_empty_queue() {
1317        // Create a connected socket pair
1318        let (_, sock2) =
1319            LocalSocket::create_connected_pair("server".to_string(), "client".to_string());
1320
1321        // Try to receive from empty queue
1322        let result = sock2.recv_handle();
1323        assert!(
1324            result.is_err(),
1325            "recv_handle should fail when queue is empty"
1326        );
1327    }
1328}