kernel/abi/linux/riscv64/
socket.rs

1use crate::ipc::IpcError;
2use crate::object::capability::StreamError;
3use crate::{
4    abi::linux::riscv64::LinuxRiscv64Abi,
5    abi::linux::riscv64::errno,
6    abi::linux::riscv64::fs::{FD_CLOEXEC, IoVec, O_NONBLOCK},
7    arch::Trapframe,
8    network::{NetworkManager, SocketDomain, SocketProtocol, SocketType, local::LocalSocket},
9    object::KernelObject,
10    object::capability::selectable::Selectable,
11    sched::scheduler::get_scheduler,
12    task::mytask,
13};
14use alloc::sync::Arc;
15use alloc::vec::Vec;
16use core::mem::size_of;
17
18/// Linux socket domains
19pub const AF_UNIX: i32 = 1; // Unix domain sockets
20pub const AF_INET: i32 = 2; // Internet IP Protocol
21pub const AF_INET6: i32 = 10; // IP version 6
22
23/// Linux socket domain as u16 constants for pattern matching
24const AF_UNIX_U16: u16 = AF_UNIX as u16;
25const AF_INET_U16: u16 = AF_INET as u16;
26
27/// IPv4 socket address structure
28#[repr(C)]
29#[derive(Clone, Copy)]
30pub struct SockaddrIn {
31    pub sin_family: u16,
32    pub sin_port: u16,
33    pub sin_addr: u32,
34    pub sin_zero: [u8; 8],
35}
36
37impl SockaddrIn {
38    pub fn new() -> Self {
39        Self {
40            sin_family: AF_INET as u16,
41            sin_port: 0,
42            sin_addr: 0,
43            sin_zero: [0; 8],
44        }
45    }
46}
47
48/// Linux socket types
49pub const SOCK_STREAM: i32 = 1; // Stream socket
50pub const SOCK_DGRAM: i32 = 2; // Datagram socket
51pub const SOCK_RAW: i32 = 3; // Raw socket
52pub const SOCK_SEQPACKET: i32 = 5; // Sequenced packet socket
53pub const SOCK_NONBLOCK: i32 = 0x800;
54pub const SOCK_CLOEXEC: i32 = 0x80000;
55pub const SOCK_TYPE_MASK: i32 = 0xF;
56
57pub const SOL_SOCKET: i32 = 1;
58pub const SCM_RIGHTS: i32 = 1;
59pub const MSG_DONTWAIT: i32 = 0x40;
60
61#[repr(C)]
62#[derive(Clone, Copy)]
63struct LinuxMsghdr {
64    msg_name: u64,
65    msg_namelen: u32,
66    __pad1: u32,
67    msg_iov: u64,
68    msg_iovlen: u64,
69    msg_control: u64,
70    msg_controllen: u64,
71    msg_flags: u32,
72    __pad2: u32,
73}
74
75#[repr(C)]
76#[derive(Clone, Copy)]
77struct LinuxCmsghdr {
78    cmsg_len: usize,
79    cmsg_level: i32,
80    cmsg_type: i32,
81}
82
83/// Linux sys_socket implementation
84///
85/// Creates a socket endpoint for communication. Now properly integrated with
86/// NetworkManager and VFS for Unix domain socket support.
87///
88/// Arguments:
89/// - abi: LinuxRiscv64Abi context
90/// - trapframe: Trapframe containing syscall arguments
91///   - arg0: domain (communication domain, e.g., AF_UNIX, AF_INET)
92///   - arg1: type (socket type, e.g., SOCK_STREAM, SOCK_DGRAM)
93///   - arg2: protocol (protocol to use, usually 0)
94///
95/// Returns:
96/// - file descriptor on success
97/// - usize::MAX (Linux -1) on error
98pub fn sys_socket(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
99    let task = match mytask() {
100        Some(t) => t,
101        None => return usize::MAX,
102    };
103
104    let domain = trapframe.get_arg(0) as i32;
105    let socket_type = trapframe.get_arg(1) as i32;
106    let _protocol = trapframe.get_arg(2) as i32;
107    let socket_base_type = socket_type & SOCK_TYPE_MASK;
108    let socket_flags = socket_type & !SOCK_TYPE_MASK;
109    let set_nonblock = (socket_flags & SOCK_NONBLOCK) != 0;
110    let set_cloexec = (socket_flags & SOCK_CLOEXEC) != 0;
111
112    // Increment PC to avoid infinite loop
113    trapframe.increment_pc_next(task);
114
115    // Map Linux socket domain to Scarlet domain
116    let scarlet_domain = match domain {
117        AF_UNIX => SocketDomain::Local,
118        AF_INET => SocketDomain::Inet4,
119        AF_INET6 => SocketDomain::Inet6,
120        _ => {
121            crate::early_println!("[linux socket] unsupported domain {}", domain);
122            return usize::MAX;
123        }
124    };
125
126    // Map Linux socket type to Scarlet type
127    let scarlet_type = match socket_base_type {
128        SOCK_STREAM => SocketType::Stream,
129        SOCK_DGRAM => SocketType::Datagram,
130        SOCK_RAW => SocketType::Raw,
131        SOCK_SEQPACKET => SocketType::SeqPacket,
132        _ => {
133            crate::early_println!("[linux socket] unsupported type {}", socket_type);
134            return usize::MAX;
135        }
136    };
137
138    // Map protocol
139    let scarlet_protocol = match scarlet_domain {
140        SocketDomain::Local => SocketProtocol::Default,
141        SocketDomain::Inet4 | SocketDomain::Inet6 => match (_protocol, socket_base_type) {
142            (0, SOCK_STREAM) => SocketProtocol::Tcp,
143            (0, SOCK_DGRAM) => SocketProtocol::Udp,
144            (6, _) => SocketProtocol::Tcp,
145            (17, _) => SocketProtocol::Udp,
146            (1, _) => SocketProtocol::Icmp,
147            _ => SocketProtocol::Default,
148        },
149        _ => SocketProtocol::Default,
150    };
151
152    let socket_obj: Arc<dyn crate::network::SocketObject> = match scarlet_domain {
153        SocketDomain::Local => {
154            let local_socket = Arc::new(LocalSocket::new(scarlet_type, SocketProtocol::Default));
155            LocalSocket::init_self_weak(&local_socket);
156            local_socket as Arc<dyn crate::network::SocketObject>
157        }
158        SocketDomain::Inet4 | SocketDomain::Inet6 => {
159            let socket = match scarlet_protocol {
160                SocketProtocol::Tcp => {
161                    NetworkManager::get_manager().get_layer("tcp").map(|layer| {
162                        let tcp = layer
163                            .as_any()
164                            .downcast_ref::<crate::network::tcp::TcpLayer>()
165                            .expect("tcp layer type mismatch");
166                        tcp.create_socket() as Arc<dyn crate::network::SocketObject>
167                    })
168                }
169                SocketProtocol::Udp => {
170                    NetworkManager::get_manager().get_layer("udp").map(|layer| {
171                        let udp = layer
172                            .as_any()
173                            .downcast_ref::<crate::network::udp::UdpLayer>()
174                            .expect("udp layer type mismatch");
175                        udp.create_socket() as Arc<dyn crate::network::SocketObject>
176                    })
177                }
178                SocketProtocol::Icmp => {
179                    NetworkManager::get_manager()
180                        .get_layer("icmp")
181                        .map(|layer| {
182                            let icmp = layer
183                                .as_any()
184                                .downcast_ref::<crate::network::icmp::IcmpLayer>()
185                                .expect("icmp layer type mismatch");
186                            icmp.create_socket() as Arc<dyn crate::network::SocketObject>
187                        })
188                }
189                _ => None,
190            };
191
192            match socket {
193                Some(socket) => socket,
194                None => {
195                    crate::early_println!(
196                        "[linux socket] failed to create INET socket protocol={:?}",
197                        scarlet_protocol
198                    );
199                    return usize::MAX;
200                }
201            }
202        }
203        _ => {
204            crate::early_println!("[linux socket] unsupported domain {:?}", scarlet_domain);
205            return usize::MAX;
206        }
207    };
208    if NetworkManager::get_manager()
209        .allocate_socket_id(Arc::clone(&socket_obj))
210        .is_err()
211    {
212        crate::early_println!("[linux socket] allocate_socket_id failed");
213    }
214
215    if set_nonblock {
216        if let Some(local_socket) = LocalSocket::from_socket_object(&socket_obj) {
217            local_socket.set_nonblocking(true);
218        }
219    }
220
221    // Wrap in KernelObject
222    let kernel_obj = KernelObject::Socket(socket_obj);
223
224    // Insert into handle table
225    match task.handle_table.insert(kernel_obj) {
226        Ok(handle) => {
227            // Allocate a file descriptor for the socket
228            match abi.allocate_fd(handle) {
229                Ok(fd) => {
230                    if set_cloexec {
231                        let _ = abi.set_fd_flags(fd, FD_CLOEXEC);
232                    }
233                    fd
234                }
235                Err(_) => {
236                    // Clean up on error
237                    let _ = task.handle_table.remove(handle);
238                    usize::MAX
239                }
240            }
241        }
242        Err(_) => {
243            crate::early_println!("[linux socket] handle table insert failed");
244            usize::MAX
245        }
246    }
247}
248
249/// Linux sys_bind implementation
250///
251/// Binds a socket to an address. For AF_UNIX sockets, this creates a socket file
252/// in the VFS at the specified path.
253///
254/// Arguments:
255/// - abi: LinuxRiscv64Abi context
256/// - trapframe: Trapframe containing syscall arguments
257///   - arg0: sockfd (socket file descriptor)
258///   - arg1: addr (pointer to socket address structure)
259///   - arg2: addrlen (size of address structure)
260///
261/// Returns:
262/// - 0 on success
263/// - usize::MAX (Linux -1) indicating failure
264pub fn sys_bind(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
265    let task = match mytask() {
266        Some(t) => t,
267        None => return usize::MAX,
268    };
269
270    let sockfd = trapframe.get_arg(0) as i32;
271    let addr_ptr = trapframe.get_arg(1);
272    let addrlen = trapframe.get_arg(2) as u32;
273
274    // Increment PC to avoid infinite loop
275    trapframe.increment_pc_next(task);
276
277    // Get the file descriptor handle
278    let handle_id = match abi.get_handle(sockfd as usize) {
279        Some(h) => h,
280        None => {
281            crate::early_println!("[linux socket] bind invalid fd {}", sockfd);
282            return usize::MAX;
283        }
284    };
285
286    // Get the socket object from handle table
287    let socket_arc = match task.handle_table.get(handle_id) {
288        Some(KernelObject::Socket(socket)) => socket.clone(),
289        _ => {
290            crate::early_println!("[linux socket] bind fd {} not socket", sockfd);
291            return usize::MAX;
292        }
293    };
294
295    // Translate address pointer to physical
296    let addr_paddr = match task.vm_manager.translate_vaddr(addr_ptr) {
297        Some(addr) => addr,
298        None => {
299            crate::early_println!("[linux socket] bind bad addr {:x}", addr_ptr);
300            return usize::MAX;
301        }
302    };
303
304    // Read sockaddr structure from userspace
305    // sockaddr_un structure: { sa_family: u16, sun_path: [u8; 108] }
306    if addrlen < 2 {
307        return usize::MAX; // Too small
308    }
309
310    unsafe {
311        let sa_family = *(addr_paddr as *const u16);
312
313        match sa_family {
314            AF_UNIX_U16 => {
315                // Read the socket path (starts at offset 2)
316                let path_start = (addr_paddr + 2) as *const u8;
317                let max_path_len = (addrlen - 2) as usize;
318
319                // Find the null terminator or max length
320                let mut path_len = 0;
321                while path_len < max_path_len && *path_start.add(path_len) != 0 {
322                    path_len += 1;
323                }
324
325                if path_len == 0 || path_len > 108 {
326                    crate::early_println!("[linux socket] bind invalid path length {}", path_len);
327                    return usize::MAX;
328                }
329
330                // Convert to string
331                let path_bytes = core::slice::from_raw_parts(path_start, path_len);
332                let path = match core::str::from_utf8(path_bytes) {
333                    Ok(s) => s,
334                    Err(_) => {
335                        crate::early_println!("[linux socket] bind path utf8 error");
336                        return usize::MAX;
337                    }
338                };
339
340                // Bind the socket to the address
341                let socket_addr = match crate::network::LocalSocketAddress::from_path(path) {
342                    Ok(addr) => crate::network::SocketAddress::Local(addr),
343                    Err(_) => return usize::MAX,
344                };
345
346                if socket_arc.bind(&socket_addr).is_err() {
347                    crate::early_println!("[linux socket] bind failed for {}", path);
348                    return usize::MAX;
349                }
350
351                if NetworkManager::get_manager()
352                    .register_named_socket(path, socket_arc.clone())
353                    .is_err()
354                {
355                    crate::early_println!("[linux socket] register_named_socket failed {}", path);
356                    return usize::MAX;
357                }
358
359                // Get the socket ID from NetworkManager
360                let socket_id = match NetworkManager::get_manager().get_socket_id(&socket_arc) {
361                    Some(id) => id,
362                    None => {
363                        crate::early_println!("[linux socket] get_socket_id failed {}", path);
364                        return usize::MAX;
365                    }
366                };
367
368                // Create socket file in VFS on a best-effort basis
369                // The socket has already been successfully bound, so VFS file creation
370                // is optional - the socket remains functional even if this fails
371                let vfs = match task.vfs.read().clone() {
372                    Some(vfs) => vfs.clone(),
373                    None => {
374                        // Use global VFS if task doesn't have its own
375                        crate::fs::vfs_v2::manager::get_global_vfs_manager()
376                    }
377                };
378
379                let socket_file_type =
380                    crate::fs::FileType::Socket(crate::fs::SocketFileInfo { socket_id });
381
382                // Attempt to create the socket file - log on failure but don't fail the bind
383                if let Err(e) = vfs.create_file(path, socket_file_type) {
384                    crate::early_println!(
385                        "[sys_bind] Warning: Failed to create VFS socket file at '{}': {:?}",
386                        path,
387                        e
388                    );
389                }
390
391                0 // Success
392            }
393            AF_INET_U16 => {
394                let addr_struct = &*(addr_paddr as *const SockaddrIn);
395                let port = u16::from_be(addr_struct.sin_port);
396                let addr_bytes = u32::to_be(addr_struct.sin_addr).to_be_bytes();
397                let socket_addr = crate::network::SocketAddress::Inet(
398                    crate::network::Inet4SocketAddress::new(addr_bytes, port),
399                );
400
401                if socket_arc.bind(&socket_addr).is_err() {
402                    crate::early_println!("[linux socket] bind failed for INET address");
403                    return usize::MAX;
404                }
405
406                0
407            }
408            _ => {
409                crate::early_println!("[linux socket] bind unsupported family {}", sa_family);
410                usize::MAX
411            }
412        }
413    }
414}
415
416/// Linux sys_listen implementation (mock)
417///
418/// Marks a socket as passive, ready to accept connections. This is a mock
419/// implementation that always succeeds to allow applications to proceed.
420///
421/// Arguments:
422/// - abi: LinuxRiscv64Abi context
423/// - trapframe: Trapframe containing syscall arguments
424///   - arg0: sockfd (socket file descriptor)
425///   - arg1: backlog (maximum queue length for pending connections)
426///
427/// Returns:
428/// - 0 on success
429/// - usize::MAX (Linux -1) indicating failure
430pub fn sys_listen(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
431    let task = match mytask() {
432        Some(t) => t,
433        None => return usize::MAX,
434    };
435
436    let sockfd = trapframe.get_arg(0) as i32;
437    let backlog = trapframe.get_arg(1) as i32;
438
439    // Increment PC to avoid infinite loop
440    trapframe.increment_pc_next(task);
441
442    let handle_id = match abi.get_handle(sockfd as usize) {
443        Some(h) => h,
444        None => {
445            crate::early_println!("[linux socket] listen invalid fd {}", sockfd);
446            return usize::MAX;
447        }
448    };
449
450    let socket_arc = match task.handle_table.get(handle_id) {
451        Some(KernelObject::Socket(socket)) => socket.clone(),
452        _ => {
453            crate::early_println!("[linux socket] listen fd {} not socket", sockfd);
454            return usize::MAX;
455        }
456    };
457
458    if socket_arc.listen(backlog.max(0) as usize).is_err() {
459        crate::early_println!("[linux socket] listen failed fd {}", sockfd);
460        return usize::MAX;
461    }
462
463    0
464}
465
466/// Linux sys_accept implementation (mock)
467///
468/// Accepts a connection on a socket. This is a mock implementation that
469/// creates a new pipe and returns it as a "connected" socket fd.
470///
471/// Arguments:
472/// - abi: LinuxRiscv64Abi context
473/// - trapframe: Trapframe containing syscall arguments
474///   - arg0: sockfd (socket file descriptor)
475///   - arg1: addr (pointer to socket address structure for peer)
476///   - arg2: addrlen (pointer to size of address structure)
477///
478/// Returns:
479/// - new socket file descriptor on success
480/// - usize::MAX (Linux -1) indicating failure
481pub fn sys_accept(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
482    let task = match mytask() {
483        Some(t) => t,
484        None => return usize::MAX,
485    };
486
487    let sockfd = trapframe.get_arg(0) as i32;
488    let _addr_ptr = trapframe.get_arg(1);
489    let _addrlen_ptr = trapframe.get_arg(2);
490
491    // Increment PC to avoid infinite loop
492    trapframe.increment_pc_next(task);
493
494    let handle_id = match abi.get_handle(sockfd as usize) {
495        Some(h) => h,
496        None => {
497            crate::early_println!("[linux socket] accept invalid fd {}", sockfd);
498            return usize::MAX;
499        }
500    };
501
502    let socket_obj = match task.handle_table.get(handle_id) {
503        Some(KernelObject::Socket(socket)) => socket.clone(),
504        _ => {
505            crate::early_println!("[linux socket] accept fd {} not socket", sockfd);
506            return usize::MAX;
507        }
508    };
509
510    // Try LocalSocket first, then TcpSocket
511    let accepted_socket = if let Some(local_socket) = LocalSocket::from_socket_object(&socket_obj) {
512        local_socket.accept_blocking(task.get_id(), trapframe)
513    } else if let Some(tcp_socket) = crate::network::tcp::TcpSocket::from_socket_object(&socket_obj)
514    {
515        tcp_socket.accept_blocking(task.get_id(), trapframe)
516    } else {
517        crate::early_println!("[linux socket] accept not supported socket type");
518        return usize::MAX;
519    };
520
521    let accepted_socket = match accepted_socket {
522        Ok(socket) => socket,
523        Err(_) => {
524            crate::early_println!("[linux socket] accept_blocking failed");
525            return usize::MAX;
526        }
527    };
528
529    let kernel_obj = KernelObject::Socket(accepted_socket);
530    match task.handle_table.insert(kernel_obj) {
531        Ok(handle) => match abi.allocate_fd(handle) {
532            Ok(fd) => fd,
533            Err(_) => {
534                let _ = task.handle_table.remove(handle);
535                usize::MAX
536            }
537        },
538        Err(_) => usize::MAX,
539    }
540}
541
542/// Linux sys_connect implementation (mock)
543///
544/// Connects a socket to an address. This is a mock implementation that
545/// always succeeds to allow applications to proceed.
546///
547/// Arguments:
548/// - abi: LinuxRiscv64Abi context
549/// - trapframe: Trapframe containing syscall arguments
550///   - arg0: sockfd (socket file descriptor)
551///   - arg1: addr (pointer to socket address structure)
552///   - arg2: addrlen (size of address structure)
553///
554/// Returns:
555/// - 0 on success
556/// - usize::MAX (Linux -1) indicating failure
557pub fn sys_connect(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
558    let task = match mytask() {
559        Some(t) => t,
560        None => return usize::MAX,
561    };
562
563    let sockfd = trapframe.get_arg(0) as i32;
564    let addr_ptr = trapframe.get_arg(1);
565    let addrlen = trapframe.get_arg(2) as u32;
566
567    // Increment PC to avoid infinite loop
568    trapframe.increment_pc_next(task);
569
570    let handle_id = match abi.get_handle(sockfd as usize) {
571        Some(h) => h,
572        None => {
573            crate::early_println!("[linux socket] connect invalid fd {}", sockfd);
574            return usize::MAX;
575        }
576    };
577
578    let socket_arc = match task.handle_table.get(handle_id) {
579        Some(KernelObject::Socket(socket)) => socket.clone(),
580        _ => {
581            crate::early_println!("[linux socket] connect fd {} not socket", sockfd);
582            return usize::MAX;
583        }
584    };
585
586    let addr_paddr = match task.vm_manager.translate_vaddr(addr_ptr) {
587        Some(addr) => addr,
588        None => {
589            crate::early_println!("[linux socket] connect bad addr {:x}", addr_ptr);
590            return usize::MAX;
591        }
592    };
593
594    if addrlen < 2 {
595        return usize::MAX;
596    }
597
598    unsafe {
599        let sa_family = *(addr_paddr as *const u16);
600        match sa_family {
601            AF_UNIX_U16 => {
602                let path_start = (addr_paddr + 2) as *const u8;
603                let max_path_len = (addrlen - 2) as usize;
604                let mut path_len = 0;
605                while path_len < max_path_len && *path_start.add(path_len) != 0 {
606                    path_len += 1;
607                }
608
609                if path_len == 0 || path_len > 108 {
610                    crate::early_println!(
611                        "[linux socket] connect invalid path length {}",
612                        path_len
613                    );
614                    return usize::MAX;
615                }
616
617                let path_bytes = core::slice::from_raw_parts(path_start, path_len);
618                let path = match core::str::from_utf8(path_bytes) {
619                    Ok(s) => s,
620                    Err(_) => {
621                        crate::early_println!("[linux socket] connect path utf8 error");
622                        return usize::MAX;
623                    }
624                };
625
626                let socket_addr = match crate::network::LocalSocketAddress::from_path(path) {
627                    Ok(addr) => crate::network::SocketAddress::Local(addr),
628                    Err(_) => return usize::MAX,
629                };
630
631                if socket_arc.connect(&socket_addr).is_err() {
632                    crate::early_println!("[linux socket] connect failed {}", path);
633                    return usize::MAX;
634                }
635            }
636            AF_INET_U16 => {
637                let addr_struct = &*(addr_paddr as *const SockaddrIn);
638                let port = u16::from_be(addr_struct.sin_port);
639                let addr_bytes = u32::to_be(addr_struct.sin_addr).to_be_bytes();
640                let socket_addr = crate::network::SocketAddress::Inet(
641                    crate::network::Inet4SocketAddress::new(addr_bytes, port),
642                );
643
644                if socket_arc.connect(&socket_addr).is_err() {
645                    crate::early_println!("[linux socket] connect failed for INET address");
646                    return usize::MAX;
647                }
648            }
649            _ => {
650                crate::early_println!("[linux socket] connect unsupported family {}", sa_family);
651                return usize::MAX;
652            }
653        }
654    }
655
656    0
657}
658
659/// Linux sys_getsockname implementation
660///
661/// Gets the current address of a socket.
662///
663/// Arguments:
664/// - abi: LinuxRiscv64Abi context
665/// - trapframe: Trapframe containing syscall arguments
666///   - arg0: sockfd (socket file descriptor)
667///   - arg1: addr (pointer to socket address structure)
668///   - arg2: addrlen (pointer to size of address structure)
669///
670/// Returns:
671/// - 0 on success
672/// - usize::MAX (Linux -1) indicating failure
673pub fn sys_getsockname(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
674    let task = match mytask() {
675        Some(t) => t,
676        None => return usize::MAX,
677    };
678
679    let sockfd = trapframe.get_arg(0) as i32;
680    let addr_ptr = trapframe.get_arg(1);
681    let addrlen_ptr = trapframe.get_arg(2);
682
683    // Increment PC to avoid infinite loop
684    trapframe.increment_pc_next(task);
685
686    let handle_id = match abi.get_handle(sockfd as usize) {
687        Some(h) => h,
688        None => {
689            crate::early_println!("[linux socket] getsockname invalid fd {}", sockfd);
690            return usize::MAX;
691        }
692    };
693
694    let socket_arc = match task.handle_table.get(handle_id) {
695        Some(KernelObject::Socket(socket)) => socket.clone(),
696        _ => {
697            crate::early_println!("[linux socket] getsockname fd {} not socket", sockfd);
698            return usize::MAX;
699        }
700    };
701
702    let (addr_paddr, addrlen_paddr) = match (
703        task.vm_manager.translate_vaddr(addr_ptr),
704        task.vm_manager.translate_vaddr(addrlen_ptr),
705    ) {
706        (Some(addr), Some(len)) => (addr, len),
707        _ => {
708            crate::early_println!("[linux socket] getsockname invalid pointers");
709            return usize::MAX;
710        }
711    };
712
713    let socket_addr = match socket_arc.getsockname() {
714        Ok(addr) => addr,
715        Err(_) => {
716            crate::early_println!("[linux socket] getsockname failed");
717            return usize::MAX;
718        }
719    };
720
721    unsafe {
722        let addrlen = *(addrlen_paddr as *const u32);
723
724        match socket_addr {
725            crate::network::SocketAddress::Local(addr) => {
726                if addrlen >= 2 {
727                    let sockaddr = addr_paddr as *mut u16;
728                    *sockaddr = AF_UNIX_U16;
729
730                    let path_start = (addr_paddr + 2) as *mut u8;
731                    let path = addr.path().as_bytes();
732                    let path_len = path.len().min((addrlen - 2) as usize);
733                    core::ptr::copy_nonoverlapping(path.as_ptr(), path_start, path_len);
734                    if path_len < (addrlen - 2) as usize {
735                        *(path_start.add(path_len)) = 0;
736                    }
737
738                    *(addrlen_paddr as *mut u32) = (2 + path_len as u32).min(addrlen);
739                    0
740                } else {
741                    usize::MAX
742                }
743            }
744            crate::network::SocketAddress::Inet(inet) => {
745                if addrlen >= size_of::<SockaddrIn>() as u32 {
746                    let sockaddr = addr_paddr as *mut SockaddrIn;
747                    (*sockaddr).sin_family = AF_INET_U16;
748                    (*sockaddr).sin_port = u16::to_be(inet.port);
749                    (*sockaddr).sin_addr = u32::from_be_bytes(inet.addr);
750
751                    *(addrlen_paddr as *mut u32) = size_of::<SockaddrIn>() as u32;
752                    0
753                } else {
754                    usize::MAX
755                }
756            }
757            _ => usize::MAX,
758        }
759    }
760}
761
762/// Linux sys_getpeername implementation
763///
764/// Gets the address of the peer connected to the socket.
765///
766/// Arguments:
767///   - arg0: sockfd (socket file descriptor)
768///   - arg1: addr (pointer to socket address structure)
769///   - arg2: addrlen (pointer to size of address structure)
770///
771/// Returns:
772/// - 0 on success
773/// - usize::MAX (Linux -1) indicating failure
774pub fn sys_getpeername(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
775    let task = match mytask() {
776        Some(t) => t,
777        None => return usize::MAX,
778    };
779
780    let sockfd = trapframe.get_arg(0) as i32;
781    let addr_ptr = trapframe.get_arg(1);
782    let addrlen_ptr = trapframe.get_arg(2);
783
784    // Increment PC to avoid infinite loop
785    trapframe.increment_pc_next(task);
786
787    let handle_id = match abi.get_handle(sockfd as usize) {
788        Some(h) => h,
789        None => {
790            crate::early_println!("[linux socket] getpeername invalid fd {}", sockfd);
791            return usize::MAX;
792        }
793    };
794
795    let socket_arc = match task.handle_table.get(handle_id) {
796        Some(KernelObject::Socket(socket)) => socket.clone(),
797        _ => {
798            crate::early_println!("[linux socket] getpeername fd {} not socket", sockfd);
799            return usize::MAX;
800        }
801    };
802
803    let (addr_paddr, addrlen_paddr) = match (
804        task.vm_manager.translate_vaddr(addr_ptr),
805        task.vm_manager.translate_vaddr(addrlen_ptr),
806    ) {
807        (Some(addr), Some(len)) => (addr, len),
808        _ => {
809            crate::early_println!("[linux socket] getpeername invalid pointers");
810            return usize::MAX;
811        }
812    };
813
814    let socket_addr = match socket_arc.getpeername() {
815        Ok(addr) => addr,
816        Err(_) => {
817            crate::early_println!("[linux socket] getpeername failed");
818            return usize::MAX;
819        }
820    };
821
822    unsafe {
823        let addrlen = *(addrlen_paddr as *const u32);
824
825        match socket_addr {
826            crate::network::SocketAddress::Local(addr) => {
827                if addrlen >= 2 {
828                    let sockaddr = addr_paddr as *mut u16;
829                    *sockaddr = AF_UNIX_U16;
830
831                    let path_start = (addr_paddr + 2) as *mut u8;
832                    let path = addr.path().as_bytes();
833                    let path_len = path.len().min((addrlen - 2) as usize);
834                    core::ptr::copy_nonoverlapping(path.as_ptr(), path_start, path_len);
835                    if path_len < (addrlen - 2) as usize {
836                        *(path_start.add(path_len)) = 0;
837                    }
838
839                    *(addrlen_paddr as *mut u32) = (2 + path_len as u32).min(addrlen);
840                    0
841                } else {
842                    usize::MAX
843                }
844            }
845            crate::network::SocketAddress::Inet(inet) => {
846                if addrlen >= size_of::<SockaddrIn>() as u32 {
847                    let sockaddr = addr_paddr as *mut SockaddrIn;
848                    (*sockaddr).sin_family = AF_INET_U16;
849                    (*sockaddr).sin_port = u16::to_be(inet.port);
850                    (*sockaddr).sin_addr = u32::from_be_bytes(inet.addr);
851
852                    *(addrlen_paddr as *mut u32) = size_of::<SockaddrIn>() as u32;
853                    0
854                } else {
855                    usize::MAX
856                }
857            }
858            _ => usize::MAX,
859        }
860    }
861}
862
863/// Linux sys_getsockopt implementation (mock)
864///
865/// Gets socket options. This is a mock implementation that
866/// writes dummy data and succeeds to allow applications to proceed.
867///
868/// Arguments:
869/// - abi: LinuxRiscv64Abi context
870/// - trapframe: Trapframe containing syscall arguments
871///   - arg0: sockfd (socket file descriptor)
872///   - arg1: level (protocol level)
873///   - arg2: optname (option name)
874///   - arg3: optval (pointer to option value buffer)
875///   - arg4: optlen (pointer to option length)
876///
877/// Returns:
878/// - 0 on success
879/// - usize::MAX (Linux -1) indicating failure
880pub fn sys_getsockopt(_abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
881    let task = match mytask() {
882        Some(t) => t,
883        None => return usize::MAX,
884    };
885
886    let _sockfd = trapframe.get_arg(0) as i32;
887    let _level = trapframe.get_arg(1) as i32;
888    let _optname = trapframe.get_arg(2) as i32;
889    let optval_ptr = trapframe.get_arg(3);
890    let optlen_ptr = trapframe.get_arg(4);
891
892    // Increment PC to avoid infinite loop
893    trapframe.increment_pc_next(task);
894
895    // Mock implementation - write minimal valid data and return success
896    if let (Some(optval_paddr), Some(optlen_paddr)) = (
897        task.vm_manager.translate_vaddr(optval_ptr),
898        task.vm_manager.translate_vaddr(optlen_ptr),
899    ) {
900        unsafe {
901            // Read the provided length
902            let optlen = *(optlen_paddr as *const u32);
903
904            // Write dummy option value (typically an integer)
905            if optlen >= 4 && optval_ptr != 0 {
906                let optval = optval_paddr as *mut u32;
907                *optval = 1; // Generic "enabled" value
908
909                // Update the actual length used
910                *(optlen_paddr as *mut u32) = 4;
911            }
912        }
913        0 // Success
914    } else {
915        usize::MAX // Invalid pointers
916    }
917}
918
919/// Linux sys_setsockopt implementation (mock)
920///
921/// Sets socket options. This is a mock implementation that
922/// always succeeds to allow applications to proceed.
923///
924/// Arguments:
925/// - abi: LinuxRiscv64Abi context
926/// - trapframe: Trapframe containing syscall arguments
927///   - arg0: sockfd (socket file descriptor)
928///   - arg1: level (protocol level)
929///   - arg2: optname (option name)
930///   - arg3: optval (pointer to option value)
931///   - arg4: optlen (option length)
932///
933/// Returns:
934/// - 0 on success
935/// - usize::MAX (Linux -1) indicating failure
936pub fn sys_setsockopt(_abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
937    let task = match mytask() {
938        Some(t) => t,
939        None => return usize::MAX,
940    };
941
942    let _sockfd = trapframe.get_arg(0) as i32;
943    let _level = trapframe.get_arg(1) as i32;
944    let _optname = trapframe.get_arg(2) as i32;
945    let _optval_ptr = trapframe.get_arg(3);
946    let _optlen = trapframe.get_arg(4) as u32;
947
948    // Increment PC to avoid infinite loop
949    trapframe.increment_pc_next(task);
950
951    // Mock implementation - always succeed
952    0
953}
954
955/// Linux sys_sendmsg implementation (minimal)
956pub fn sys_sendmsg(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
957    let task = match mytask() {
958        Some(t) => t,
959        None => return usize::MAX,
960    };
961
962    let sockfd = trapframe.get_arg(0);
963    let msg_ptr = trapframe.get_arg(1);
964    let flags = trapframe.get_arg(2) as i32;
965
966    trapframe.increment_pc_next(task);
967
968    let handle = match abi.get_handle(sockfd) {
969        Some(h) => h,
970        None => {
971            crate::early_println!("[linux socket] sendmsg bad fd {}", sockfd);
972            return errno::to_result(errno::EBADF);
973        }
974    };
975
976    let kernel_obj = match task.handle_table.get(handle) {
977        Some(obj) => obj,
978        None => {
979            crate::early_println!("[linux socket] sendmsg missing handle {}", sockfd);
980            return errno::to_result(errno::EBADF);
981        }
982    };
983
984    let stream = match kernel_obj.as_stream() {
985        Some(stream) => stream,
986        None => {
987            crate::early_println!("[linux socket] sendmsg not a stream");
988            return errno::to_result(errno::ENOTSOCK);
989        }
990    };
991
992    let nonblocking = (flags & MSG_DONTWAIT) != 0
993        || abi
994            .get_file_status_flags(sockfd)
995            .map(|f| ((f as i32) & O_NONBLOCK) != 0)
996            .unwrap_or(false);
997
998    let msg_addr = match task.vm_manager.translate_vaddr(msg_ptr) {
999        Some(addr) => addr as *const LinuxMsghdr,
1000        None => {
1001            crate::early_println!("[linux socket] sendmsg bad msg ptr {:x}", msg_ptr);
1002            return errno::to_result(errno::EFAULT);
1003        }
1004    };
1005
1006    if msg_addr.is_null() {
1007        crate::early_println!("[linux socket] sendmsg null msg ptr");
1008        return errno::to_result(errno::EFAULT);
1009    }
1010
1011    let msg = unsafe { *msg_addr };
1012    let iovcnt = msg.msg_iovlen as usize;
1013    if iovcnt == 0 {
1014        return 0;
1015    }
1016
1017    const IOV_MAX: usize = 1024;
1018    if iovcnt > IOV_MAX {
1019        return errno::to_result(errno::EINVAL);
1020    }
1021
1022    let iovec_addr = match task.vm_manager.translate_vaddr(msg.msg_iov as usize) {
1023        Some(addr) => addr as *const IoVec,
1024        None => {
1025            crate::early_println!("[linux socket] sendmsg bad iov ptr {:x}", msg.msg_iov);
1026            return errno::to_result(errno::EFAULT);
1027        }
1028    };
1029
1030    if iovec_addr.is_null() {
1031        crate::early_println!("[linux socket] sendmsg null iov ptr");
1032        return errno::to_result(errno::EFAULT);
1033    }
1034
1035    let iovecs = unsafe { core::slice::from_raw_parts(iovec_addr, iovcnt) };
1036
1037    if msg.msg_control != 0 && msg.msg_controllen as usize >= size_of::<LinuxCmsghdr>() {
1038        let socket_arc = match &kernel_obj {
1039            KernelObject::Socket(socket) => Arc::clone(socket),
1040            _ => return errno::to_result(errno::ENOTSOCK),
1041        };
1042
1043        if let Some(local_socket) = LocalSocket::from_socket_object(&socket_arc) {
1044            let cmsg_addr = match task.vm_manager.translate_vaddr(msg.msg_control as usize) {
1045                Some(addr) => addr as *const LinuxCmsghdr,
1046                None => {
1047                    crate::early_println!(
1048                        "[linux socket] sendmsg bad cmsg ptr {:x}",
1049                        msg.msg_control
1050                    );
1051                    return errno::to_result(errno::EFAULT);
1052                }
1053            };
1054
1055            if !cmsg_addr.is_null() {
1056                let cmsg = unsafe { *cmsg_addr };
1057                if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
1058                    let data_len = cmsg.cmsg_len.saturating_sub(size_of::<LinuxCmsghdr>());
1059                    let fd_count = data_len / size_of::<i32>();
1060                    let data_ptr = unsafe { cmsg_addr.add(1) } as *const i32;
1061
1062                    for index in 0..fd_count {
1063                        let fd = unsafe { *data_ptr.add(index) };
1064                        if fd < 0 {
1065                            return errno::to_result(errno::EBADF);
1066                        }
1067                        let send_handle = match abi.get_handle(fd as usize) {
1068                            Some(h) => h,
1069                            None => {
1070                                crate::early_println!(
1071                                    "[linux socket] sendmsg bad fd in cmsg {}",
1072                                    fd
1073                                );
1074                                return errno::to_result(errno::EBADF);
1075                            }
1076                        };
1077                        let dup_obj = match task.handle_table.clone_for_dup(send_handle) {
1078                            Some(obj) => obj,
1079                            None => {
1080                                crate::early_println!(
1081                                    "[linux socket] sendmsg clone_for_dup failed"
1082                                );
1083                                return errno::to_result(errno::EBADF);
1084                            }
1085                        };
1086                        if local_socket.send_handle(dup_obj).is_err() {
1087                            crate::early_println!("[linux socket] sendmsg send_handle failed");
1088                            return errno::to_result(errno::EIO);
1089                        }
1090                    }
1091                }
1092            }
1093        }
1094    }
1095
1096    let mut total_written = 0usize;
1097    struct NonblockGuard<'a> {
1098        sel: Option<&'a dyn Selectable>,
1099        prev: bool,
1100    }
1101
1102    impl<'a> Drop for NonblockGuard<'a> {
1103        fn drop(&mut self) {
1104            if let Some(sel) = self.sel {
1105                sel.set_nonblocking(self.prev);
1106            }
1107        }
1108    }
1109
1110    let _nonblock_guard = if nonblocking {
1111        if let Some(sel) = kernel_obj.as_selectable() {
1112            let prev = sel.is_nonblocking();
1113            if !prev {
1114                sel.set_nonblocking(true);
1115                Some(NonblockGuard {
1116                    sel: Some(sel),
1117                    prev,
1118                })
1119            } else {
1120                None
1121            }
1122        } else {
1123            None
1124        }
1125    } else {
1126        None
1127    };
1128
1129    for iovec in iovecs {
1130        if iovec.iov_len == 0 {
1131            continue;
1132        }
1133
1134        let buf_addr = match task.vm_manager.translate_vaddr(iovec.iov_base as usize) {
1135            Some(addr) => addr as *const u8,
1136            None => {
1137                crate::early_println!(
1138                    "[linux socket] sendmsg bad buf ptr {:x}",
1139                    iovec.iov_base as usize
1140                );
1141                return errno::to_result(errno::EFAULT);
1142            }
1143        };
1144
1145        if buf_addr.is_null() {
1146            crate::early_println!("[linux socket] sendmsg null buf ptr");
1147            return errno::to_result(errno::EFAULT);
1148        }
1149
1150        let buffer = unsafe { core::slice::from_raw_parts(buf_addr, iovec.iov_len) };
1151
1152        match stream.write(buffer) {
1153            Ok(n) => {
1154                total_written = total_written.saturating_add(n);
1155                if n < iovec.iov_len {
1156                    break;
1157                }
1158            }
1159            Err(StreamError::WouldBlock) => {
1160                if nonblocking {
1161                    crate::early_println!("[linux socket] sendmsg would block");
1162                    return if total_written == 0 {
1163                        errno::to_result(errno::EAGAIN)
1164                    } else {
1165                        total_written
1166                    };
1167                }
1168                get_scheduler().schedule(trapframe);
1169                return usize::MAX;
1170            }
1171            Err(_) => {
1172                crate::early_println!("[linux socket] sendmsg write error");
1173                return errno::to_result(errno::EIO);
1174            }
1175        }
1176    }
1177
1178    total_written
1179}
1180
1181/// Linux sys_recvmsg implementation (minimal)
1182pub fn sys_recvmsg(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
1183    let task = match mytask() {
1184        Some(t) => t,
1185        None => return usize::MAX,
1186    };
1187
1188    let sockfd = trapframe.get_arg(0);
1189    let msg_ptr = trapframe.get_arg(1);
1190    let flags = trapframe.get_arg(2) as i32;
1191    // crate::early_println!(
1192    //     "[linux recvmsg] fd={} msg_ptr={:#x} flags={:#x}",
1193    //     sockfd,
1194    //     msg_ptr,
1195    //     flags
1196    // );
1197
1198    trapframe.increment_pc_next(task);
1199
1200    let handle = match abi.get_handle(sockfd) {
1201        Some(h) => h,
1202        None => {
1203            crate::early_println!("[linux socket] recvmsg bad fd {}", sockfd);
1204            return errno::to_result(errno::EBADF);
1205        }
1206    };
1207
1208    // crate::early_println!("[linux recvmsg] handle={}", handle);
1209
1210    let kernel_obj = match task.handle_table.get(handle) {
1211        Some(obj) => obj,
1212        None => {
1213            crate::early_println!("[linux socket] recvmsg missing handle {}", sockfd);
1214            return errno::to_result(errno::EBADF);
1215        }
1216    };
1217
1218    let stream = match kernel_obj.as_stream() {
1219        Some(stream) => stream,
1220        None => {
1221            crate::early_println!("[linux socket] recvmsg not a stream");
1222            return errno::to_result(errno::ENOTSOCK);
1223        }
1224    };
1225
1226    let nonblocking = (flags & MSG_DONTWAIT) != 0
1227        || abi
1228            .get_file_status_flags(sockfd)
1229            .map(|f| ((f as i32) & O_NONBLOCK) != 0)
1230            .unwrap_or(false);
1231
1232    let msg_addr = match task.vm_manager.translate_vaddr(msg_ptr) {
1233        Some(addr) => addr as *mut LinuxMsghdr,
1234        None => {
1235            crate::early_println!("[linux socket] recvmsg bad msg ptr {:x}", msg_ptr);
1236            return errno::to_result(errno::EFAULT);
1237        }
1238    };
1239
1240    if msg_addr.is_null() {
1241        crate::early_println!("[linux socket] recvmsg null msg ptr");
1242        return errno::to_result(errno::EFAULT);
1243    }
1244
1245    let msg = unsafe { *msg_addr };
1246    // crate::early_println!(
1247    //     "[linux recvmsg] iov_ptr={:#x} iovlen={} control_ptr={:#x} controllen={}",
1248    //     msg.msg_iov,
1249    //     msg.msg_iovlen,
1250    //     msg.msg_control,
1251    //     msg.msg_controllen
1252    // );
1253    let iovcnt = msg.msg_iovlen as usize;
1254    if iovcnt == 0 {
1255        return 0;
1256    }
1257
1258    const IOV_MAX: usize = 1024;
1259    if iovcnt > IOV_MAX {
1260        return errno::to_result(errno::EINVAL);
1261    }
1262
1263    let iovec_addr = match task.vm_manager.translate_vaddr(msg.msg_iov as usize) {
1264        Some(addr) => addr as *const IoVec,
1265        None => {
1266            crate::early_println!("[linux socket] recvmsg bad iov ptr {:x}", msg.msg_iov);
1267            return errno::to_result(errno::EFAULT);
1268        }
1269    };
1270
1271    if iovec_addr.is_null() {
1272        crate::early_println!("[linux socket] recvmsg null iov ptr");
1273        return errno::to_result(errno::EFAULT);
1274    }
1275
1276    let iovecs = unsafe { core::slice::from_raw_parts(iovec_addr, iovcnt) };
1277    // crate::early_println!("[linux recvmsg] iovcnt={}", iovecs.len());
1278    let mut total_read = 0usize;
1279    let mut pending_fd: Option<i32> = None;
1280    let mut msg_controllen = 0usize;
1281    struct NonblockGuard<'a> {
1282        sel: Option<&'a dyn Selectable>,
1283        prev: bool,
1284    }
1285
1286    impl<'a> Drop for NonblockGuard<'a> {
1287        fn drop(&mut self) {
1288            if let Some(sel) = self.sel {
1289                sel.set_nonblocking(self.prev);
1290            }
1291        }
1292    }
1293
1294    let _nonblock_guard = if nonblocking {
1295        if let Some(sel) = kernel_obj.as_selectable() {
1296            let prev = sel.is_nonblocking();
1297            if !prev {
1298                sel.set_nonblocking(true);
1299                Some(NonblockGuard {
1300                    sel: Some(sel),
1301                    prev,
1302                })
1303            } else {
1304                None
1305            }
1306        } else {
1307            None
1308        }
1309    } else {
1310        None
1311    };
1312
1313    // Calculate total buffer size for potential handle+data receive
1314    let total_buffer_size: usize = iovecs.iter().map(|i| i.iov_len).sum();
1315    let mut atomic_data: Option<Vec<u8>> = None;
1316
1317    // Try atomic handle+data receive if control buffer is provided
1318    if msg.msg_control != 0
1319        && (msg.msg_controllen as usize) >= size_of::<LinuxCmsghdr>() + size_of::<i32>()
1320    {
1321        let socket_arc = match &kernel_obj {
1322            KernelObject::Socket(socket) => Arc::clone(socket),
1323            _ => {
1324                return errno::to_result(errno::ENOTSOCK);
1325            }
1326        };
1327
1328        if let Some(local_socket) = LocalSocket::from_socket_object(&socket_arc) {
1329            match local_socket.recv_handle_and_data(total_buffer_size) {
1330                Ok((obj, data)) => {
1331                    let new_handle = match task.handle_table.insert(obj) {
1332                        Ok(h) => h,
1333                        Err(_) => return errno::to_result(errno::EMFILE),
1334                    };
1335                    let new_fd = match abi.allocate_fd(new_handle) {
1336                        Ok(fd) => fd,
1337                        Err(_) => {
1338                            let _ = task.handle_table.remove(new_handle);
1339                            return errno::to_result(errno::EMFILE);
1340                        }
1341                    };
1342                    pending_fd = Some(new_fd as i32);
1343                    atomic_data = Some(data);
1344                }
1345                Err(IpcError::ChannelEmpty) => {
1346                    // No handle available; fall back to regular stream read
1347                }
1348                Err(_) => {
1349                    return errno::to_result(errno::EIO);
1350                }
1351            }
1352        }
1353    }
1354
1355    // Copy data from atomic receive or read from stream
1356    if let Some(ref data) = atomic_data {
1357        // Copy atomically received data into iovecs
1358        let mut data_offset = 0;
1359        let data_len = data.len();
1360        for iovec in iovecs {
1361            if data_offset >= data_len {
1362                break;
1363            }
1364            if iovec.iov_len == 0 {
1365                continue;
1366            }
1367
1368            let buf_addr = match task.vm_manager.translate_vaddr(iovec.iov_base as usize) {
1369                Some(addr) => addr as *mut u8,
1370                None => return errno::to_result(errno::EFAULT),
1371            };
1372
1373            if buf_addr.is_null() {
1374                return errno::to_result(errno::EFAULT);
1375            }
1376
1377            let remaining = data.len() - data_offset;
1378            let to_copy = remaining.min(iovec.iov_len);
1379            let buffer = unsafe { core::slice::from_raw_parts_mut(buf_addr, to_copy) };
1380            buffer.copy_from_slice(&data[data_offset..data_offset + to_copy]);
1381            data_offset += to_copy;
1382            total_read += to_copy;
1383        }
1384    } else {
1385        // No handle received; read data from stream
1386        for iovec in iovecs {
1387            if iovec.iov_len == 0 {
1388                continue;
1389            }
1390
1391            let buf_addr = match task.vm_manager.translate_vaddr(iovec.iov_base as usize) {
1392                Some(addr) => addr as *mut u8,
1393                None => {
1394                    return errno::to_result(errno::EFAULT);
1395                }
1396            };
1397
1398            if buf_addr.is_null() {
1399                return errno::to_result(errno::EFAULT);
1400            }
1401
1402            let buffer = unsafe { core::slice::from_raw_parts_mut(buf_addr, iovec.iov_len) };
1403
1404            match stream.read(buffer) {
1405                Ok(n) => {
1406                    total_read = total_read.saturating_add(n);
1407                    if n < iovec.iov_len {
1408                        break;
1409                    }
1410                }
1411                Err(StreamError::WouldBlock) => {
1412                    return if total_read == 0 {
1413                        errno::to_result(errno::EAGAIN)
1414                    } else {
1415                        total_read
1416                    };
1417                }
1418                Err(_) => {
1419                    return errno::to_result(errno::EIO);
1420                }
1421            }
1422        }
1423    }
1424
1425    if let Some(fd_value) = pending_fd {
1426        let cmsg_addr = match task.vm_manager.translate_vaddr(msg.msg_control as usize) {
1427            Some(addr) => addr as *mut LinuxCmsghdr,
1428            None => return errno::to_result(errno::EFAULT),
1429        };
1430
1431        if cmsg_addr.is_null() {
1432            return errno::to_result(errno::EFAULT);
1433        }
1434
1435        unsafe {
1436            (*cmsg_addr).cmsg_len = size_of::<LinuxCmsghdr>() + size_of::<i32>();
1437            (*cmsg_addr).cmsg_level = SOL_SOCKET;
1438            (*cmsg_addr).cmsg_type = SCM_RIGHTS;
1439            let data_ptr = cmsg_addr.add(1) as *mut i32;
1440            *data_ptr = fd_value;
1441        }
1442
1443        msg_controllen = size_of::<LinuxCmsghdr>() + size_of::<i32>();
1444    }
1445
1446    unsafe {
1447        (*msg_addr).msg_flags = 0;
1448        (*msg_addr).msg_controllen = msg_controllen as u64;
1449    }
1450
1451    total_read
1452}
1453
1454/// Linux sys_sendto implementation
1455///
1456/// Send a message on a socket. Unlike sendmsg, this directly takes a buffer and address.
1457///
1458/// Arguments:
1459/// - abi: LinuxRiscv64Abi context
1460/// - trapframe: Trapframe containing syscall arguments
1461///   - arg0: sockfd (socket file descriptor)
1462///   - arg1: buf (pointer to data buffer)
1463///   - arg2: len (length of data)
1464///   - arg3: flags (send flags)
1465///   - arg4: dest_addr (pointer to destination address, may be NULL for connected sockets)
1466///   - arg5: addrlen (size of destination address)
1467///
1468/// Returns:
1469/// - number of bytes sent on success
1470/// - negative errno on error
1471pub fn sys_sendto(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
1472    let task = match mytask() {
1473        Some(t) => t,
1474        None => return errno::to_result(errno::ESRCH),
1475    };
1476
1477    let sockfd = trapframe.get_arg(0);
1478    let buf_ptr = trapframe.get_arg(1);
1479    let len = trapframe.get_arg(2);
1480    let flags = trapframe.get_arg(3) as u32;
1481    let dest_addr_ptr = trapframe.get_arg(4);
1482    let addrlen = trapframe.get_arg(5) as u32;
1483
1484    trapframe.increment_pc_next(task);
1485
1486    // Get socket handle
1487    let handle = match abi.get_handle(sockfd) {
1488        Some(h) => h,
1489        None => return errno::to_result(errno::EBADF),
1490    };
1491
1492    // Get socket object
1493    let socket = match task.handle_table.get(handle) {
1494        Some(KernelObject::Socket(s)) => s.clone(),
1495        _ => return errno::to_result(errno::ENOTSOCK),
1496    };
1497
1498    // Translate buffer pointer
1499    let buf_paddr = match task.vm_manager.translate_vaddr(buf_ptr) {
1500        Some(addr) => addr,
1501        None => return errno::to_result(errno::EFAULT),
1502    };
1503
1504    let data = unsafe { core::slice::from_raw_parts(buf_paddr as *const u8, len) };
1505
1506    // Parse destination address if provided
1507    let dest_addr = if dest_addr_ptr != 0 && addrlen > 0 {
1508        let addr_paddr = match task.vm_manager.translate_vaddr(dest_addr_ptr) {
1509            Some(addr) => addr,
1510            None => return errno::to_result(errno::EFAULT),
1511        };
1512
1513        // Read address family
1514        let sa_family = unsafe { *(addr_paddr as *const u16) };
1515
1516        match sa_family {
1517            AF_INET_U16 => {
1518                if addrlen < size_of::<SockaddrIn>() as u32 {
1519                    return errno::to_result(errno::EINVAL);
1520                }
1521                let sockaddr = unsafe { *(addr_paddr as *const SockaddrIn) };
1522                let port = u16::from_be(sockaddr.sin_port);
1523                let addr_bytes = sockaddr.sin_addr.to_be_bytes();
1524                crate::network::SocketAddress::Inet(crate::network::Inet4SocketAddress::new(
1525                    addr_bytes, port,
1526                ))
1527            }
1528            AF_UNIX_U16 => {
1529                // Unix domain socket sendto - usually not used for stream sockets
1530                crate::network::SocketAddress::Unspecified
1531            }
1532            _ => return errno::to_result(errno::EAFNOSUPPORT),
1533        }
1534    } else {
1535        // No address - use connected peer
1536        crate::network::SocketAddress::Unspecified
1537    };
1538
1539    // Send data
1540    match socket.sendto(data, &dest_addr, flags) {
1541        Ok(n) => n,
1542        Err(crate::network::socket::SocketError::WouldBlock) => errno::to_result(errno::EAGAIN),
1543        Err(crate::network::socket::SocketError::NotConnected) => errno::to_result(errno::ENOTCONN),
1544        Err(crate::network::socket::SocketError::NoRoute) => errno::to_result(errno::ENETUNREACH),
1545        Err(crate::network::socket::SocketError::InvalidAddress) => errno::to_result(errno::EINVAL),
1546        Err(_) => errno::to_result(errno::EIO),
1547    }
1548}
1549
1550/// Linux sys_recvfrom implementation
1551///
1552/// Receive a message from a socket. Returns the source address if provided.
1553///
1554/// Arguments:
1555/// - abi: LinuxRiscv64Abi context
1556/// - trapframe: Trapframe containing syscall arguments
1557///   - arg0: sockfd (socket file descriptor)
1558///   - arg1: buf (pointer to receive buffer)
1559///   - arg2: len (length of buffer)
1560///   - arg3: flags (receive flags)
1561///   - arg4: src_addr (pointer to store source address, may be NULL)
1562///   - arg5: addrlen (pointer to address length, input/output)
1563///
1564/// Returns:
1565/// - number of bytes received on success
1566/// - negative errno on error
1567pub fn sys_recvfrom(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
1568    let task = match mytask() {
1569        Some(t) => t,
1570        None => return errno::to_result(errno::ESRCH),
1571    };
1572
1573    let sockfd = trapframe.get_arg(0);
1574    let buf_ptr = trapframe.get_arg(1);
1575    let len = trapframe.get_arg(2);
1576    let flags = trapframe.get_arg(3) as u32;
1577    let src_addr_ptr = trapframe.get_arg(4);
1578    let addrlen_ptr = trapframe.get_arg(5);
1579
1580    trapframe.increment_pc_next(task);
1581
1582    // Get socket handle
1583    let handle = match abi.get_handle(sockfd) {
1584        Some(h) => h,
1585        None => return errno::to_result(errno::EBADF),
1586    };
1587
1588    // Get socket object
1589    let socket = match task.handle_table.get(handle) {
1590        Some(KernelObject::Socket(s)) => s.clone(),
1591        _ => return errno::to_result(errno::ENOTSOCK),
1592    };
1593
1594    // Translate buffer pointer
1595    let buf_paddr = match task.vm_manager.translate_vaddr(buf_ptr) {
1596        Some(addr) => addr,
1597        None => return errno::to_result(errno::EFAULT),
1598    };
1599
1600    let buffer = unsafe { core::slice::from_raw_parts_mut(buf_paddr as *mut u8, len) };
1601
1602    // Check for non-blocking mode
1603    let nonblocking = (flags & (MSG_DONTWAIT as u32)) != 0
1604        || abi
1605            .get_file_status_flags(sockfd)
1606            .map(|f| ((f as i32) & O_NONBLOCK) != 0)
1607            .unwrap_or(false);
1608
1609    // Set non-blocking if requested
1610    if let Some(selectable) = socket.as_selectable() {
1611        if nonblocking {
1612            selectable.set_nonblocking(true);
1613        }
1614    }
1615
1616    // Receive data
1617    let result = socket.recvfrom(buffer, flags);
1618
1619    // Restore blocking mode if we changed it
1620    if nonblocking {
1621        if let Some(selectable) = socket.as_selectable() {
1622            selectable.set_nonblocking(false);
1623        }
1624    }
1625
1626    match result {
1627        Ok((n, src_addr)) => {
1628            // Store source address if requested
1629            if src_addr_ptr != 0 && addrlen_ptr != 0 {
1630                let addrlen_paddr = match task.vm_manager.translate_vaddr(addrlen_ptr) {
1631                    Some(addr) => addr as *mut u32,
1632                    None => return errno::to_result(errno::EFAULT),
1633                };
1634
1635                let provided_len = unsafe { *addrlen_paddr };
1636
1637                match src_addr {
1638                    crate::network::SocketAddress::Inet(inet) => {
1639                        if provided_len >= size_of::<SockaddrIn>() as u32 {
1640                            let addr_paddr = match task.vm_manager.translate_vaddr(src_addr_ptr) {
1641                                Some(addr) => addr as *mut SockaddrIn,
1642                                None => return errno::to_result(errno::EFAULT),
1643                            };
1644
1645                            let sockaddr = SockaddrIn {
1646                                sin_family: AF_INET as u16,
1647                                sin_port: inet.port.to_be(),
1648                                sin_addr: u32::from_be_bytes(inet.addr),
1649                                sin_zero: [0; 8],
1650                            };
1651                            unsafe {
1652                                *addr_paddr = sockaddr;
1653                                *addrlen_paddr = size_of::<SockaddrIn>() as u32;
1654                            }
1655                        }
1656                    }
1657                    crate::network::SocketAddress::Local(_) => {
1658                        // Unix domain socket - store sockaddr_un
1659                        unsafe {
1660                            *addrlen_paddr = 0;
1661                        }
1662                    }
1663                    crate::network::SocketAddress::Unspecified => unsafe {
1664                        *addrlen_paddr = 0;
1665                    },
1666                    _ => unsafe {
1667                        *addrlen_paddr = 0;
1668                    },
1669                }
1670            }
1671            n
1672        }
1673        Err(crate::network::socket::SocketError::WouldBlock) => errno::to_result(errno::EAGAIN),
1674        Err(crate::network::socket::SocketError::NotConnected) => errno::to_result(errno::ENOTCONN),
1675        Err(_) => errno::to_result(errno::EIO),
1676    }
1677}
1678
1679/// Linux sys_socketpair implementation
1680///
1681/// Create a pair of connected sockets. This is primarily used for AF_UNIX sockets
1682/// to create a bidirectional communication channel between processes.
1683///
1684/// Arguments:
1685/// - abi: LinuxRiscv64Abi context
1686/// - trapframe: Trapframe containing syscall arguments
1687///   - arg0: domain (address family, must be AF_UNIX)
1688///   - arg1: type (socket type, e.g., SOCK_STREAM)
1689///   - arg2: protocol (usually 0)
1690///   - arg3: sv (pointer to int[2] to receive the file descriptors)
1691///
1692/// Returns:
1693/// - 0 on success
1694/// - negative errno on error
1695pub fn sys_socketpair(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
1696    let task = match mytask() {
1697        Some(t) => t,
1698        None => return errno::to_result(errno::ESRCH),
1699    };
1700
1701    let domain = trapframe.get_arg(0) as i32;
1702    let socket_type = trapframe.get_arg(1) as i32;
1703    let _protocol = trapframe.get_arg(2) as i32;
1704    let sv_ptr = trapframe.get_arg(3);
1705
1706    trapframe.increment_pc_next(task);
1707
1708    // Validate domain - socketpair only supports AF_UNIX
1709    if domain != AF_UNIX {
1710        return errno::to_result(errno::EAFNOSUPPORT);
1711    }
1712
1713    // Extract socket type flags
1714    let base_type = socket_type & SOCK_TYPE_MASK;
1715    let flags = socket_type & !SOCK_TYPE_MASK;
1716    let nonblocking = (flags & SOCK_NONBLOCK) != 0;
1717    let cloexec = (flags & SOCK_CLOEXEC) != 0;
1718
1719    // Validate socket type - we support SOCK_STREAM and SOCK_DGRAM
1720    if base_type != SOCK_STREAM && base_type != SOCK_DGRAM {
1721        return errno::to_result(errno::ESOCKTNOSUPPORT);
1722    }
1723
1724    // Translate sv pointer (needs to write 2 i32 values)
1725    let sv_paddr = match task.vm_manager.translate_vaddr(sv_ptr) {
1726        Some(addr) => addr as *mut i32,
1727        None => return errno::to_result(errno::EFAULT),
1728    };
1729
1730    // Create connected socket pair
1731    let (socket1, socket2) = LocalSocket::create_connected_pair(
1732        alloc::string::String::from("socketpair:0"),
1733        alloc::string::String::from("socketpair:1"),
1734    );
1735
1736    // Set non-blocking mode if requested
1737    if nonblocking {
1738        socket1.set_nonblocking(true);
1739        socket2.set_nonblocking(true);
1740    }
1741
1742    // Add first socket to handle table
1743    let kernel_obj1 = KernelObject::Socket(socket1);
1744    let handle1 = match task.handle_table.insert(kernel_obj1) {
1745        Ok(id) => id,
1746        Err(_) => return errno::to_result(errno::EMFILE),
1747    };
1748
1749    // Add second socket to handle table
1750    let kernel_obj2 = KernelObject::Socket(socket2);
1751    let handle2 = match task.handle_table.insert(kernel_obj2) {
1752        Ok(id) => id,
1753        Err(_) => {
1754            // Clean up handle1 if handle2 allocation fails
1755            let _ = task.handle_table.remove(handle1);
1756            return errno::to_result(errno::EMFILE);
1757        }
1758    };
1759
1760    // Allocate file descriptors
1761    let fd1 = match abi.allocate_fd(handle1) {
1762        Ok(fd) => fd,
1763        Err(_) => {
1764            let _ = task.handle_table.remove(handle1);
1765            let _ = task.handle_table.remove(handle2);
1766            return errno::to_result(errno::EMFILE);
1767        }
1768    };
1769
1770    let fd2 = match abi.allocate_fd(handle2) {
1771        Ok(fd) => fd,
1772        Err(_) => {
1773            let _ = abi.remove_fd(fd1);
1774            let _ = task.handle_table.remove(handle1);
1775            let _ = task.handle_table.remove(handle2);
1776            return errno::to_result(errno::EMFILE);
1777        }
1778    };
1779
1780    // Set flags
1781    if nonblocking {
1782        let _ = abi.set_file_status_flags(fd1, O_NONBLOCK as u32);
1783        let _ = abi.set_file_status_flags(fd2, O_NONBLOCK as u32);
1784    }
1785    if cloexec {
1786        let _ = abi.set_fd_flags(fd1, FD_CLOEXEC);
1787        let _ = abi.set_fd_flags(fd2, FD_CLOEXEC);
1788    }
1789
1790    // Write file descriptors to user space
1791    unsafe {
1792        *sv_paddr = fd1 as i32;
1793        *sv_paddr.add(1) = fd2 as i32;
1794    }
1795
1796    0
1797}
1798
1799/// Linux sys_shutdown implementation
1800///
1801/// Shut down part of a full-duplex connection.
1802///
1803/// Arguments:
1804/// - abi: LinuxRiscv64Abi context
1805/// - trapframe: Trapframe containing syscall arguments
1806///   - arg0: sockfd (socket file descriptor)
1807///   - arg1: how (0=SHUT_RD, 1=SHUT_WR, 2=SHUT_RDWR)
1808///
1809/// Returns:
1810/// - 0 on success
1811/// - negative errno on error
1812pub fn sys_shutdown(abi: &mut LinuxRiscv64Abi, trapframe: &mut Trapframe) -> usize {
1813    let task = match mytask() {
1814        Some(t) => t,
1815        None => return errno::to_result(errno::ESRCH),
1816    };
1817
1818    let sockfd = trapframe.get_arg(0);
1819    let how = trapframe.get_arg(1) as u32;
1820
1821    trapframe.increment_pc_next(task);
1822
1823    // Get socket handle
1824    let handle = match abi.get_handle(sockfd) {
1825        Some(h) => h,
1826        None => return errno::to_result(errno::EBADF),
1827    };
1828
1829    // Get socket object
1830    let socket = match task.handle_table.get(handle) {
1831        Some(KernelObject::Socket(s)) => s.clone(),
1832        _ => return errno::to_result(errno::ENOTSOCK),
1833    };
1834
1835    // Convert how to ShutdownHow
1836    let shutdown_how = match how {
1837        0 => crate::network::socket::ShutdownHow::Read,
1838        1 => crate::network::socket::ShutdownHow::Write,
1839        2 => crate::network::socket::ShutdownHow::Both,
1840        _ => return errno::to_result(errno::EINVAL),
1841    };
1842
1843    match socket.shutdown(shutdown_how) {
1844        Ok(()) => 0,
1845        Err(crate::network::socket::SocketError::NotConnected) => errno::to_result(errno::ENOTCONN),
1846        Err(_) => errno::to_result(errno::EIO),
1847    }
1848}