1use alloc::{
25 collections::BTreeMap,
26 string::{String, ToString},
27 sync::{Arc, Weak},
28 vec::Vec,
29};
30use core::sync::atomic::{AtomicUsize, Ordering};
31use spin::Once;
32
33pub mod arp;
34pub mod config;
35pub mod ethernet;
36pub mod ethernet_interface;
37pub mod icmp;
38pub mod ipv4;
39pub mod local;
40pub mod protocol_stack;
41pub mod socket;
42pub mod syscall;
43pub mod tcp;
44pub mod udp;
45#[cfg(all(test, target_arch = "riscv64"))]
46pub mod virtio_net_tests;
47
48pub use protocol_stack::{
50 LayerContext, NetworkLayer, NetworkLayerStats, ProtocolStack, ProtocolStackManager,
51 ProtocolStackStats, SocketConfig,
52};
53pub use socket::{
54 Inet4SocketAddress,
55 Inet6SocketAddress,
56 LocalSocketAddress,
57 ShutdownHow,
58 SocketAddress,
59 SocketControl,
60 SocketDomain,
61 SocketError,
62 SocketObject,
63 SocketProtocol,
64 SocketState,
65 SocketType,
66 UnixSocketAddress, };
68
69use crate::device::network::{DevicePacket, MacAddress};
70use crate::early_println;
71use crate::network::arp::ArpCacheEntry;
72use crate::network::ipv4::Ipv4Address;
73use crate::object::KernelObject;
74
75pub type SocketId = usize;
77
78pub type SocketFactory =
82 fn(SocketType, SocketProtocol) -> Result<Arc<dyn SocketObject>, SocketError>;
83
84#[derive(Debug, Clone, Default)]
86pub struct InterfaceStats {
87 pub tx_packets: u64,
88 pub tx_bytes: u64,
89 pub rx_packets: u64,
90 pub rx_bytes: u64,
91 pub drops: u64,
92 pub errors: u64,
93}
94
95#[derive(Debug, Clone)]
97pub struct NetworkConfig {
98 pub default_gateway: Option<Ipv4Address>,
99 pub gateway_mac: Option<MacAddress>,
100 pub dns_server: Option<Ipv4Address>,
101 pub subnet_mask: Ipv4Address,
102}
103
104impl Default for NetworkConfig {
105 fn default() -> Self {
106 Self {
107 default_gateway: None,
108 gateway_mac: None,
109 dns_server: None,
110 subnet_mask: Ipv4Address::new(255, 255, 255, 0),
111 }
112 }
113}
114
115pub trait NetworkInterface: Send + Sync {
117 fn name(&self) -> &str;
118 fn mac_address(&self) -> MacAddress;
119 fn ip_address(&self) -> Option<Ipv4Address>;
120 fn set_ip_address(&self, ip: Ipv4Address);
121 fn send(&self, packet: DevicePacket) -> Result<(), &'static str>;
122 fn poll(&self) -> Result<Vec<DevicePacket>, &'static str>;
123 fn stats(&self) -> InterfaceStats;
124}
125
126pub struct NetworkManager {
128 socket_factories: spin::RwLock<BTreeMap<SocketDomain, SocketFactory>>,
130
131 protocol_stacks: protocol_stack::ProtocolStackManager,
133
134 protocol_layers: spin::RwLock<BTreeMap<String, Arc<dyn NetworkLayer>>>,
136
137 named_sockets: spin::RwLock<BTreeMap<String, Weak<dyn SocketObject>>>,
139
140 connections: spin::RwLock<BTreeMap<SocketId, Arc<dyn SocketObject>>>,
142
143 socket_to_id: spin::RwLock<BTreeMap<usize, SocketId>>,
145
146 next_socket_id: AtomicUsize,
148
149 interfaces: spin::RwLock<BTreeMap<String, Arc<dyn NetworkInterface>>>,
151
152 default_interface: spin::RwLock<Option<String>>,
154
155 arp_cache: spin::RwLock<BTreeMap<u32, ArpCacheEntry>>,
157
158 network_config: spin::RwLock<NetworkConfig>,
160}
161
162impl NetworkManager {
163 fn new() -> Self {
165 Self {
166 socket_factories: spin::RwLock::new(BTreeMap::new()),
167 protocol_stacks: protocol_stack::ProtocolStackManager::new(),
168 protocol_layers: spin::RwLock::new(BTreeMap::new()),
169 named_sockets: spin::RwLock::new(BTreeMap::new()),
170 connections: spin::RwLock::new(BTreeMap::new()),
171 socket_to_id: spin::RwLock::new(BTreeMap::new()),
172 next_socket_id: AtomicUsize::new(1),
173 interfaces: spin::RwLock::new(BTreeMap::new()),
174 default_interface: spin::RwLock::new(None),
175 arp_cache: spin::RwLock::new(BTreeMap::new()),
176 network_config: spin::RwLock::new(NetworkConfig::default()),
177 }
178 }
179
180 pub fn get_manager() -> &'static NetworkManager {
182 GLOBAL_NETWORK_MANAGER.call_once(|| NetworkManager::new())
183 }
184
185 pub fn init() -> &'static NetworkManager {
192 let manager = GLOBAL_NETWORK_MANAGER.call_once(|| NetworkManager::new());
193
194 crate::network::ethernet::EthernetLayer::init(manager);
197
198 crate::network::ipv4::Ipv4Layer::init(manager);
200 crate::network::arp::ArpLayer::init(manager);
201
202 crate::network::icmp::IcmpLayer::init(manager);
204 crate::network::udp::UdpLayer::init(manager);
205 crate::network::tcp::TcpLayer::init(manager);
206
207 manager
208 }
209
210 pub fn register_interface(
215 &self,
216 name: &str,
217 interface: Arc<dyn NetworkInterface>,
218 ) -> Result<(), &'static str> {
219 let mut default = self.default_interface.write();
220 if default.is_none() {
221 *default = Some(String::from(name));
222 }
223
224 let interface_clone = interface.clone();
225 self.interfaces
226 .write()
227 .insert(String::from(name), interface);
228
229 if self.interfaces.read().len() == 1 {
231 self.configure_protocol_layers_with_interface(interface_clone);
232 }
233
234 Ok(())
235 }
236
237 pub fn get_interface(&self, name: &str) -> Option<Arc<dyn NetworkInterface>> {
238 self.interfaces.read().get(name).cloned()
239 }
240
241 pub fn get_default_interface(&self) -> Option<Arc<dyn NetworkInterface>> {
242 self.default_interface
243 .read()
244 .as_ref()
245 .and_then(|name| self.get_interface(name))
246 }
247
248 pub fn set_default_interface(&self, name: &str) {
249 *self.default_interface.write() = Some(String::from(name));
250 }
251
252 pub fn list_interfaces(&self) -> Vec<String> {
253 self.interfaces.read().keys().cloned().collect()
254 }
255
256 fn configure_protocol_layers_with_interface(&self, interface: Arc<dyn NetworkInterface>) {
261 let local_ip = interface.ip_address();
262 let interface_name = interface.name();
263
264 if let Some(ip_layer) = self.get_layer("ip") {
266 if let Some(ip) = ip_layer
267 .as_any()
268 .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
269 {
270 if let Some(local_ip_addr) = local_ip {
271 ip.add_address(
272 interface_name,
273 crate::network::ipv4::Ipv4AddressInfo {
274 address: local_ip_addr,
275 netmask: crate::network::ipv4::Ipv4Address::new(255, 255, 255, 0),
276 broadcast: None,
277 is_primary: true,
278 },
279 );
280 }
281 }
282 }
283 }
284
285 pub fn arp_lookup(&self, ip: &Ipv4Address) -> Option<MacAddress> {
290 let ip_u32 = u32::from_be_bytes(ip.as_bytes());
291 let cache = self.arp_cache.read();
292 cache.get(&ip_u32).and_then(|entry| {
293 if entry.is_expired() {
294 None
295 } else {
296 Some(MacAddress::new(entry.mac_address))
297 }
298 })
299 }
300
301 pub fn arp_cache_add(&self, ip: Ipv4Address, mac: MacAddress) {
302 let ip_u32 = u32::from_be_bytes(ip.as_bytes());
303 let mut cache = self.arp_cache.write();
304 cache.insert(ip_u32, ArpCacheEntry::new(ip, *mac.as_bytes()));
305 }
306
307 pub fn send_arp_request(&self, target_ip: Ipv4Address) -> Result<(), &'static str> {
308 let interface = self.get_default_interface().ok_or("No default interface")?;
309
310 let local_ip = interface.ip_address().ok_or("Interface has no IP")?;
311 let local_mac = interface.mac_address();
312
313 let arp_request =
314 crate::network::arp::ArpPacket::request(local_ip.as_bytes(), target_ip.as_bytes());
315
316 let eth_header = crate::network::ethernet::EthernetHeader::new(
317 [0xFF; 6],
318 *local_mac.as_bytes(),
319 crate::network::ethernet::ether_type::ARP,
320 );
321
322 let mut packet_data = Vec::new();
323 packet_data.extend_from_slice(ð_header.to_bytes());
324 packet_data.extend_from_slice(&arp_request.to_bytes());
325
326 let packet = DevicePacket::with_data(packet_data);
327 interface.send(packet)
328 }
329
330 pub fn resolve_mac(&self, ip: Ipv4Address) -> Result<MacAddress, &'static str> {
331 if let Some(mac) = self.arp_lookup(&ip) {
332 return Ok(mac);
333 }
334 self.send_arp_request(ip)?;
335 Err("MAC not in cache, ARP request sent")
336 }
337
338 pub fn get_config(&self) -> NetworkConfig {
343 self.network_config.read().clone()
344 }
345
346 pub fn set_config(&self, config: NetworkConfig) {
347 *self.network_config.write() = config;
348 }
349
350 pub fn set_default_gateway(&self, gateway: Ipv4Address) {
351 self.network_config.write().default_gateway = Some(gateway);
352 self.network_config.write().gateway_mac = None;
353
354 if let Some(default_iface) = self.get_default_interface() {
356 if let Some(ip_layer) = self.get_layer("ip") {
357 if let Some(ipv4) = ip_layer
358 .as_any()
359 .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
360 {
361 ipv4.set_default_gateway(gateway, default_iface.name());
362 }
363 }
364 }
365 }
366
367 pub fn get_default_gateway(&self) -> Option<Ipv4Address> {
368 self.network_config.read().default_gateway
369 }
370
371 pub fn handle_received_packet(&self, _interface_name: &str, packet: &DevicePacket) {
372 if packet.len < 14 {
373 return;
374 }
375
376 let eth_type = u16::from_be_bytes([packet.data[12], packet.data[13]]);
377 early_println!(
378 "[net] recv frame len={} eth_type=0x{:04X}",
379 packet.len,
380 eth_type
381 );
382 match eth_type {
383 0x0806 => self.handle_arp_packet(packet),
384 0x0800 => self.handle_ipv4_packet(packet),
385 _ => {}
386 }
387 }
388
389 fn handle_arp_packet(&self, packet: &DevicePacket) {
390 if packet.len < 14 + 28 {
391 return;
392 }
393
394 let arp_data = &packet.data[14..];
395 if let Some(arp_layer) = self.get_layer("arp") {
396 if let Some(arp) = arp_layer
397 .as_any()
398 .downcast_ref::<crate::network::arp::ArpLayer>()
399 {
400 let _ = arp.receive_packet(arp_data);
401 }
402 }
403 }
404
405 fn handle_ipv4_packet(&self, packet: &DevicePacket) {
406 if packet.len < 14 + 20 {
407 return;
408 }
409
410 let ip_bytes = &packet.data[14..packet.len];
411 let header = match crate::network::ipv4::Ipv4Header::from_bytes(ip_bytes) {
412 Some(h) => h,
413 None => return,
414 };
415
416 early_println!(
417 "[IPv4] Recv frame: ip_len={} src={}.{}.{}.{} dst={}.{}.{}.{} proto={}",
418 ip_bytes.len(),
419 header.source_ip[0],
420 header.source_ip[1],
421 header.source_ip[2],
422 header.source_ip[3],
423 header.dest_ip[0],
424 header.dest_ip[1],
425 header.dest_ip[2],
426 header.dest_ip[3],
427 header.protocol
428 );
429
430 let header_len = header.header_length();
431 let total_length = usize::from(header.total_length);
432 if ip_bytes.len() < header_len {
433 return;
434 }
435 if total_length < header_len || total_length > ip_bytes.len() {
436 return;
437 }
438
439 let payload = &ip_bytes[header_len..total_length];
440 let protocol = header.protocol;
441
442 if let Some(ip_layer) = self.get_layer("ip") {
443 if let Some(ip) = ip_layer
444 .as_any()
445 .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
446 {
447 if let Some(handler) = ip.get_protocol_handler(protocol) {
448 let src_ip = crate::network::ipv4::Ipv4Address::from_bytes(header.source_ip);
449 let dst_ip = crate::network::ipv4::Ipv4Address::from_bytes(header.dest_ip);
450 let _ = match protocol {
451 crate::network::ipv4::protocol::ICMP => handler
452 .as_any()
453 .downcast_ref::<crate::network::icmp::IcmpLayer>()
454 .map(|icmp| icmp.receive_packet(payload, src_ip, dst_ip)),
455 crate::network::ipv4::protocol::TCP => handler
456 .as_any()
457 .downcast_ref::<crate::network::tcp::TcpLayer>()
458 .map(|tcp| tcp.receive_packet(src_ip, dst_ip, payload)),
459 crate::network::ipv4::protocol::UDP => handler
460 .as_any()
461 .downcast_ref::<crate::network::udp::UdpLayer>()
462 .map(|udp| udp.receive_packet(src_ip, dst_ip, payload)),
463 _ => Some(handler.receive(payload, None)),
464 };
465 }
466 }
467 }
468 }
469
470 pub fn register_socket_factory(&self, domain: SocketDomain, factory: SocketFactory) {
475 self.socket_factories.write().insert(domain, factory);
476 }
477
478 pub fn create_socket(
479 &self,
480 domain: SocketDomain,
481 socket_type: SocketType,
482 protocol: SocketProtocol,
483 ) -> Result<KernelObject, SocketError> {
484 let factories = self.socket_factories.read();
485 if let Some(factory) = factories.get(&domain) {
486 let socket = factory(socket_type, protocol)?;
487 let socket_id = self.next_socket_id.fetch_add(1, Ordering::SeqCst);
488 self.connections.write().insert(socket_id, socket.clone());
489 return Ok(KernelObject::Socket(socket));
490 }
491 drop(factories);
492
493 if let Some(stack) = self.protocol_stacks.get_stack(domain) {
494 let socket = stack.create_socket(socket_type, protocol)?;
495 let socket_id = self.next_socket_id.fetch_add(1, Ordering::SeqCst);
496 self.connections.write().insert(socket_id, socket.clone());
497 return Ok(KernelObject::Socket(socket));
498 }
499
500 Err(SocketError::NotSupported)
501 }
502
503 pub fn register_protocol_stack(&self, stack: Arc<dyn protocol_stack::ProtocolStack>) {
504 self.protocol_stacks.register_stack(stack);
505 }
506
507 pub fn register_layer(&self, name: &str, layer: Arc<dyn NetworkLayer>) {
508 self.protocol_layers.write().insert(name.to_string(), layer);
509 }
510
511 pub fn unregister_layer(&self, name: &str) -> Option<Arc<dyn NetworkLayer>> {
512 self.protocol_layers.write().remove(name)
513 }
514
515 pub fn get_layer(&self, name: &str) -> Option<Arc<dyn NetworkLayer>> {
516 self.protocol_layers.read().get(name).cloned()
517 }
518
519 pub fn list_layers(&self) -> Vec<String> {
520 self.protocol_layers.read().keys().cloned().collect()
521 }
522
523 pub fn layer_count(&self) -> usize {
524 self.protocol_layers.read().len()
525 }
526
527 pub fn has_layer(&self, name: &str) -> bool {
528 self.protocol_layers.read().contains_key(name)
529 }
530
531 pub fn register_named_socket(
532 &self,
533 name: &str,
534 socket: Arc<dyn SocketObject>,
535 ) -> Result<(), SocketError> {
536 let mut sockets = self.named_sockets.write();
537 if let Some(weak_socket) = sockets.get(name) {
538 if weak_socket.upgrade().is_some() {
539 return Err(SocketError::AddressInUse);
540 }
541 }
542 sockets.insert(name.into(), Arc::downgrade(&socket));
543 Ok(())
544 }
545
546 pub fn lookup_named_socket(&self, name: &str) -> Result<Arc<dyn SocketObject>, SocketError> {
547 let sockets = self.named_sockets.read();
548 match sockets.get(name) {
549 Some(weak_socket) => weak_socket.upgrade().ok_or(SocketError::ConnectionRefused),
550 None => Err(SocketError::ConnectionRefused),
551 }
552 }
553
554 pub fn unregister_named_socket(&self, name: &str) {
555 self.named_sockets.write().remove(name);
556 }
557
558 pub fn get_socket(&self, socket_id: SocketId) -> Option<Arc<dyn SocketObject>> {
559 self.connections.read().get(&socket_id).cloned()
560 }
561
562 pub fn register_socket_with_id(
563 &self,
564 socket_id: SocketId,
565 socket: Arc<dyn SocketObject>,
566 ) -> Result<(), SocketError> {
567 let mut connections = self.connections.write();
568 if connections.contains_key(&socket_id) {
569 return Err(SocketError::AddressInUse);
570 }
571 let socket_ptr = Arc::as_ptr(&socket) as *const () as usize;
572 connections.insert(socket_id, socket);
573 drop(connections);
574 self.socket_to_id.write().insert(socket_ptr, socket_id);
575 Ok(())
576 }
577
578 pub fn remove_socket(&self, socket_id: SocketId) {
579 let mut connections = self.connections.write();
580 if let Some(socket) = connections.get(&socket_id) {
581 let socket_ptr = Arc::as_ptr(socket) as *const () as usize;
582 drop(connections);
583 self.connections.write().remove(&socket_id);
584 self.socket_to_id.write().remove(&socket_ptr);
585 }
586 }
587
588 pub fn allocate_socket_id(
589 &self,
590 socket: Arc<dyn SocketObject>,
591 ) -> Result<SocketId, SocketError> {
592 let socket = socket;
593 let start_id = self.next_socket_id.load(Ordering::SeqCst);
594 let mut current_id = start_id;
595 loop {
596 match self.register_socket_with_id(current_id, Arc::clone(&socket)) {
597 Ok(()) => {
598 let mut observed = self.next_socket_id.load(Ordering::SeqCst);
599 loop {
600 if observed > current_id {
601 break;
602 }
603 match self.next_socket_id.compare_exchange(
604 observed,
605 current_id.wrapping_add(1),
606 Ordering::SeqCst,
607 Ordering::SeqCst,
608 ) {
609 Ok(_) => break,
610 Err(actual) => {
611 if actual > current_id {
612 break;
613 }
614 observed = actual;
615 }
616 }
617 }
618 return Ok(current_id);
619 }
620 Err(e) => {
621 let next_id = current_id.wrapping_add(1);
622 if next_id == start_id {
623 return Err(e);
624 }
625 current_id = next_id;
626 }
627 }
628 }
629 }
630
631 pub fn get_socket_id(&self, socket: &Arc<dyn SocketObject>) -> Option<SocketId> {
632 let socket_ptr = Arc::as_ptr(socket) as *const () as usize;
633 self.socket_to_id.read().get(&socket_ptr).copied()
634 }
635
636 pub fn connection_count(&self) -> usize {
637 self.connections.read().len()
638 }
639
640 pub fn named_socket_count(&self) -> usize {
641 let sockets = self.named_sockets.read();
642 sockets.values().filter(|s| s.upgrade().is_some()).count()
643 }
644}
645
646static GLOBAL_NETWORK_MANAGER: Once<NetworkManager> = Once::new();
648
649pub fn get_network_manager() -> &'static NetworkManager {
651 NetworkManager::get_manager()
652}