kernel/network/
mod.rs

1//! Network functionality for Scarlet
2//!
3//! This module provides network capabilities including sockets and protocol stacks.
4//! It follows the existing patterns established by VfsManager and DeviceManager.
5//!
6//! # Architecture
7//!
8//! - **SocketObject**: KernelObject type representing network endpoints
9//! - **NetworkManager**: Global manager handling socket lifecycle and connections
10//! - **Socket Implementations**: Provided by ABI modules (Linux, xv6, etc.)
11//!
12//! # Design Philosophy
13//!
14//! Scarlet's core provides only abstract socket infrastructure. Specific socket
15//! implementations (Unix domain sockets, TCP/IP, etc.) are delegated to ABI modules.
16//! This maintains Scarlet's OS-agnostic nature while allowing each ABI to provide
17//! the socket semantics expected by its applications.
18//!
19//! # Usage
20//!
21//! ABI modules implement SocketObject trait and register socket implementations
22//! with the NetworkManager.
23
24use alloc::{
25    collections::BTreeMap,
26    string::{String, ToString},
27    sync::{Arc, Weak},
28    vec::Vec,
29};
30use core::sync::atomic::{AtomicUsize, Ordering};
31use spin::Once;
32
33pub mod arp;
34pub mod config;
35pub mod ethernet;
36pub mod ethernet_interface;
37pub mod icmp;
38pub mod ipv4;
39pub mod local;
40pub mod protocol_stack;
41pub mod socket;
42pub mod syscall;
43pub mod tcp;
44pub mod udp;
45#[cfg(all(test, target_arch = "riscv64"))]
46pub mod virtio_net_tests;
47
48// Re-export commonly used types
49pub use protocol_stack::{
50    LayerContext, NetworkLayer, NetworkLayerStats, ProtocolStack, ProtocolStackManager,
51    ProtocolStackStats, SocketConfig,
52};
53pub use socket::{
54    Inet4SocketAddress,
55    Inet6SocketAddress,
56    LocalSocketAddress,
57    ShutdownHow,
58    SocketAddress,
59    SocketControl,
60    SocketDomain,
61    SocketError,
62    SocketObject,
63    SocketProtocol,
64    SocketState,
65    SocketType,
66    UnixSocketAddress, // Keep for backwards compatibility
67};
68
69use crate::device::network::{DevicePacket, MacAddress};
70use crate::early_println;
71use crate::network::arp::ArpCacheEntry;
72use crate::network::ipv4::Ipv4Address;
73use crate::object::KernelObject;
74
75/// Unique socket identifier
76pub type SocketId = usize;
77
78/// Socket factory function type
79///
80/// ABI modules register socket factories for their specific implementations
81pub type SocketFactory =
82    fn(SocketType, SocketProtocol) -> Result<Arc<dyn SocketObject>, SocketError>;
83
84/// Interface statistics
85#[derive(Debug, Clone, Default)]
86pub struct InterfaceStats {
87    pub tx_packets: u64,
88    pub tx_bytes: u64,
89    pub rx_packets: u64,
90    pub rx_bytes: u64,
91    pub drops: u64,
92    pub errors: u64,
93}
94
95/// Network configuration
96#[derive(Debug, Clone)]
97pub struct NetworkConfig {
98    pub default_gateway: Option<Ipv4Address>,
99    pub gateway_mac: Option<MacAddress>,
100    pub dns_server: Option<Ipv4Address>,
101    pub subnet_mask: Ipv4Address,
102}
103
104impl Default for NetworkConfig {
105    fn default() -> Self {
106        Self {
107            default_gateway: None,
108            gateway_mac: None,
109            dns_server: None,
110            subnet_mask: Ipv4Address::new(255, 255, 255, 0),
111        }
112    }
113}
114
115/// Network interface trait
116pub trait NetworkInterface: Send + Sync {
117    fn name(&self) -> &str;
118    fn mac_address(&self) -> MacAddress;
119    fn ip_address(&self) -> Option<Ipv4Address>;
120    fn set_ip_address(&self, ip: Ipv4Address);
121    fn send(&self, packet: DevicePacket) -> Result<(), &'static str>;
122    fn poll(&self) -> Result<Vec<DevicePacket>, &'static str>;
123    fn stats(&self) -> InterfaceStats;
124}
125
126/// Network Manager - Global socket and connection manager
127pub struct NetworkManager {
128    /// Socket factories per domain (registered by ABI modules)
129    socket_factories: spin::RwLock<BTreeMap<SocketDomain, SocketFactory>>,
130
131    /// Protocol stacks for network protocols (TCP/IP, UDP, etc.)
132    protocol_stacks: protocol_stack::ProtocolStackManager,
133
134    /// Protocol layers registry (shared instances like VFS filesystems)
135    protocol_layers: spin::RwLock<BTreeMap<String, Arc<dyn NetworkLayer>>>,
136
137    /// Named sockets namespace (path/name -> socket)
138    named_sockets: spin::RwLock<BTreeMap<String, Weak<dyn SocketObject>>>,
139
140    /// Active socket connections by ID
141    connections: spin::RwLock<BTreeMap<SocketId, Arc<dyn SocketObject>>>,
142
143    /// Reverse mapping: socket pointer address -> socket ID for O(1) lookups
144    socket_to_id: spin::RwLock<BTreeMap<usize, SocketId>>,
145
146    /// Next socket ID counter
147    next_socket_id: AtomicUsize,
148
149    /// Registered network interfaces
150    interfaces: spin::RwLock<BTreeMap<String, Arc<dyn NetworkInterface>>>,
151
152    /// Default interface name
153    default_interface: spin::RwLock<Option<String>>,
154
155    /// ARP cache
156    arp_cache: spin::RwLock<BTreeMap<u32, ArpCacheEntry>>,
157
158    /// Network configuration
159    network_config: spin::RwLock<NetworkConfig>,
160}
161
162impl NetworkManager {
163    /// Create a new NetworkManager instance
164    fn new() -> Self {
165        Self {
166            socket_factories: spin::RwLock::new(BTreeMap::new()),
167            protocol_stacks: protocol_stack::ProtocolStackManager::new(),
168            protocol_layers: spin::RwLock::new(BTreeMap::new()),
169            named_sockets: spin::RwLock::new(BTreeMap::new()),
170            connections: spin::RwLock::new(BTreeMap::new()),
171            socket_to_id: spin::RwLock::new(BTreeMap::new()),
172            next_socket_id: AtomicUsize::new(1),
173            interfaces: spin::RwLock::new(BTreeMap::new()),
174            default_interface: spin::RwLock::new(None),
175            arp_cache: spin::RwLock::new(BTreeMap::new()),
176            network_config: spin::RwLock::new(NetworkConfig::default()),
177        }
178    }
179
180    /// Get the global NetworkManager instance
181    pub fn get_manager() -> &'static NetworkManager {
182        GLOBAL_NETWORK_MANAGER.call_once(|| NetworkManager::new())
183    }
184
185    /// Initialize the global NetworkManager
186    ///
187    /// Initializes all protocol layers in dependency order:
188    /// 1. Ethernet (no dependencies)
189    /// 2. IPv4, ARP (depend on Ethernet)
190    /// 3. ICMP, UDP, TCP (depend on IPv4)
191    pub fn init() -> &'static NetworkManager {
192        let manager = GLOBAL_NETWORK_MANAGER.call_once(|| NetworkManager::new());
193
194        // Initialize protocol layers in dependency order
195        // Layer 1: Ethernet (no dependencies)
196        crate::network::ethernet::EthernetLayer::init(manager);
197
198        // Layer 2: IPv4 and ARP (depend on Ethernet)
199        crate::network::ipv4::Ipv4Layer::init(manager);
200        crate::network::arp::ArpLayer::init(manager);
201
202        // Layer 3: ICMP, UDP, TCP (depend on IPv4)
203        crate::network::icmp::IcmpLayer::init(manager);
204        crate::network::udp::UdpLayer::init(manager);
205        crate::network::tcp::TcpLayer::init(manager);
206
207        manager
208    }
209
210    // ===================================================================
211    // Interface Management
212    // ===================================================================
213
214    pub fn register_interface(
215        &self,
216        name: &str,
217        interface: Arc<dyn NetworkInterface>,
218    ) -> Result<(), &'static str> {
219        let mut default = self.default_interface.write();
220        if default.is_none() {
221            *default = Some(String::from(name));
222        }
223
224        let interface_clone = interface.clone();
225        self.interfaces
226            .write()
227            .insert(String::from(name), interface);
228
229        // Configure protocol layers when first interface is registered
230        if self.interfaces.read().len() == 1 {
231            self.configure_protocol_layers_with_interface(interface_clone);
232        }
233
234        Ok(())
235    }
236
237    pub fn get_interface(&self, name: &str) -> Option<Arc<dyn NetworkInterface>> {
238        self.interfaces.read().get(name).cloned()
239    }
240
241    pub fn get_default_interface(&self) -> Option<Arc<dyn NetworkInterface>> {
242        self.default_interface
243            .read()
244            .as_ref()
245            .and_then(|name| self.get_interface(name))
246    }
247
248    pub fn set_default_interface(&self, name: &str) {
249        *self.default_interface.write() = Some(String::from(name));
250    }
251
252    pub fn list_interfaces(&self) -> Vec<String> {
253        self.interfaces.read().keys().cloned().collect()
254    }
255
256    // ===================================================================
257    // Protocol Layer Configuration
258    // ===================================================================
259
260    fn configure_protocol_layers_with_interface(&self, interface: Arc<dyn NetworkInterface>) {
261        let local_ip = interface.ip_address();
262        let interface_name = interface.name();
263
264        // Configure IP layer with local IP address
265        if let Some(ip_layer) = self.get_layer("ip") {
266            if let Some(ip) = ip_layer
267                .as_any()
268                .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
269            {
270                if let Some(local_ip_addr) = local_ip {
271                    ip.add_address(
272                        interface_name,
273                        crate::network::ipv4::Ipv4AddressInfo {
274                            address: local_ip_addr,
275                            netmask: crate::network::ipv4::Ipv4Address::new(255, 255, 255, 0),
276                            broadcast: None,
277                            is_primary: true,
278                        },
279                    );
280                }
281            }
282        }
283    }
284
285    // ===================================================================
286    // ARP Cache Management
287    // ===================================================================
288
289    pub fn arp_lookup(&self, ip: &Ipv4Address) -> Option<MacAddress> {
290        let ip_u32 = u32::from_be_bytes(ip.as_bytes());
291        let cache = self.arp_cache.read();
292        cache.get(&ip_u32).and_then(|entry| {
293            if entry.is_expired() {
294                None
295            } else {
296                Some(MacAddress::new(entry.mac_address))
297            }
298        })
299    }
300
301    pub fn arp_cache_add(&self, ip: Ipv4Address, mac: MacAddress) {
302        let ip_u32 = u32::from_be_bytes(ip.as_bytes());
303        let mut cache = self.arp_cache.write();
304        cache.insert(ip_u32, ArpCacheEntry::new(ip, *mac.as_bytes()));
305    }
306
307    pub fn send_arp_request(&self, target_ip: Ipv4Address) -> Result<(), &'static str> {
308        let interface = self.get_default_interface().ok_or("No default interface")?;
309
310        let local_ip = interface.ip_address().ok_or("Interface has no IP")?;
311        let local_mac = interface.mac_address();
312
313        let arp_request =
314            crate::network::arp::ArpPacket::request(local_ip.as_bytes(), target_ip.as_bytes());
315
316        let eth_header = crate::network::ethernet::EthernetHeader::new(
317            [0xFF; 6],
318            *local_mac.as_bytes(),
319            crate::network::ethernet::ether_type::ARP,
320        );
321
322        let mut packet_data = Vec::new();
323        packet_data.extend_from_slice(&eth_header.to_bytes());
324        packet_data.extend_from_slice(&arp_request.to_bytes());
325
326        let packet = DevicePacket::with_data(packet_data);
327        interface.send(packet)
328    }
329
330    pub fn resolve_mac(&self, ip: Ipv4Address) -> Result<MacAddress, &'static str> {
331        if let Some(mac) = self.arp_lookup(&ip) {
332            return Ok(mac);
333        }
334        self.send_arp_request(ip)?;
335        Err("MAC not in cache, ARP request sent")
336    }
337
338    // ===================================================================
339    // Network Configuration
340    // ===================================================================
341
342    pub fn get_config(&self) -> NetworkConfig {
343        self.network_config.read().clone()
344    }
345
346    pub fn set_config(&self, config: NetworkConfig) {
347        *self.network_config.write() = config;
348    }
349
350    pub fn set_default_gateway(&self, gateway: Ipv4Address) {
351        self.network_config.write().default_gateway = Some(gateway);
352        self.network_config.write().gateway_mac = None;
353
354        // Add default route to Ipv4Layer's routing table
355        if let Some(default_iface) = self.get_default_interface() {
356            if let Some(ip_layer) = self.get_layer("ip") {
357                if let Some(ipv4) = ip_layer
358                    .as_any()
359                    .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
360                {
361                    ipv4.set_default_gateway(gateway, default_iface.name());
362                }
363            }
364        }
365    }
366
367    pub fn get_default_gateway(&self) -> Option<Ipv4Address> {
368        self.network_config.read().default_gateway
369    }
370
371    pub fn handle_received_packet(&self, _interface_name: &str, packet: &DevicePacket) {
372        if packet.len < 14 {
373            return;
374        }
375
376        let eth_type = u16::from_be_bytes([packet.data[12], packet.data[13]]);
377        early_println!(
378            "[net] recv frame len={} eth_type=0x{:04X}",
379            packet.len,
380            eth_type
381        );
382        match eth_type {
383            0x0806 => self.handle_arp_packet(packet),
384            0x0800 => self.handle_ipv4_packet(packet),
385            _ => {}
386        }
387    }
388
389    fn handle_arp_packet(&self, packet: &DevicePacket) {
390        if packet.len < 14 + 28 {
391            return;
392        }
393
394        let arp_data = &packet.data[14..];
395        if let Some(arp_layer) = self.get_layer("arp") {
396            if let Some(arp) = arp_layer
397                .as_any()
398                .downcast_ref::<crate::network::arp::ArpLayer>()
399            {
400                let _ = arp.receive_packet(arp_data);
401            }
402        }
403    }
404
405    fn handle_ipv4_packet(&self, packet: &DevicePacket) {
406        if packet.len < 14 + 20 {
407            return;
408        }
409
410        let ip_bytes = &packet.data[14..packet.len];
411        let header = match crate::network::ipv4::Ipv4Header::from_bytes(ip_bytes) {
412            Some(h) => h,
413            None => return,
414        };
415
416        early_println!(
417            "[IPv4] Recv frame: ip_len={} src={}.{}.{}.{} dst={}.{}.{}.{} proto={}",
418            ip_bytes.len(),
419            header.source_ip[0],
420            header.source_ip[1],
421            header.source_ip[2],
422            header.source_ip[3],
423            header.dest_ip[0],
424            header.dest_ip[1],
425            header.dest_ip[2],
426            header.dest_ip[3],
427            header.protocol
428        );
429
430        let header_len = header.header_length();
431        let total_length = usize::from(header.total_length);
432        if ip_bytes.len() < header_len {
433            return;
434        }
435        if total_length < header_len || total_length > ip_bytes.len() {
436            return;
437        }
438
439        let payload = &ip_bytes[header_len..total_length];
440        let protocol = header.protocol;
441
442        if let Some(ip_layer) = self.get_layer("ip") {
443            if let Some(ip) = ip_layer
444                .as_any()
445                .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
446            {
447                if let Some(handler) = ip.get_protocol_handler(protocol) {
448                    let src_ip = crate::network::ipv4::Ipv4Address::from_bytes(header.source_ip);
449                    let dst_ip = crate::network::ipv4::Ipv4Address::from_bytes(header.dest_ip);
450                    let _ = match protocol {
451                        crate::network::ipv4::protocol::ICMP => handler
452                            .as_any()
453                            .downcast_ref::<crate::network::icmp::IcmpLayer>()
454                            .map(|icmp| icmp.receive_packet(payload, src_ip, dst_ip)),
455                        crate::network::ipv4::protocol::TCP => handler
456                            .as_any()
457                            .downcast_ref::<crate::network::tcp::TcpLayer>()
458                            .map(|tcp| tcp.receive_packet(src_ip, dst_ip, payload)),
459                        crate::network::ipv4::protocol::UDP => handler
460                            .as_any()
461                            .downcast_ref::<crate::network::udp::UdpLayer>()
462                            .map(|udp| udp.receive_packet(src_ip, dst_ip, payload)),
463                        _ => Some(handler.receive(payload, None)),
464                    };
465                }
466            }
467        }
468    }
469
470    // ===================================================================
471    // Socket Management (existing behavior)
472    // ===================================================================
473
474    pub fn register_socket_factory(&self, domain: SocketDomain, factory: SocketFactory) {
475        self.socket_factories.write().insert(domain, factory);
476    }
477
478    pub fn create_socket(
479        &self,
480        domain: SocketDomain,
481        socket_type: SocketType,
482        protocol: SocketProtocol,
483    ) -> Result<KernelObject, SocketError> {
484        let factories = self.socket_factories.read();
485        if let Some(factory) = factories.get(&domain) {
486            let socket = factory(socket_type, protocol)?;
487            let socket_id = self.next_socket_id.fetch_add(1, Ordering::SeqCst);
488            self.connections.write().insert(socket_id, socket.clone());
489            return Ok(KernelObject::Socket(socket));
490        }
491        drop(factories);
492
493        if let Some(stack) = self.protocol_stacks.get_stack(domain) {
494            let socket = stack.create_socket(socket_type, protocol)?;
495            let socket_id = self.next_socket_id.fetch_add(1, Ordering::SeqCst);
496            self.connections.write().insert(socket_id, socket.clone());
497            return Ok(KernelObject::Socket(socket));
498        }
499
500        Err(SocketError::NotSupported)
501    }
502
503    pub fn register_protocol_stack(&self, stack: Arc<dyn protocol_stack::ProtocolStack>) {
504        self.protocol_stacks.register_stack(stack);
505    }
506
507    pub fn register_layer(&self, name: &str, layer: Arc<dyn NetworkLayer>) {
508        self.protocol_layers.write().insert(name.to_string(), layer);
509    }
510
511    pub fn unregister_layer(&self, name: &str) -> Option<Arc<dyn NetworkLayer>> {
512        self.protocol_layers.write().remove(name)
513    }
514
515    pub fn get_layer(&self, name: &str) -> Option<Arc<dyn NetworkLayer>> {
516        self.protocol_layers.read().get(name).cloned()
517    }
518
519    pub fn list_layers(&self) -> Vec<String> {
520        self.protocol_layers.read().keys().cloned().collect()
521    }
522
523    pub fn layer_count(&self) -> usize {
524        self.protocol_layers.read().len()
525    }
526
527    pub fn has_layer(&self, name: &str) -> bool {
528        self.protocol_layers.read().contains_key(name)
529    }
530
531    pub fn register_named_socket(
532        &self,
533        name: &str,
534        socket: Arc<dyn SocketObject>,
535    ) -> Result<(), SocketError> {
536        let mut sockets = self.named_sockets.write();
537        if let Some(weak_socket) = sockets.get(name) {
538            if weak_socket.upgrade().is_some() {
539                return Err(SocketError::AddressInUse);
540            }
541        }
542        sockets.insert(name.into(), Arc::downgrade(&socket));
543        Ok(())
544    }
545
546    pub fn lookup_named_socket(&self, name: &str) -> Result<Arc<dyn SocketObject>, SocketError> {
547        let sockets = self.named_sockets.read();
548        match sockets.get(name) {
549            Some(weak_socket) => weak_socket.upgrade().ok_or(SocketError::ConnectionRefused),
550            None => Err(SocketError::ConnectionRefused),
551        }
552    }
553
554    pub fn unregister_named_socket(&self, name: &str) {
555        self.named_sockets.write().remove(name);
556    }
557
558    pub fn get_socket(&self, socket_id: SocketId) -> Option<Arc<dyn SocketObject>> {
559        self.connections.read().get(&socket_id).cloned()
560    }
561
562    pub fn register_socket_with_id(
563        &self,
564        socket_id: SocketId,
565        socket: Arc<dyn SocketObject>,
566    ) -> Result<(), SocketError> {
567        let mut connections = self.connections.write();
568        if connections.contains_key(&socket_id) {
569            return Err(SocketError::AddressInUse);
570        }
571        let socket_ptr = Arc::as_ptr(&socket) as *const () as usize;
572        connections.insert(socket_id, socket);
573        drop(connections);
574        self.socket_to_id.write().insert(socket_ptr, socket_id);
575        Ok(())
576    }
577
578    pub fn remove_socket(&self, socket_id: SocketId) {
579        let mut connections = self.connections.write();
580        if let Some(socket) = connections.get(&socket_id) {
581            let socket_ptr = Arc::as_ptr(socket) as *const () as usize;
582            drop(connections);
583            self.connections.write().remove(&socket_id);
584            self.socket_to_id.write().remove(&socket_ptr);
585        }
586    }
587
588    pub fn allocate_socket_id(
589        &self,
590        socket: Arc<dyn SocketObject>,
591    ) -> Result<SocketId, SocketError> {
592        let socket = socket;
593        let start_id = self.next_socket_id.load(Ordering::SeqCst);
594        let mut current_id = start_id;
595        loop {
596            match self.register_socket_with_id(current_id, Arc::clone(&socket)) {
597                Ok(()) => {
598                    let mut observed = self.next_socket_id.load(Ordering::SeqCst);
599                    loop {
600                        if observed > current_id {
601                            break;
602                        }
603                        match self.next_socket_id.compare_exchange(
604                            observed,
605                            current_id.wrapping_add(1),
606                            Ordering::SeqCst,
607                            Ordering::SeqCst,
608                        ) {
609                            Ok(_) => break,
610                            Err(actual) => {
611                                if actual > current_id {
612                                    break;
613                                }
614                                observed = actual;
615                            }
616                        }
617                    }
618                    return Ok(current_id);
619                }
620                Err(e) => {
621                    let next_id = current_id.wrapping_add(1);
622                    if next_id == start_id {
623                        return Err(e);
624                    }
625                    current_id = next_id;
626                }
627            }
628        }
629    }
630
631    pub fn get_socket_id(&self, socket: &Arc<dyn SocketObject>) -> Option<SocketId> {
632        let socket_ptr = Arc::as_ptr(socket) as *const () as usize;
633        self.socket_to_id.read().get(&socket_ptr).copied()
634    }
635
636    pub fn connection_count(&self) -> usize {
637        self.connections.read().len()
638    }
639
640    pub fn named_socket_count(&self) -> usize {
641        let sockets = self.named_sockets.read();
642        sockets.values().filter(|s| s.upgrade().is_some()).count()
643    }
644}
645
646/// Global network manager instance
647static GLOBAL_NETWORK_MANAGER: Once<NetworkManager> = Once::new();
648
649/// Get the global network manager
650pub fn get_network_manager() -> &'static NetworkManager {
651    NetworkManager::get_manager()
652}