kernel/network/
ipv4.rs

1//! IPv4 protocol layer
2//!
3//! This module provides IPv4 packet handling for the network stack.
4//! It implements the NetworkLayer trait for IPv4 encapsulation/decapsulation.
5//!
6//! # Design
7//!
8//! The Ipv4Layer manages:
9//! - Multiple IPv4 addresses per interface (primary + secondary)
10//! - Routing table for destination-based forwarding
11//! - Source IP selection based on routing decisions
12//!
13//! This design supports multiple network interfaces with multiple IP addresses each.
14
15use alloc::collections::BTreeMap;
16use alloc::string::{String, ToString};
17use alloc::vec::Vec;
18use spin::RwLock;
19
20use crate::early_println;
21use crate::network::protocol_stack::{
22    LayerContext, NetworkLayer, NetworkLayerStats, get_network_manager,
23};
24use crate::network::socket::SocketError;
25
26/// IPv4 address
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct Ipv4Address(pub [u8; 4]);
29
30impl Ipv4Address {
31    /// Create a new IPv4 address
32    pub fn new(a: u8, b: u8, c: u8, d: u8) -> Self {
33        Self([a, b, c, d])
34    }
35
36    /// Create IPv4 address from bytes
37    pub fn from_bytes(bytes: [u8; 4]) -> Self {
38        Self(bytes)
39    }
40
41    /// Get address as bytes
42    pub fn as_bytes(&self) -> [u8; 4] {
43        self.0
44    }
45
46    /// Convert to big-endian u32
47    pub fn to_u32_be(&self) -> u32 {
48        u32::from_be_bytes(self.0)
49    }
50
51    /// Convert from big-endian u32
52    pub fn from_u32_be(addr: u32) -> Self {
53        Self(addr.to_be_bytes())
54    }
55
56    /// Check if this is a broadcast address (255.255.255.255)
57    pub fn is_broadcast(&self) -> bool {
58        self.0 == [255, 255, 255, 255]
59    }
60
61    /// Check if this is a loopback address (127.0.0.0/8)
62    pub fn is_loopback(&self) -> bool {
63        self.0[0] == 127
64    }
65
66    /// Check if this is the "any" address (0.0.0.0)
67    pub fn is_any(&self) -> bool {
68        self.0 == [0, 0, 0, 0]
69    }
70}
71
72/// IPv4 header (minimum 20 bytes)
73#[derive(Debug, Clone, Copy)]
74#[repr(C, packed)]
75pub struct Ipv4Header {
76    /// Version (4 bits) + IHL (4 bits)
77    pub version_ihl: u8,
78    /// Type of Service
79    pub tos: u8,
80    /// Total Length (16 bits)
81    pub total_length: u16,
82    /// Identification (16 bits)
83    pub identification: u16,
84    /// Flags (3 bits) + Fragment Offset (13 bits)
85    pub flags_fragment: u16,
86    /// Time to Live
87    pub ttl: u8,
88    /// Protocol (8 bits)
89    pub protocol: u8,
90    /// Header Checksum (16 bits)
91    pub checksum: u16,
92    /// Source IP (32 bits)
93    pub source_ip: [u8; 4],
94    /// Destination IP (32 bits)
95    pub dest_ip: [u8; 4],
96}
97
98impl Ipv4Header {
99    /// Create a new IPv4 header
100    pub fn new() -> Self {
101        Self {
102            version_ihl: 0x45, // Version=4, IHL=5 (20 bytes)
103            tos: 0,
104            total_length: 0,
105            identification: 0,
106            flags_fragment: 0,
107            ttl: 64,
108            protocol: 0,
109            checksum: 0,
110            source_ip: [0, 0, 0, 0],
111            dest_ip: [0, 0, 0, 0],
112        }
113    }
114
115    /// Get IP version (always 4)
116    pub fn version(&self) -> u8 {
117        self.version_ihl >> 4
118    }
119
120    /// Get IHL (Internet Header Length) in 32-bit words
121    pub fn ihl(&self) -> u8 {
122        self.version_ihl & 0x0F
123    }
124
125    /// Get header length in bytes
126    pub fn header_length(&self) -> usize {
127        (self.ihl() as usize) * 4
128    }
129
130    /// Calculate checksum
131    pub fn calculate_checksum(&self) -> u16 {
132        let mut bytes = self.to_bytes();
133        if bytes.len() >= 12 {
134            bytes[10] = 0;
135            bytes[11] = 0;
136        }
137        checksum_from_bytes(&bytes)
138    }
139
140    /// Serialize header to bytes
141    pub fn to_bytes(&self) -> Vec<u8> {
142        let mut bytes = Vec::with_capacity(20);
143        bytes.push(self.version_ihl);
144        bytes.push(self.tos);
145        bytes.extend_from_slice(&self.total_length.to_be_bytes());
146        bytes.extend_from_slice(&self.identification.to_be_bytes());
147        bytes.extend_from_slice(&self.flags_fragment.to_be_bytes());
148        bytes.push(self.ttl);
149        bytes.push(self.protocol);
150        bytes.extend_from_slice(&self.checksum.to_be_bytes());
151        bytes.extend_from_slice(&self.source_ip);
152        bytes.extend_from_slice(&self.dest_ip);
153        bytes
154    }
155
156    /// Parse header from bytes
157    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
158        if bytes.len() < 20 {
159            return None;
160        }
161
162        let version_ihl = bytes[0];
163        let version = version_ihl >> 4;
164        if version != 4 {
165            return None;
166        }
167
168        let ihl = version_ihl & 0x0F;
169        let header_len = (ihl as usize) * 4;
170        if bytes.len() < header_len {
171            return None;
172        }
173
174        Some(Self {
175            version_ihl,
176            tos: bytes[1],
177            total_length: u16::from_be_bytes([bytes[2], bytes[3]]),
178            identification: u16::from_be_bytes([bytes[4], bytes[5]]),
179            flags_fragment: u16::from_be_bytes([bytes[6], bytes[7]]),
180            ttl: bytes[8],
181            protocol: bytes[9],
182            checksum: u16::from_be_bytes([bytes[10], bytes[11]]),
183            source_ip: [bytes[12], bytes[13], bytes[14], bytes[15]],
184            dest_ip: [bytes[16], bytes[17], bytes[18], bytes[19]],
185        })
186    }
187}
188
189/// IPv4 protocol numbers
190pub mod protocol {
191    /// ICMP
192    pub const ICMP: u8 = 1;
193    /// TCP
194    pub const TCP: u8 = 6;
195    /// UDP
196    pub const UDP: u8 = 17;
197    /// IPv6 encapsulation
198    pub const IPV6: u8 = 41;
199}
200
201/// IPv4 address information for an interface
202#[derive(Debug, Clone)]
203pub struct Ipv4AddressInfo {
204    /// The IPv4 address
205    pub address: Ipv4Address,
206    /// Network mask
207    pub netmask: Ipv4Address,
208    /// Broadcast address (optional)
209    pub broadcast: Option<Ipv4Address>,
210    /// Whether this is the primary address for the interface
211    pub is_primary: bool,
212}
213
214/// Routing table entry
215#[derive(Debug, Clone)]
216pub struct RouteEntry {
217    /// Destination network
218    pub destination: Ipv4Address,
219    /// Network mask
220    pub netmask: Ipv4Address,
221    /// Gateway (None for directly connected networks)
222    pub gateway: Option<Ipv4Address>,
223    /// Outgoing interface name
224    pub interface: String,
225    /// Route metric (lower is preferred)
226    pub metric: u32,
227}
228
229/// IPv4 layer
230///
231/// Handles IPv4 packet encapsulation and decapsulation.
232/// Manages multiple addresses per interface and routing table.
233pub struct Ipv4Layer {
234    /// Interface name -> list of IPv4 addresses
235    addresses: RwLock<BTreeMap<String, Vec<Ipv4AddressInfo>>>,
236    /// Routing table (ordered by specificity)
237    routing_table: RwLock<Vec<RouteEntry>>,
238    /// Protocol handlers registered by protocol number
239    protocols: RwLock<BTreeMap<u8, alloc::sync::Arc<dyn NetworkLayer>>>,
240    /// Statistics
241    stats: RwLock<NetworkLayerStats>,
242    /// Default TTL
243    default_ttl: u8,
244}
245
246impl Ipv4Layer {
247    /// Create a new IPv4 layer
248    pub fn new() -> alloc::sync::Arc<Self> {
249        alloc::sync::Arc::new(Self {
250            addresses: RwLock::new(BTreeMap::new()),
251            routing_table: RwLock::new(Vec::new()),
252            protocols: RwLock::new(BTreeMap::new()),
253            stats: RwLock::new(NetworkLayerStats::default()),
254            default_ttl: 64,
255        })
256    }
257
258    /// Initialize and register the IPv4 layer with NetworkManager
259    ///
260    /// Registers with NetworkManager and registers itself with EthernetLayer
261    /// for EtherType 0x0800 (IPv4).
262    ///
263    /// # Panics
264    ///
265    /// Panics if EthernetLayer is not registered (must be initialized first).
266    pub fn init(network_manager: &crate::network::NetworkManager) {
267        let layer = Self::new();
268        network_manager.register_layer("ip", layer.clone());
269
270        // Register with Ethernet layer for IPv4 packets (EtherType 0x0800)
271        let ethernet = network_manager
272            .get_layer("ethernet")
273            .expect("EthernetLayer must be initialized before Ipv4Layer");
274        ethernet.register_protocol(crate::network::ethernet::ether_type::IPV4, layer);
275    }
276
277    /// Add an IPv4 address to an interface
278    pub fn add_address(&self, interface: &str, info: Ipv4AddressInfo) {
279        let mut addrs = self.addresses.write();
280        addrs
281            .entry(interface.to_string())
282            .or_insert_with(Vec::new)
283            .push(info);
284    }
285
286    /// Remove an IPv4 address from an interface
287    pub fn remove_address(&self, interface: &str, ip: Ipv4Address) {
288        let mut addrs = self.addresses.write();
289        if let Some(list) = addrs.get_mut(interface) {
290            list.retain(|a| a.address != ip);
291        }
292    }
293
294    /// Get all addresses for an interface
295    pub fn get_addresses(&self, interface: &str) -> Vec<Ipv4AddressInfo> {
296        self.addresses
297            .read()
298            .get(interface)
299            .cloned()
300            .unwrap_or_default()
301    }
302
303    /// Get primary IP address for an interface
304    pub fn get_primary_ip(&self, interface: &str) -> Option<Ipv4Address> {
305        self.addresses
306            .read()
307            .get(interface)?
308            .iter()
309            .find(|a| a.is_primary)
310            .map(|a| a.address)
311    }
312
313    /// Add a route to the routing table
314    pub fn add_route(&self, entry: RouteEntry) {
315        let mut table = self.routing_table.write();
316        table.push(entry);
317        // Sort by netmask specificity (more specific routes first)
318        table.sort_by(|a, b| {
319            let a_bits = a.netmask.to_u32_be().count_ones();
320            let b_bits = b.netmask.to_u32_be().count_ones();
321            b_bits.cmp(&a_bits).then(a.metric.cmp(&b.metric))
322        });
323    }
324
325    /// Remove a route from the routing table
326    pub fn remove_route(&self, destination: Ipv4Address, netmask: Ipv4Address) {
327        let mut table = self.routing_table.write();
328        table.retain(|r| r.destination != destination || r.netmask != netmask);
329    }
330
331    /// Set default gateway
332    pub fn set_default_gateway(&self, gateway: Ipv4Address, interface: &str) {
333        self.add_route(RouteEntry {
334            destination: Ipv4Address::new(0, 0, 0, 0),
335            netmask: Ipv4Address::new(0, 0, 0, 0),
336            gateway: Some(gateway),
337            interface: interface.to_string(),
338            metric: 100,
339        });
340    }
341
342    /// Select source IP and interface for a destination
343    ///
344    /// Returns (interface_name, source_ip, optional_gateway)
345    pub fn select_source(
346        &self,
347        dest: Ipv4Address,
348    ) -> Option<(String, Ipv4Address, Option<Ipv4Address>)> {
349        let table = self.routing_table.read();
350
351        // Find matching route
352        for route in table.iter() {
353            if self.ip_matches_route(dest, route) {
354                if let Some(src_ip) = self.get_primary_ip(&route.interface) {
355                    return Some((route.interface.clone(), src_ip, route.gateway));
356                }
357            }
358        }
359
360        // Fallback: check if destination is on a directly connected network
361        let addrs = self.addresses.read();
362        for (iface, ips) in addrs.iter() {
363            for ip_info in ips {
364                if self.same_subnet(dest, ip_info.address, ip_info.netmask) {
365                    return Some((iface.clone(), ip_info.address, None));
366                }
367            }
368        }
369
370        // Last resort: use any available primary IP
371        for (iface, ips) in addrs.iter() {
372            if let Some(primary) = ips.iter().find(|a| a.is_primary) {
373                return Some((iface.clone(), primary.address, None));
374            }
375        }
376
377        None
378    }
379
380    /// Check if an IP matches a route
381    fn ip_matches_route(&self, ip: Ipv4Address, route: &RouteEntry) -> bool {
382        self.same_subnet(ip, route.destination, route.netmask)
383    }
384
385    /// Check if two IPs are in the same subnet
386    fn same_subnet(&self, ip1: Ipv4Address, ip2: Ipv4Address, mask: Ipv4Address) -> bool {
387        let ip1_u32 = ip1.to_u32_be();
388        let ip2_u32 = ip2.to_u32_be();
389        let mask_u32 = mask.to_u32_be();
390        (ip1_u32 & mask_u32) == (ip2_u32 & mask_u32)
391    }
392
393    /// Check if an IP is local (assigned to any interface)
394    pub fn is_local_ip(&self, ip: Ipv4Address) -> bool {
395        self.addresses
396            .read()
397            .values()
398            .any(|ips| ips.iter().any(|a| a.address == ip))
399    }
400
401    /// Get protocol handler for a protocol number
402    pub fn get_protocol_handler(
403        &self,
404        proto_num: u8,
405    ) -> Option<alloc::sync::Arc<dyn NetworkLayer>> {
406        self.protocols.read().get(&proto_num).cloned()
407    }
408}
409
410impl NetworkLayer for Ipv4Layer {
411    fn register_protocol(&self, proto_num: u16, handler: alloc::sync::Arc<dyn NetworkLayer>) {
412        self.protocols.write().insert(proto_num as u8, handler);
413    }
414
415    fn send(
416        &self,
417        packet: &[u8],
418        context: &LayerContext,
419        next_layers: &[alloc::sync::Arc<dyn NetworkLayer>],
420    ) -> Result<(), SocketError> {
421        // Get destination IP from context (required)
422        let dest_ip_bytes = context
423            .get("ip_dst")
424            .and_then(|ip| {
425                if ip.len() >= 4 {
426                    Some([ip[0], ip[1], ip[2], ip[3]])
427                } else {
428                    None
429                }
430            })
431            .ok_or(SocketError::InvalidPacket)?;
432        let dest_ip = Ipv4Address::from_bytes(dest_ip_bytes);
433
434        // Get protocol number from context
435        let protocol = context
436            .get("ip_protocol")
437            .and_then(|p| if !p.is_empty() { Some(p[0]) } else { None })
438            .unwrap_or(protocol::TCP);
439
440        // Get source IP from context, or select based on routing
441        let (interface_name, src_ip_bytes, gateway) = if let Some(ip_src) = context.get("ip_src") {
442            if ip_src.len() >= 4 {
443                // Source IP explicitly set - still need to check routing for gateway
444                let iface = context
445                    .get("interface")
446                    .and_then(|b| core::str::from_utf8(b).ok())
447                    .map(String::from)
448                    .or_else(|| {
449                        get_network_manager()
450                            .get_default_interface()
451                            .map(|i| String::from(i.name()))
452                    })
453                    .ok_or(SocketError::NoRoute)?;
454
455                // Look up gateway from routing table for this destination
456                let gateway = self.select_source(dest_ip).and_then(|(_, _, gw)| gw);
457
458                (iface, [ip_src[0], ip_src[1], ip_src[2], ip_src[3]], gateway)
459            } else {
460                return Err(SocketError::InvalidAddress);
461            }
462        } else {
463            // Select source IP based on routing table
464            let (iface, src_ip, gw) = self.select_source(dest_ip).ok_or(SocketError::NoRoute)?;
465            (iface, src_ip.0, gw)
466        };
467
468        // Build IPv4 header
469        let mut header = Ipv4Header::new();
470        header.source_ip = src_ip_bytes;
471        header.dest_ip = dest_ip_bytes;
472        header.protocol = protocol;
473        header.ttl = self.default_ttl;
474
475        // Calculate total length (header + packet)
476        let total_length = (20 + packet.len()) as u16;
477        header.total_length = total_length;
478
479        // Calculate and set checksum
480        header.checksum = header.calculate_checksum();
481
482        // Serialize header
483        let mut ip_packet = header.to_bytes();
484
485        // Create IP packet: header + payload
486        ip_packet.extend_from_slice(packet);
487
488        early_println!(
489            "[IPv4] Send: {} bytes (src: {}.{}.{}.{}, dst: {}.{}.{}.{}, proto: {}, iface: {})",
490            ip_packet.len(),
491            src_ip_bytes[0],
492            src_ip_bytes[1],
493            src_ip_bytes[2],
494            src_ip_bytes[3],
495            dest_ip_bytes[0],
496            dest_ip_bytes[1],
497            dest_ip_bytes[2],
498            dest_ip_bytes[3],
499            protocol,
500            interface_name
501        );
502
503        // Prepare context for Ethernet layer
504        let mut eth_context = context.clone();
505        eth_context.set(
506            "eth_type",
507            &crate::network::ethernet::ether_type::IPV4.to_be_bytes(),
508        );
509        eth_context.set("interface", interface_name.as_bytes());
510        eth_context.set("ip_src", &src_ip_bytes);
511
512        // If we have a gateway, set that as the next-hop for ARP resolution
513        if let Some(gw) = gateway {
514            eth_context.set("next_hop", &gw.0);
515        } else {
516            eth_context.set("next_hop", &dest_ip_bytes);
517        }
518
519        // Forward to Ethernet layer
520        if !next_layers.is_empty() {
521            next_layers[0].send(&ip_packet, &eth_context, &next_layers[1..])?;
522        } else if let Some(eth_layer) = get_network_manager().get_layer("ethernet") {
523            eth_layer.send(&ip_packet, &eth_context, &[])?;
524        }
525
526        // Update statistics
527        let mut stats = self.stats.write();
528        stats.packets_sent += 1;
529        stats.bytes_sent += ip_packet.len() as u64;
530
531        Ok(())
532    }
533
534    fn receive(&self, packet: &[u8], _context: Option<&LayerContext>) -> Result<(), SocketError> {
535        // Parse IPv4 header
536        let header = Ipv4Header::from_bytes(packet).ok_or(SocketError::InvalidPacket)?;
537
538        let header_len = header.header_length();
539        let total_length = usize::from(header.total_length);
540
541        if packet.len() < header_len {
542            return Err(SocketError::InvalidPacket);
543        }
544        if total_length < header_len || total_length > packet.len() {
545            return Err(SocketError::InvalidPacket);
546        }
547
548        early_println!(
549            "[IPv4] RX: total_len={} src={}.{}.{}.{} dst={}.{}.{}.{} proto={}",
550            total_length,
551            header.source_ip[0],
552            header.source_ip[1],
553            header.source_ip[2],
554            header.source_ip[3],
555            header.dest_ip[0],
556            header.dest_ip[1],
557            header.dest_ip[2],
558            header.dest_ip[3],
559            header.protocol
560        );
561
562        // Verify checksum (header.checksum is already in host order)
563        let calculated_checksum = checksum_from_bytes(&packet[..header_len]);
564        let header_checksum = unsafe { core::ptr::addr_of!(header.checksum).read_unaligned() };
565        if calculated_checksum != header_checksum {
566            early_println!(
567                "[IPv4] Checksum mismatch: calculated=0x{:04X}, header=0x{:04X}",
568                calculated_checksum,
569                header_checksum
570            );
571            let mut stats = self.stats.write();
572            stats.protocol_errors += 1;
573            return Err(SocketError::InvalidPacket);
574        }
575
576        let payload = &packet[header_len..total_length];
577
578        early_println!(
579            "[IPv4] Recv: {} bytes (src: {}.{}.{}.{}, dst: {}.{}.{}.{}, proto: {})",
580            packet.len(),
581            header.source_ip[0],
582            header.source_ip[1],
583            header.source_ip[2],
584            header.source_ip[3],
585            header.dest_ip[0],
586            header.dest_ip[1],
587            header.dest_ip[2],
588            header.dest_ip[3],
589            header.protocol
590        );
591
592        // Update statistics
593        let mut stats = self.stats.write();
594        stats.packets_received += 1;
595        stats.bytes_received += total_length as u64;
596
597        // Route to protocol handler based on protocol field
598        let protocols = self.protocols.read();
599        if let Some(handler) = protocols.get(&header.protocol) {
600            let mut proto_context = LayerContext::new();
601            proto_context.set("ip_src", &header.source_ip);
602            proto_context.set("ip_dst", &header.dest_ip);
603            handler.receive(payload, Some(&proto_context))
604        } else {
605            // No handler for this protocol - log and drop
606            Err(SocketError::ProtocolNotSupported)
607        }
608    }
609
610    fn name(&self) -> &'static str {
611        "IPv4"
612    }
613
614    fn stats(&self) -> NetworkLayerStats {
615        self.stats.read().clone()
616    }
617
618    fn as_any(&self) -> &dyn core::any::Any {
619        self
620    }
621}
622
623fn checksum_from_bytes(header_bytes: &[u8]) -> u16 {
624    let mut sum: u32 = 0;
625    let mut i = 0;
626
627    while i + 1 < header_bytes.len() {
628        if i == 10 {
629            i += 2;
630            continue;
631        }
632        let word = u16::from_be_bytes([header_bytes[i], header_bytes[i + 1]]);
633        sum += word as u32;
634        i += 2;
635    }
636
637    if i < header_bytes.len() {
638        let word = u16::from_be_bytes([header_bytes[i], 0]);
639        sum += word as u32;
640    }
641
642    while sum >> 16 != 0 {
643        sum = (sum & 0xFFFF) + (sum >> 16);
644    }
645
646    !sum as u16
647}
648
649#[cfg(test)]
650mod tests {
651    use alloc::string::ToString;
652    use alloc::vec;
653
654    use super::*;
655
656    #[test_case]
657    fn test_ipv4_address() {
658        let addr = Ipv4Address::new(192, 168, 1, 100);
659        assert_eq!(addr.as_bytes(), [192, 168, 1, 100]);
660        assert!(!addr.is_broadcast());
661        assert!(!addr.is_loopback());
662        assert!(!addr.is_any());
663
664        let broadcast = Ipv4Address::new(255, 255, 255, 255);
665        assert!(broadcast.is_broadcast());
666
667        let loopback = Ipv4Address::new(127, 0, 0, 1);
668        assert!(loopback.is_loopback());
669
670        let any = Ipv4Address::new(0, 0, 0, 0);
671        assert!(any.is_any());
672    }
673
674    #[test_case]
675    fn test_ipv4_address_u32_conversion() {
676        let addr = Ipv4Address::new(192, 168, 1, 100);
677        assert_eq!(addr.to_u32_be(), u32::from_be_bytes([192, 168, 1, 100]));
678
679        let from_u32 = Ipv4Address::from_u32_be(0xC0A80164u32);
680        assert_eq!(from_u32, addr);
681    }
682
683    #[test_case]
684    fn test_ipv4_header_creation() {
685        let mut header = Ipv4Header::new();
686        header.source_ip = [192, 168, 1, 100];
687        header.dest_ip = [192, 168, 1, 1];
688        header.protocol = protocol::TCP;
689        header.total_length = (20 + 10) as u16;
690
691        assert_eq!(header.version(), 4);
692        assert_eq!(header.ihl(), 5);
693        assert_eq!(header.header_length(), 20);
694        assert_eq!(header.protocol, protocol::TCP);
695    }
696
697    #[test_case]
698    fn test_ipv4_header_serialization() {
699        let mut header = Ipv4Header::new();
700        header.source_ip = [192, 168, 1, 100];
701        header.dest_ip = [192, 168, 1, 1];
702        header.protocol = protocol::TCP;
703        header.total_length = 30;
704
705        let bytes = header.to_bytes();
706        assert_eq!(bytes.len(), 20);
707        assert_eq!(bytes[0], 0x45); // Version=4, IHL=5
708        assert_eq!(u16::from_be_bytes([bytes[2], bytes[3]]), 30);
709        assert_eq!(&bytes[12..16], [192, 168, 1, 100]);
710        assert_eq!(&bytes[16..20], [192, 168, 1, 1]);
711    }
712
713    #[test_case]
714    fn test_ipv4_header_parsing() {
715        let mut bytes = vec![
716            0x45, // Version=4, IHL=5
717            0x00, // TOS
718            0x00, 0x1E, // Total length = 30
719            0x00, 0x01, // Identification
720            0x00, 0x00, // Flags+Fragment
721            0x40, // TTL = 64
722            0x06, // Protocol = TCP
723            0x00, 0x00, // Checksum (placeholder)
724            0xC0, 0xA8, 0x01, 0x64, // Source IP = 192.168.1.100
725            0xC0, 0xA8, 0x01, 0x01, // Dest IP = 192.168.1.1
726        ];
727
728        let header = Ipv4Header::from_bytes(&bytes).unwrap();
729        assert_eq!(header.version(), 4);
730        assert_eq!(header.ihl(), 5);
731        let total_length = unsafe { core::ptr::addr_of!(header.total_length).read_unaligned() };
732        assert_eq!(total_length, 30);
733        assert_eq!(header.protocol, protocol::TCP);
734        assert_eq!(header.source_ip, [192, 168, 1, 100]);
735        assert_eq!(header.dest_ip, [192, 168, 1, 1]);
736        assert_eq!(header.ttl, 64);
737    }
738
739    #[test_case]
740    fn test_ipv4_header_invalid_version() {
741        let mut bytes = alloc::vec![0x55u8; 20]; // Invalid version (5)
742        assert!(Ipv4Header::from_bytes(&bytes).is_none());
743    }
744
745    #[test_case]
746    fn test_ipv4_header_too_short() {
747        let bytes = [0u8; 10];
748        assert!(Ipv4Header::from_bytes(&bytes).is_none());
749    }
750
751    #[test_case]
752    fn test_ipv4_layer_creation() {
753        let ip_layer = Ipv4Layer::new();
754        // New layer has no addresses
755        assert!(ip_layer.get_addresses("eth0").is_empty());
756    }
757
758    #[test_case]
759    fn test_ipv4_layer_add_address() {
760        let ip_layer = Ipv4Layer::new();
761
762        let ip = Ipv4Address::new(192, 168, 1, 100);
763        ip_layer.add_address(
764            "eth0",
765            Ipv4AddressInfo {
766                address: ip,
767                netmask: Ipv4Address::new(255, 255, 255, 0),
768                broadcast: Some(Ipv4Address::new(192, 168, 1, 255)),
769                is_primary: true,
770            },
771        );
772
773        let addrs = ip_layer.get_addresses("eth0");
774        assert_eq!(addrs.len(), 1);
775        assert_eq!(addrs[0].address, ip);
776        assert!(addrs[0].is_primary);
777    }
778
779    #[test_case]
780    fn test_ipv4_layer_multiple_addresses() {
781        let ip_layer = Ipv4Layer::new();
782
783        // Add primary address
784        ip_layer.add_address(
785            "eth0",
786            Ipv4AddressInfo {
787                address: Ipv4Address::new(192, 168, 1, 100),
788                netmask: Ipv4Address::new(255, 255, 255, 0),
789                broadcast: None,
790                is_primary: true,
791            },
792        );
793
794        // Add secondary address
795        ip_layer.add_address(
796            "eth0",
797            Ipv4AddressInfo {
798                address: Ipv4Address::new(192, 168, 1, 101),
799                netmask: Ipv4Address::new(255, 255, 255, 0),
800                broadcast: None,
801                is_primary: false,
802            },
803        );
804
805        let addrs = ip_layer.get_addresses("eth0");
806        assert_eq!(addrs.len(), 2);
807        assert_eq!(
808            ip_layer.get_primary_ip("eth0"),
809            Some(Ipv4Address::new(192, 168, 1, 100))
810        );
811    }
812
813    #[test_case]
814    fn test_ipv4_layer_routing() {
815        let ip_layer = Ipv4Layer::new();
816
817        // Add address to eth0
818        ip_layer.add_address(
819            "eth0",
820            Ipv4AddressInfo {
821                address: Ipv4Address::new(192, 168, 1, 100),
822                netmask: Ipv4Address::new(255, 255, 255, 0),
823                broadcast: None,
824                is_primary: true,
825            },
826        );
827
828        // Add route for local subnet
829        ip_layer.add_route(RouteEntry {
830            destination: Ipv4Address::new(192, 168, 1, 0),
831            netmask: Ipv4Address::new(255, 255, 255, 0),
832            gateway: None,
833            interface: "eth0".to_string(),
834            metric: 0,
835        });
836
837        // Add default route
838        ip_layer.set_default_gateway(Ipv4Address::new(192, 168, 1, 1), "eth0");
839
840        // Test routing to local subnet - should use direct route
841        let result = ip_layer.select_source(Ipv4Address::new(192, 168, 1, 50));
842        assert!(result.is_some());
843        let (iface, src_ip, gw) = result.unwrap();
844        assert_eq!(iface, "eth0");
845        assert_eq!(src_ip, Ipv4Address::new(192, 168, 1, 100));
846        assert!(gw.is_none()); // Direct route, no gateway
847
848        // Test routing to external address - should use default gateway
849        let result = ip_layer.select_source(Ipv4Address::new(8, 8, 8, 8));
850        assert!(result.is_some());
851        let (iface, src_ip, gw) = result.unwrap();
852        assert_eq!(iface, "eth0");
853        assert_eq!(src_ip, Ipv4Address::new(192, 168, 1, 100));
854        assert_eq!(gw, Some(Ipv4Address::new(192, 168, 1, 1)));
855    }
856
857    #[test_case]
858    fn test_ipv4_is_local_ip() {
859        let ip_layer = Ipv4Layer::new();
860
861        ip_layer.add_address(
862            "eth0",
863            Ipv4AddressInfo {
864                address: Ipv4Address::new(192, 168, 1, 100),
865                netmask: Ipv4Address::new(255, 255, 255, 0),
866                broadcast: None,
867                is_primary: true,
868            },
869        );
870
871        assert!(ip_layer.is_local_ip(Ipv4Address::new(192, 168, 1, 100)));
872        assert!(!ip_layer.is_local_ip(Ipv4Address::new(192, 168, 1, 101)));
873    }
874
875    #[test_case]
876    fn test_ipv4_checksum() {
877        let mut header = Ipv4Header::new();
878        header.source_ip = [192, 168, 1, 100];
879        header.dest_ip = [192, 168, 1, 1];
880        header.protocol = protocol::TCP;
881        header.ttl = 64;
882        header.total_length = 20;
883        header.identification = 0;
884        header.flags_fragment = 0;
885        header.tos = 0;
886
887        let checksum = header.calculate_checksum();
888        // Just verify that checksum calculation runs without panicking
889        assert_ne!(checksum, 0);
890    }
891
892    #[test_case]
893    fn test_ipv4_checksum_known_vector() {
894        let header = Ipv4Header {
895            version_ihl: 0x45,
896            tos: 0x00,
897            total_length: 0x003C,
898            identification: 0x1C46,
899            flags_fragment: 0x4000,
900            ttl: 0x40,
901            protocol: 0x06,
902            checksum: 0x0000,
903            source_ip: [192, 168, 0, 1],
904            dest_ip: [192, 168, 0, 199],
905        };
906
907        let checksum = header.calculate_checksum();
908        assert_eq!(checksum, 0x9C5D);
909    }
910
911    #[test_case]
912    fn test_protocol_constants() {
913        assert_eq!(protocol::ICMP, 1);
914        assert_eq!(protocol::TCP, 6);
915        assert_eq!(protocol::UDP, 17);
916        assert_eq!(protocol::IPV6, 41);
917    }
918}