kernel/network/
syscall.rs

1//! Socket System Calls for Scarlet Native
2//!
3//! This module implements socket system calls specifically for Scarlet Native.
4//! Unlike POSIX sockets, these are designed around Scarlet's handle-based architecture.
5//!
6//! # Design Principles
7//!
8//! 1. **Handle-based**: Returns handle IDs instead of file descriptors
9//! 2. **Scarlet-native**: Uses LocalSocket for IPC, not POSIX Unix domain sockets
10//! 3. **Path-based naming**: Sockets are identified by filesystem-like paths
11//! 4. **Simple and direct**: Minimal abstraction for kernel IPC
12//!
13//! # System Call Interface
14//!
15//! - `sys_socket_create()` - Create a new socket (returns handle ID)
16//! - `sys_socket_bind()` - Bind socket to a path
17//! - `sys_socket_listen()` - Start listening for connections
18//! - `sys_socket_connect()` - Connect to a named socket
19//! - `sys_socket_accept()` - Accept an incoming connection (returns new handle)
20//! - `sys_socketpair()` - Create a connected socket pair (for IPC)
21//! - `sys_socket_shutdown()` - Shutdown socket (read, write, or both)
22//!
23//! # Usage Example
24//!
25//! ```rust,ignore
26//! // Server side
27//! let server_handle = sys_socket_create();
28//! sys_socket_bind(server_handle, "/tmp/server.sock");
29//! sys_socket_listen(server_handle, 5);
30//! let client_handle = sys_socket_accept(server_handle);
31//!
32//! // Client side
33//! let client_handle = sys_socket_create();
34//! sys_socket_connect(client_handle, "/tmp/server.sock");
35//!
36//! // IPC pair (simpler)
37//! let [handle1, handle2] = sys_socketpair();
38//! ```
39
40use alloc::string::String;
41use alloc::sync::Arc;
42use alloc::vec;
43use alloc::vec::Vec;
44
45use crate::arch::Trapframe;
46use crate::network::{
47    Inet4SocketAddress, Ipv4Address, LocalSocketAddress, NetworkManager, ShutdownHow,
48    SocketAddress, SocketDomain, SocketObject, SocketProtocol, SocketType, local::LocalSocket,
49};
50use crate::object::KernelObject;
51use crate::object::handle::{AccessMode, HandleMetadata, HandleType};
52use crate::task::mytask;
53
54#[repr(C)]
55#[derive(Clone, Copy)]
56struct NetworkSetIpv4Request {
57    iface_ptr: usize,
58    iface_len: usize,
59    addr: [u8; 4],
60}
61
62fn read_user_string(ptr: usize, len: usize) -> Option<String> {
63    let task = mytask()?;
64    if len == 0 {
65        return None;
66    }
67    let addr = task.vm_manager.translate_vaddr(ptr)? as *const u8;
68    if len > 256 {
69        return None;
70    }
71    let mut bytes = Vec::with_capacity(len);
72    unsafe {
73        for i in 0..len {
74            bytes.push(*addr.add(i));
75        }
76    }
77    String::from_utf8(bytes).ok()
78}
79
80fn read_user_ipv4(ptr: usize) -> Option<Ipv4Address> {
81    let task = mytask()?;
82    let addr = task.vm_manager.translate_vaddr(ptr)? as *const u8;
83    unsafe {
84        let bytes = [*addr, *addr.add(1), *addr.add(2), *addr.add(3)];
85        Some(Ipv4Address::from_bytes(bytes))
86    }
87}
88
89pub fn sys_network_set_ipv4(tf: &mut Trapframe) -> usize {
90    let task = match mytask() {
91        Some(task) => task,
92        None => return usize::MAX,
93    };
94    tf.increment_pc_next(task);
95
96    let req_ptr = tf.get_arg(0);
97    let req_addr = match task.vm_manager.translate_vaddr(req_ptr) {
98        Some(addr) => addr as *const NetworkSetIpv4Request,
99        None => return usize::MAX,
100    };
101
102    let req = unsafe { *req_addr };
103    let iface = match read_user_string(req.iface_ptr, req.iface_len) {
104        Some(name) => name,
105        None => return usize::MAX,
106    };
107    let ip = Ipv4Address::from_bytes(req.addr);
108
109    if crate::network::get_network_manager()
110        .get_interface(&iface)
111        .is_none()
112    {
113        return usize::MAX;
114    }
115
116    match crate::network::config::set_interface_ip(&iface, ip) {
117        Ok(()) => 0,
118        Err(_) => usize::MAX,
119    }
120}
121
122pub fn sys_network_set_gateway(tf: &mut Trapframe) -> usize {
123    let task = match mytask() {
124        Some(task) => task,
125        None => return usize::MAX,
126    };
127    tf.increment_pc_next(task);
128
129    let addr_ptr = tf.get_arg(0);
130    let gateway = match read_user_ipv4(addr_ptr) {
131        Some(addr) => addr,
132        None => return usize::MAX,
133    };
134    crate::network::get_network_manager().set_default_gateway(gateway);
135    0
136}
137
138pub fn sys_network_set_dns(tf: &mut Trapframe) -> usize {
139    let task = match mytask() {
140        Some(task) => task,
141        None => return usize::MAX,
142    };
143    tf.increment_pc_next(task);
144
145    let addr_ptr = tf.get_arg(0);
146    let dns = match read_user_ipv4(addr_ptr) {
147        Some(addr) => addr,
148        None => return usize::MAX,
149    };
150    let manager = crate::network::get_network_manager();
151    let mut config = manager.get_config();
152    config.dns_server = Some(dns);
153    manager.set_config(config);
154    0
155}
156
157pub fn sys_network_set_netmask(tf: &mut Trapframe) -> usize {
158    let task = match mytask() {
159        Some(task) => task,
160        None => return usize::MAX,
161    };
162    tf.increment_pc_next(task);
163
164    let addr_ptr = tf.get_arg(0);
165    let mask = match read_user_ipv4(addr_ptr) {
166        Some(addr) => addr,
167        None => return usize::MAX,
168    };
169    let manager = crate::network::get_network_manager();
170    let mut config = manager.get_config();
171    config.subnet_mask = mask;
172    manager.set_config(config);
173    0
174}
175
176pub fn sys_network_list_interfaces(tf: &mut Trapframe) -> usize {
177    let task = match mytask() {
178        Some(task) => task,
179        None => return usize::MAX,
180    };
181    tf.increment_pc_next(task);
182
183    let buf_ptr = tf.get_arg(0);
184    let buf_len = tf.get_arg(1);
185    if buf_ptr == 0 || buf_len == 0 {
186        return usize::MAX;
187    }
188
189    let buf_addr = match task.vm_manager.translate_vaddr(buf_ptr) {
190        Some(addr) => addr as *mut u8,
191        None => return usize::MAX,
192    };
193
194    let interfaces = crate::network::get_network_manager().list_interfaces();
195    let mut output = String::new();
196    for (idx, name) in interfaces.iter().enumerate() {
197        if idx > 0 {
198            output.push('\n');
199        }
200        output.push_str(name);
201    }
202
203    let bytes = output.as_bytes();
204    let copy_len = bytes.len().min(buf_len);
205    unsafe {
206        core::ptr::copy_nonoverlapping(bytes.as_ptr(), buf_addr, copy_len);
207    }
208    copy_len
209}
210
211/// System call: Create a new socket
212///
213/// Creates a Scarlet Native local socket for IPC.
214///
215/// # Arguments (via trapframe)
216///
217/// - `a0`: Socket domain (SocketDomain)
218/// - `a1`: Socket type (SocketType)
219/// - `a2`: Socket protocol (SocketProtocol)
220///
221/// # Returns
222///
223/// Handle ID of the newly created socket (> 0), or error code (usize::MAX for -1).
224///
225/// # Errors
226///
227/// Returns usize::MAX (-1) if:
228/// - Failed to allocate handle
229/// - Internal error creating socket
230pub fn sys_socket_create(tf: &mut Trapframe) -> usize {
231    let task = match mytask() {
232        Some(task) => task,
233        None => return usize::MAX,
234    };
235
236    tf.increment_pc_next(task);
237
238    let domain = tf.get_arg(0) as u32;
239    let socket_type = tf.get_arg(1) as u32;
240    let protocol = tf.get_arg(2) as u32;
241
242    let domain = match domain {
243        0 | 1 => SocketDomain::Local,
244        2 => SocketDomain::Inet4,
245        3 => SocketDomain::Inet6,
246        _ => return usize::MAX,
247    };
248
249    let socket_type = match socket_type {
250        0 | 1 => SocketType::Stream,
251        2 => SocketType::Datagram,
252        3 => SocketType::Raw,
253        4 => SocketType::SeqPacket,
254        _ => return usize::MAX,
255    };
256
257    let protocol = match protocol {
258        0 => SocketProtocol::Default,
259        1 => SocketProtocol::Icmp,
260        6 => SocketProtocol::Tcp,
261        17 => SocketProtocol::Udp,
262        value => SocketProtocol::Raw(value as u16),
263    };
264
265    let protocol = match (socket_type, protocol) {
266        (SocketType::Stream, SocketProtocol::Default) => SocketProtocol::Tcp,
267        (SocketType::Datagram, SocketProtocol::Default) => SocketProtocol::Udp,
268        (SocketType::Raw, SocketProtocol::Default) => SocketProtocol::Raw(0),
269        _ => protocol,
270    };
271
272    let socket = match domain {
273        SocketDomain::Local => {
274            let socket = Arc::new(LocalSocket::new(socket_type, protocol));
275            LocalSocket::init_self_weak(&socket);
276            socket as Arc<dyn SocketObject>
277        }
278        SocketDomain::Inet4 | SocketDomain::Inet6 => {
279            let manager = NetworkManager::get_manager();
280            let socket = match protocol {
281                SocketProtocol::Tcp => manager.get_layer("tcp").map(|layer| {
282                    let tcp = layer
283                        .as_any()
284                        .downcast_ref::<crate::network::tcp::TcpLayer>()
285                        .expect("tcp layer type mismatch");
286                    tcp.create_socket() as Arc<dyn SocketObject>
287                }),
288                SocketProtocol::Udp => manager.get_layer("udp").map(|layer| {
289                    let udp = layer
290                        .as_any()
291                        .downcast_ref::<crate::network::udp::UdpLayer>()
292                        .expect("udp layer type mismatch");
293                    udp.create_socket() as Arc<dyn SocketObject>
294                }),
295                SocketProtocol::Icmp => manager.get_layer("icmp").map(|layer| {
296                    let icmp = layer
297                        .as_any()
298                        .downcast_ref::<crate::network::icmp::IcmpLayer>()
299                        .expect("icmp layer type mismatch");
300                    icmp.create_socket() as Arc<dyn SocketObject>
301                }),
302                _ => None,
303            };
304
305            match socket {
306                Some(socket) => socket,
307                None => return usize::MAX,
308            }
309        }
310        SocketDomain::Packet => return usize::MAX,
311    };
312
313    // Register socket with NetworkManager to get a socket ID for VFS integration
314    let socket_id = match NetworkManager::get_manager().allocate_socket_id(socket.clone()) {
315        Ok(id) => id,
316        Err(_) => return usize::MAX,
317    };
318
319    // Wrap in KernelObject
320    let kernel_obj = KernelObject::Socket(socket);
321
322    // Create metadata for the socket handle
323    let metadata = HandleMetadata {
324        handle_type: HandleType::IpcChannel,
325        access_mode: AccessMode::ReadWrite,
326        special_semantics: None,
327    };
328
329    // Add to handle table with metadata
330    let handle_id = match task.handle_table.insert_with_metadata(kernel_obj, metadata) {
331        Ok(id) => id as usize,
332        Err(_) => {
333            // Clean up on error
334            NetworkManager::get_manager().remove_socket(socket_id);
335            return usize::MAX;
336        }
337    };
338
339    handle_id
340}
341
342/// System call: Bind socket to a path
343///
344/// Binds a socket to a named path in the socket namespace.
345/// This allows other processes to connect to this socket by name.
346///
347/// # Arguments (via trapframe)
348///
349/// - `a0`: Socket handle ID
350/// - `a1`: Pointer to path string (null-terminated)
351/// - `a2`: Length of path string (excluding null terminator)
352///
353/// # Returns
354///
355/// 0 on success, usize::MAX (-1) on error
356///
357/// # Errors
358///
359/// Returns usize::MAX (-1) if:
360/// - Invalid handle ID
361/// - Invalid path pointer or length
362/// - Path already in use
363/// - Socket already bound
364pub fn sys_socket_bind(tf: &mut Trapframe) -> usize {
365    let task = match mytask() {
366        Some(task) => task,
367        None => return usize::MAX,
368    };
369
370    tf.increment_pc_next(task);
371
372    let handle_id = tf.get_arg(0) as u32;
373    let path_ptr = tf.get_arg(1);
374    let path_len = tf.get_arg(2);
375
376    // Get the socket from handle table
377    let socket_arc = match task.handle_table.get(handle_id) {
378        Some(KernelObject::Socket(socket)) => socket.clone(),
379        _ => return usize::MAX,
380    };
381
382    // Translate pointer to physical address
383    let path_physical = match task.vm_manager.translate_vaddr(path_ptr) {
384        Some(addr) => addr as *const u8,
385        None => return usize::MAX,
386    };
387
388    if path_len == core::mem::size_of::<Inet4SocketAddress>() {
389        let addr = unsafe { *(path_physical as *const Inet4SocketAddress) };
390        if socket_arc.bind(&SocketAddress::Inet(addr)).is_err() {
391            return usize::MAX;
392        }
393        return 0;
394    }
395
396    // Read path string from user space (up to path_len bytes)
397    let path = unsafe {
398        let mut bytes = alloc::vec::Vec::with_capacity(path_len.min(108)); // Socket path limit
399        for i in 0..path_len.min(108) {
400            let byte = *path_physical.add(i);
401            if byte == 0 {
402                break;
403            }
404            bytes.push(byte);
405        }
406        match alloc::string::String::from_utf8(bytes) {
407            Ok(s) => s,
408            Err(_) => return usize::MAX,
409        }
410    };
411
412    // Bind the socket to the path
413    let local_addr = match LocalSocketAddress::from_path(path.clone()) {
414        Ok(addr) => addr,
415        Err(_) => return usize::MAX,
416    };
417
418    // Bind updates the socket's internal state
419    if socket_arc.bind(&SocketAddress::Local(local_addr)).is_err() {
420        return usize::MAX;
421    }
422
423    // Register the same Arc in NetworkManager's named socket namespace
424    // This ensures the registered socket and the one in handle_table are identical
425    if NetworkManager::get_manager()
426        .register_named_socket(&path, socket_arc.clone())
427        .is_err()
428    {
429        return usize::MAX;
430    }
431
432    // Get the socket ID from NetworkManager for VFS integration
433    let socket_id = match NetworkManager::get_manager().get_socket_id(&socket_arc) {
434        Some(id) => id,
435        None => return usize::MAX, // Socket not found in NetworkManager
436    };
437
438    // Create socket file in VFS for filesystem visibility
439    // Note: This is optional - the socket is already functional via named_sockets
440    let vfs_guard = task.vfs.read();
441    let vfs = match vfs_guard.as_ref() {
442        Some(vfs) => vfs.clone(),
443        None => {
444            // Use global VFS if task doesn't have its own
445            crate::fs::vfs_v2::manager::get_global_vfs_manager()
446        }
447    };
448
449    let socket_file_type = crate::fs::FileType::Socket(crate::fs::SocketFileInfo { socket_id });
450
451    // Attempt to create the socket file in VFS
452    // This may fail if:
453    // - Parent directory doesn't exist
454    // - File already exists
455    // - Path is invalid
456    // - Filesystem doesn't support socket files
457    // Since the socket is already bound and registered in named_sockets,
458    // we treat VFS file creation as optional and don't fail the bind operation
459    if let Err(e) = vfs.create_file(&path, socket_file_type) {
460        // Log the error for debugging but continue - socket is still usable
461        crate::early_println!(
462            "[socket_bind] Warning: Failed to create VFS socket file at '{}': {:?}",
463            path,
464            e
465        );
466    }
467
468    0
469}
470
471/// System call: Listen for connections
472///
473/// Marks a socket as passive (listening for connections).
474///
475/// # Arguments (via trapframe)
476///
477/// - `a0`: Socket handle ID
478/// - `a1`: Maximum backlog size (number of pending connections)
479///
480/// # Returns
481///
482/// 0 on success, usize::MAX (-1) on error
483///
484/// # Errors
485///
486/// Returns usize::MAX (-1) if:
487/// - Invalid handle ID
488/// - Socket not bound
489/// - Socket already listening or connected
490pub fn sys_socket_listen(tf: &mut Trapframe) -> usize {
491    let task = match mytask() {
492        Some(task) => task,
493        None => return usize::MAX,
494    };
495
496    tf.increment_pc_next(task);
497
498    let handle_id = tf.get_arg(0) as u32;
499    let backlog = tf.get_arg(1);
500
501    // Get the socket from handle table
502    let socket = match task.handle_table.get(handle_id) {
503        Some(KernelObject::Socket(socket)) => socket.clone(),
504        _ => {
505            crate::println!("[sys_socket_listen] Invalid handle {}", handle_id);
506            return usize::MAX;
507        }
508    };
509
510    // Start listening
511    match socket.listen(backlog) {
512        Ok(()) => {
513            crate::println!("[sys_socket_listen] Socket {} now listening", handle_id);
514            0
515        }
516        Err(e) => {
517            crate::println!("[sys_socket_listen] listen() failed: {:?}", e);
518            usize::MAX
519        }
520    }
521}
522
523/// System call: Connect to a named socket
524///
525/// Connects a socket to another socket identified by path.
526///
527/// # Arguments (via trapframe)
528///
529/// - `a0`: Socket handle ID
530/// - `a1`: Pointer to path string (null-terminated)
531/// - `a2`: Length of path string (excluding null terminator)
532///
533/// # Returns
534///
535/// 0 on success, usize::MAX (-1) on error
536///
537/// # Errors
538///
539/// Returns usize::MAX (-1) if:
540/// - Invalid handle ID
541/// - Invalid path pointer or length
542/// - Target socket not found or not listening
543/// - Socket already connected
544pub fn sys_socket_connect(tf: &mut Trapframe) -> usize {
545    let task = match mytask() {
546        Some(task) => task,
547        None => return usize::MAX,
548    };
549
550    tf.increment_pc_next(task);
551
552    let handle_id = tf.get_arg(0) as u32;
553    let path_ptr = tf.get_arg(1);
554    let path_len = tf.get_arg(2);
555
556    // Get the socket from handle table
557    let socket = match task.handle_table.get(handle_id) {
558        Some(KernelObject::Socket(socket)) => socket.clone(),
559        _ => return usize::MAX,
560    };
561
562    // Translate pointer to physical address
563    let path_physical = match task.vm_manager.translate_vaddr(path_ptr) {
564        Some(addr) => addr as *const u8,
565        None => return usize::MAX,
566    };
567
568    if path_len == core::mem::size_of::<Inet4SocketAddress>() {
569        let addr = unsafe { *(path_physical as *const Inet4SocketAddress) };
570        if socket.connect(&SocketAddress::Inet(addr)).is_err() {
571            return usize::MAX;
572        }
573        return 0;
574    }
575
576    // Read path string from user space (up to path_len bytes)
577    let path = unsafe {
578        let mut bytes = alloc::vec::Vec::with_capacity(path_len.min(108)); // Socket path limit
579        for i in 0..path_len.min(108) {
580            let byte = *path_physical.add(i);
581            if byte == 0 {
582                break;
583            }
584            bytes.push(byte);
585        }
586        match alloc::string::String::from_utf8(bytes) {
587            Ok(s) => s,
588            Err(_) => return usize::MAX,
589        }
590    };
591
592    // Create socket address and connect
593    let peer_addr = match LocalSocketAddress::from_path(&path) {
594        Ok(addr) => addr,
595        Err(_) => return usize::MAX,
596    };
597
598    // Connect the socket - this updates its internal state
599    if socket.connect(&SocketAddress::Local(peer_addr)).is_err() {
600        return usize::MAX;
601    }
602
603    0
604}
605
606/// System call: Accept an incoming connection
607///
608/// Accepts a connection from the socket's backlog queue.
609/// This blocks if no connections are pending (in a real implementation,
610/// should return WouldBlock for non-blocking sockets).
611///
612/// # Arguments (via trapframe)
613///
614/// - `a0`: Listening socket handle ID
615///
616/// # Returns
617///
618/// Handle ID of the accepted connection socket (> 0), or usize::MAX (-1) on error
619///
620/// # Errors
621///
622/// Returns usize::MAX (-1) if:
623/// - Invalid handle ID
624/// - Socket not in listening state
625/// - No pending connections (would block)
626/// - Failed to allocate handle for new socket
627pub fn sys_socket_accept(tf: &mut Trapframe) -> usize {
628    let task = match mytask() {
629        Some(task) => task,
630        None => return usize::MAX,
631    };
632
633    tf.increment_pc_next(task);
634
635    let handle_id = tf.get_arg(0) as u32;
636
637    // Get the listening socket from handle table
638    let socket_obj = match task.handle_table.get(handle_id) {
639        Some(KernelObject::Socket(socket)) => socket.clone(),
640        Some(_) => return usize::MAX,
641        None => return usize::MAX,
642    };
643
644    // Try to downcast to LocalSocket or TcpSocket
645    use crate::network::local::LocalSocket;
646
647    let accepted_socket = if let Some(local_socket) = LocalSocket::from_socket_object(&socket_obj) {
648        // LocalSocket accept
649        match local_socket.accept_blocking(task.get_id(), tf) {
650            Ok(socket) => socket,
651            Err(e) => {
652                crate::println!(
653                    "[sys_socket_accept] LocalSocket accept_blocking failed: {:?}",
654                    e
655                );
656                return usize::MAX;
657            }
658        }
659    } else if let Some(tcp_socket) = crate::network::tcp::TcpSocket::from_socket_object(&socket_obj)
660    {
661        // TcpSocket accept
662        match tcp_socket.accept_blocking(task.get_id(), tf) {
663            Ok(socket) => socket,
664            Err(_) => return usize::MAX,
665        }
666    } else {
667        crate::println!("[sys_socket_accept] Not a supported socket type");
668        return usize::MAX;
669    };
670
671    // Add the accepted socket to handle table
672    let kernel_obj = KernelObject::Socket(accepted_socket);
673    let metadata = HandleMetadata {
674        handle_type: HandleType::IpcChannel,
675        access_mode: AccessMode::ReadWrite,
676        special_semantics: None,
677    };
678
679    match task.handle_table.insert_with_metadata(kernel_obj, metadata) {
680        Ok(id) => id as usize,
681        Err(_) => usize::MAX,
682    }
683}
684
685/// System call: Create a connected socket pair
686///
687/// Creates two connected local sockets for IPC.
688/// This is more efficient than bind/connect for simple bidirectional communication.
689///
690/// # Arguments (via trapframe)
691///
692/// - `a0`: Pointer to array[2] for storing handle IDs
693///
694/// # Returns
695///
696/// 0 on success, usize::MAX (-1) on error.
697/// On success, the handle IDs are written to the array pointed to by a0.
698///
699/// # Errors
700///
701/// Returns usize::MAX (-1) if:
702/// - Invalid array pointer
703/// - Failed to create socket pair
704/// - Failed to allocate handles
705pub fn sys_socketpair(tf: &mut Trapframe) -> usize {
706    let task = match mytask() {
707        Some(task) => task,
708        None => return usize::MAX,
709    };
710
711    tf.increment_pc_next(task);
712
713    let array_ptr = tf.get_arg(0);
714
715    // Validate pointer (check if we can write 2 usizes = 16 bytes)
716    let array_vaddr = match task.vm_manager.translate_vaddr(array_ptr) {
717        Some(addr) => addr as *mut usize,
718        None => return usize::MAX,
719    };
720
721    // Create a connected socket pair using LocalSocket::create_connected_pair
722    let (socket1, socket2) = LocalSocket::create_connected_pair(
723        String::from("socketpair:0"),
724        String::from("socketpair:1"),
725    );
726
727    // Add both sockets to handle table
728    let kernel_obj1 = KernelObject::Socket(socket1);
729    let metadata = HandleMetadata {
730        handle_type: HandleType::IpcChannel,
731        access_mode: AccessMode::ReadWrite,
732        special_semantics: None,
733    };
734
735    let handle1 = match task
736        .handle_table
737        .insert_with_metadata(kernel_obj1, metadata.clone())
738    {
739        Ok(id) => id as usize,
740        Err(_) => return usize::MAX,
741    };
742
743    let kernel_obj2 = KernelObject::Socket(socket2);
744    let handle2 = match task
745        .handle_table
746        .insert_with_metadata(kernel_obj2, metadata)
747    {
748        Ok(id) => id as usize,
749        Err(_) => {
750            // Clean up handle1 if handle2 allocation fails
751            let _ = task.handle_table.remove(handle1 as u32);
752            return usize::MAX;
753        }
754    };
755
756    // Write handle IDs to user space array
757    unsafe {
758        array_vaddr.write(handle1);
759        array_vaddr.add(1).write(handle2);
760    }
761
762    0
763}
764
765/// System call: Shutdown socket
766///
767/// Shuts down part or all of a socket connection.
768///
769/// # Arguments (via trapframe)
770///
771/// - `a0`: Socket handle ID
772/// - `a1`: How to shutdown (0 = read, 1 = write, 2 = both)
773///
774/// # Returns
775///
776/// 0 on success, usize::MAX (-1) on error
777///
778/// # Errors
779///
780/// Returns usize::MAX (-1) if:
781/// - Invalid handle ID
782/// - Invalid shutdown mode
783/// - Socket not connected
784pub fn sys_socket_shutdown(tf: &mut Trapframe) -> usize {
785    let task = match mytask() {
786        Some(task) => task,
787        None => return usize::MAX,
788    };
789
790    tf.increment_pc_next(task);
791
792    let handle_id = tf.get_arg(0) as u32;
793    let how_value = tf.get_arg(1);
794
795    // Get the socket from handle table
796    let socket = match task.handle_table.get(handle_id) {
797        Some(KernelObject::Socket(socket)) => socket.clone(),
798        _ => return usize::MAX,
799    };
800
801    // Parse shutdown mode
802    let how = match how_value {
803        0 => ShutdownHow::Read,
804        1 => ShutdownHow::Write,
805        2 => ShutdownHow::Both,
806        _ => return usize::MAX,
807    };
808
809    // Shutdown the socket
810    if socket.shutdown(how).is_err() {
811        return usize::MAX;
812    }
813
814    0
815}
816
817/// System call: Receive datagram with sender address
818///
819/// Receives a datagram from a socket and returns the sender's address.
820/// Used for UDP and Local datagram sockets.
821///
822/// # Arguments (via trapframe)
823///
824/// - `a0`: Socket handle ID
825/// - `a1`: Pointer to buffer for receiving data
826/// - `a2`: Buffer length
827/// - `a3`: Pointer to SocketAddress structure for storing sender address (can be null)
828///
829/// # Returns
830///
831/// Number of bytes received on success, usize::MAX (-1) on error
832///
833/// # Errors
834///
835/// Returns usize::MAX (-1) if:
836/// - Invalid handle ID
837/// - Invalid buffer pointer
838/// - Socket error
839pub fn sys_socket_recvfrom(tf: &mut Trapframe) -> usize {
840    let task = match mytask() {
841        Some(task) => task,
842        None => return usize::MAX,
843    };
844
845    tf.increment_pc_next(task);
846
847    let handle_id = tf.get_arg(0) as u32;
848    let buf_ptr = tf.get_arg(1);
849    let buf_len = tf.get_arg(2);
850    let addr_ptr = tf.get_arg(3);
851
852    // Validate buffer pointer
853    let buf_vaddr = match task.vm_manager.translate_vaddr(buf_ptr) {
854        Some(addr) => addr as *mut u8,
855        None => return usize::MAX,
856    };
857
858    // Get the socket from handle table
859    let socket = match task.handle_table.get(handle_id) {
860        Some(KernelObject::Socket(socket)) => socket.clone(),
861        _ => return usize::MAX,
862    };
863
864    // Create a temporary buffer
865    let mut temp_buf = vec![0u8; buf_len];
866
867    // Receive datagram
868    match socket.recvfrom(&mut temp_buf, 0) {
869        Ok((len, addr)) => {
870            // Copy data to user buffer
871            unsafe {
872                core::ptr::copy_nonoverlapping(temp_buf.as_ptr(), buf_vaddr, len);
873            }
874
875            // Store sender address if pointer is provided
876            if addr_ptr != 0 {
877                if let Some(addr_vaddr) = task.vm_manager.translate_vaddr(addr_ptr) {
878                    unsafe {
879                        match addr {
880                            SocketAddress::Inet(inet) => {
881                                let addr_bytes = inet.addr;
882                                let port_bytes = inet.port.to_be_bytes();
883                                let ptr = addr_vaddr as *mut u8;
884                                *ptr = 2; // AF_INET
885                                *ptr.add(1) = 0;
886                                core::ptr::copy_nonoverlapping(addr_bytes.as_ptr(), ptr.add(2), 4);
887                                core::ptr::copy_nonoverlapping(port_bytes.as_ptr(), ptr.add(6), 2);
888                            }
889                            _ => {}
890                        }
891                    }
892                }
893            }
894
895            len
896        }
897        Err(crate::network::socket::SocketError::WouldBlock) => (-(11i32)) as usize,
898        Err(_) => usize::MAX,
899    }
900}
901
902/// System call: Send datagram to specified address
903///
904/// Sends a datagram to a specific address.
905/// Used for UDP and Local datagram sockets.
906///
907/// # Arguments (via trapframe)
908///
909/// - `a0`: Socket handle ID
910/// - `a1`: Pointer to data buffer
911/// - `a2`: Data length
912/// - `a3`: Pointer to SocketAddress structure (destination address)
913///
914/// # Returns
915///
916/// Number of bytes sent on success, usize::MAX (-1) on error
917///
918/// # Errors
919///
920/// Returns usize::MAX (-1) if:
921/// - Invalid handle ID
922/// - Invalid buffer pointer
923/// - Invalid address
924/// - Socket error
925pub fn sys_socket_sendto(tf: &mut Trapframe) -> usize {
926    let task = match mytask() {
927        Some(task) => task,
928        None => return usize::MAX,
929    };
930
931    tf.increment_pc_next(task);
932
933    let handle_id = tf.get_arg(0) as u32;
934    let buf_ptr = tf.get_arg(1);
935    let buf_len = tf.get_arg(2);
936    let addr_ptr = tf.get_arg(3);
937
938    // Validate buffer pointer
939    let buf_vaddr = match task.vm_manager.translate_vaddr(buf_ptr) {
940        Some(addr) => addr as *const u8,
941        None => return usize::MAX,
942    };
943
944    // Read data from user buffer
945    let data: Vec<u8> = unsafe { core::slice::from_raw_parts(buf_vaddr, buf_len).to_vec() };
946
947    // Get the socket from handle table
948    let socket = match task.handle_table.get(handle_id) {
949        Some(KernelObject::Socket(socket)) => socket.clone(),
950        _ => return usize::MAX,
951    };
952
953    // Parse destination address
954    let addr = if addr_ptr != 0 {
955        match task.vm_manager.translate_vaddr(addr_ptr) {
956            Some(addr_vaddr) => {
957                unsafe {
958                    let ptr = addr_vaddr as *const u8;
959                    let family = *ptr;
960                    match family {
961                        2 => {
962                            // AF_INET
963                            let ip_bytes = [*ptr.add(2), *ptr.add(3), *ptr.add(4), *ptr.add(5)];
964                            let port = u16::from_be_bytes([*ptr.add(6), *ptr.add(7)]);
965                            SocketAddress::Inet(Inet4SocketAddress::new(ip_bytes, port))
966                        }
967                        _ => return usize::MAX,
968                    }
969                }
970            }
971            None => return usize::MAX,
972        }
973    } else {
974        return usize::MAX;
975    };
976
977    // Send datagram
978    match socket.sendto(&data, &addr, 0) {
979        Ok(len) => len,
980        Err(_) => usize::MAX,
981    }
982}