1use alloc::collections::BTreeMap;
16use alloc::string::{String, ToString};
17use alloc::vec::Vec;
18use spin::RwLock;
19
20use crate::early_println;
21use crate::network::protocol_stack::{
22 LayerContext, NetworkLayer, NetworkLayerStats, get_network_manager,
23};
24use crate::network::socket::SocketError;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct Ipv4Address(pub [u8; 4]);
29
30impl Ipv4Address {
31 pub fn new(a: u8, b: u8, c: u8, d: u8) -> Self {
33 Self([a, b, c, d])
34 }
35
36 pub fn from_bytes(bytes: [u8; 4]) -> Self {
38 Self(bytes)
39 }
40
41 pub fn as_bytes(&self) -> [u8; 4] {
43 self.0
44 }
45
46 pub fn to_u32_be(&self) -> u32 {
48 u32::from_be_bytes(self.0)
49 }
50
51 pub fn from_u32_be(addr: u32) -> Self {
53 Self(addr.to_be_bytes())
54 }
55
56 pub fn is_broadcast(&self) -> bool {
58 self.0 == [255, 255, 255, 255]
59 }
60
61 pub fn is_loopback(&self) -> bool {
63 self.0[0] == 127
64 }
65
66 pub fn is_any(&self) -> bool {
68 self.0 == [0, 0, 0, 0]
69 }
70}
71
72#[derive(Debug, Clone, Copy)]
74#[repr(C, packed)]
75pub struct Ipv4Header {
76 pub version_ihl: u8,
78 pub tos: u8,
80 pub total_length: u16,
82 pub identification: u16,
84 pub flags_fragment: u16,
86 pub ttl: u8,
88 pub protocol: u8,
90 pub checksum: u16,
92 pub source_ip: [u8; 4],
94 pub dest_ip: [u8; 4],
96}
97
98impl Ipv4Header {
99 pub fn new() -> Self {
101 Self {
102 version_ihl: 0x45, tos: 0,
104 total_length: 0,
105 identification: 0,
106 flags_fragment: 0,
107 ttl: 64,
108 protocol: 0,
109 checksum: 0,
110 source_ip: [0, 0, 0, 0],
111 dest_ip: [0, 0, 0, 0],
112 }
113 }
114
115 pub fn version(&self) -> u8 {
117 self.version_ihl >> 4
118 }
119
120 pub fn ihl(&self) -> u8 {
122 self.version_ihl & 0x0F
123 }
124
125 pub fn header_length(&self) -> usize {
127 (self.ihl() as usize) * 4
128 }
129
130 pub fn calculate_checksum(&self) -> u16 {
132 let mut bytes = self.to_bytes();
133 if bytes.len() >= 12 {
134 bytes[10] = 0;
135 bytes[11] = 0;
136 }
137 checksum_from_bytes(&bytes)
138 }
139
140 pub fn to_bytes(&self) -> Vec<u8> {
142 let mut bytes = Vec::with_capacity(20);
143 bytes.push(self.version_ihl);
144 bytes.push(self.tos);
145 bytes.extend_from_slice(&self.total_length.to_be_bytes());
146 bytes.extend_from_slice(&self.identification.to_be_bytes());
147 bytes.extend_from_slice(&self.flags_fragment.to_be_bytes());
148 bytes.push(self.ttl);
149 bytes.push(self.protocol);
150 bytes.extend_from_slice(&self.checksum.to_be_bytes());
151 bytes.extend_from_slice(&self.source_ip);
152 bytes.extend_from_slice(&self.dest_ip);
153 bytes
154 }
155
156 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
158 if bytes.len() < 20 {
159 return None;
160 }
161
162 let version_ihl = bytes[0];
163 let version = version_ihl >> 4;
164 if version != 4 {
165 return None;
166 }
167
168 let ihl = version_ihl & 0x0F;
169 let header_len = (ihl as usize) * 4;
170 if bytes.len() < header_len {
171 return None;
172 }
173
174 Some(Self {
175 version_ihl,
176 tos: bytes[1],
177 total_length: u16::from_be_bytes([bytes[2], bytes[3]]),
178 identification: u16::from_be_bytes([bytes[4], bytes[5]]),
179 flags_fragment: u16::from_be_bytes([bytes[6], bytes[7]]),
180 ttl: bytes[8],
181 protocol: bytes[9],
182 checksum: u16::from_be_bytes([bytes[10], bytes[11]]),
183 source_ip: [bytes[12], bytes[13], bytes[14], bytes[15]],
184 dest_ip: [bytes[16], bytes[17], bytes[18], bytes[19]],
185 })
186 }
187}
188
189pub mod protocol {
191 pub const ICMP: u8 = 1;
193 pub const TCP: u8 = 6;
195 pub const UDP: u8 = 17;
197 pub const IPV6: u8 = 41;
199}
200
201#[derive(Debug, Clone)]
203pub struct Ipv4AddressInfo {
204 pub address: Ipv4Address,
206 pub netmask: Ipv4Address,
208 pub broadcast: Option<Ipv4Address>,
210 pub is_primary: bool,
212}
213
214#[derive(Debug, Clone)]
216pub struct RouteEntry {
217 pub destination: Ipv4Address,
219 pub netmask: Ipv4Address,
221 pub gateway: Option<Ipv4Address>,
223 pub interface: String,
225 pub metric: u32,
227}
228
229pub struct Ipv4Layer {
234 addresses: RwLock<BTreeMap<String, Vec<Ipv4AddressInfo>>>,
236 routing_table: RwLock<Vec<RouteEntry>>,
238 protocols: RwLock<BTreeMap<u8, alloc::sync::Arc<dyn NetworkLayer>>>,
240 stats: RwLock<NetworkLayerStats>,
242 default_ttl: u8,
244}
245
246impl Ipv4Layer {
247 pub fn new() -> alloc::sync::Arc<Self> {
249 alloc::sync::Arc::new(Self {
250 addresses: RwLock::new(BTreeMap::new()),
251 routing_table: RwLock::new(Vec::new()),
252 protocols: RwLock::new(BTreeMap::new()),
253 stats: RwLock::new(NetworkLayerStats::default()),
254 default_ttl: 64,
255 })
256 }
257
258 pub fn init(network_manager: &crate::network::NetworkManager) {
267 let layer = Self::new();
268 network_manager.register_layer("ip", layer.clone());
269
270 let ethernet = network_manager
272 .get_layer("ethernet")
273 .expect("EthernetLayer must be initialized before Ipv4Layer");
274 ethernet.register_protocol(crate::network::ethernet::ether_type::IPV4, layer);
275 }
276
277 pub fn add_address(&self, interface: &str, info: Ipv4AddressInfo) {
279 let mut addrs = self.addresses.write();
280 addrs
281 .entry(interface.to_string())
282 .or_insert_with(Vec::new)
283 .push(info);
284 }
285
286 pub fn remove_address(&self, interface: &str, ip: Ipv4Address) {
288 let mut addrs = self.addresses.write();
289 if let Some(list) = addrs.get_mut(interface) {
290 list.retain(|a| a.address != ip);
291 }
292 }
293
294 pub fn get_addresses(&self, interface: &str) -> Vec<Ipv4AddressInfo> {
296 self.addresses
297 .read()
298 .get(interface)
299 .cloned()
300 .unwrap_or_default()
301 }
302
303 pub fn get_primary_ip(&self, interface: &str) -> Option<Ipv4Address> {
305 self.addresses
306 .read()
307 .get(interface)?
308 .iter()
309 .find(|a| a.is_primary)
310 .map(|a| a.address)
311 }
312
313 pub fn add_route(&self, entry: RouteEntry) {
315 let mut table = self.routing_table.write();
316 table.push(entry);
317 table.sort_by(|a, b| {
319 let a_bits = a.netmask.to_u32_be().count_ones();
320 let b_bits = b.netmask.to_u32_be().count_ones();
321 b_bits.cmp(&a_bits).then(a.metric.cmp(&b.metric))
322 });
323 }
324
325 pub fn remove_route(&self, destination: Ipv4Address, netmask: Ipv4Address) {
327 let mut table = self.routing_table.write();
328 table.retain(|r| r.destination != destination || r.netmask != netmask);
329 }
330
331 pub fn set_default_gateway(&self, gateway: Ipv4Address, interface: &str) {
333 self.add_route(RouteEntry {
334 destination: Ipv4Address::new(0, 0, 0, 0),
335 netmask: Ipv4Address::new(0, 0, 0, 0),
336 gateway: Some(gateway),
337 interface: interface.to_string(),
338 metric: 100,
339 });
340 }
341
342 pub fn select_source(
346 &self,
347 dest: Ipv4Address,
348 ) -> Option<(String, Ipv4Address, Option<Ipv4Address>)> {
349 let table = self.routing_table.read();
350
351 for route in table.iter() {
353 if self.ip_matches_route(dest, route) {
354 if let Some(src_ip) = self.get_primary_ip(&route.interface) {
355 return Some((route.interface.clone(), src_ip, route.gateway));
356 }
357 }
358 }
359
360 let addrs = self.addresses.read();
362 for (iface, ips) in addrs.iter() {
363 for ip_info in ips {
364 if self.same_subnet(dest, ip_info.address, ip_info.netmask) {
365 return Some((iface.clone(), ip_info.address, None));
366 }
367 }
368 }
369
370 for (iface, ips) in addrs.iter() {
372 if let Some(primary) = ips.iter().find(|a| a.is_primary) {
373 return Some((iface.clone(), primary.address, None));
374 }
375 }
376
377 None
378 }
379
380 fn ip_matches_route(&self, ip: Ipv4Address, route: &RouteEntry) -> bool {
382 self.same_subnet(ip, route.destination, route.netmask)
383 }
384
385 fn same_subnet(&self, ip1: Ipv4Address, ip2: Ipv4Address, mask: Ipv4Address) -> bool {
387 let ip1_u32 = ip1.to_u32_be();
388 let ip2_u32 = ip2.to_u32_be();
389 let mask_u32 = mask.to_u32_be();
390 (ip1_u32 & mask_u32) == (ip2_u32 & mask_u32)
391 }
392
393 pub fn is_local_ip(&self, ip: Ipv4Address) -> bool {
395 self.addresses
396 .read()
397 .values()
398 .any(|ips| ips.iter().any(|a| a.address == ip))
399 }
400
401 pub fn get_protocol_handler(
403 &self,
404 proto_num: u8,
405 ) -> Option<alloc::sync::Arc<dyn NetworkLayer>> {
406 self.protocols.read().get(&proto_num).cloned()
407 }
408}
409
410impl NetworkLayer for Ipv4Layer {
411 fn register_protocol(&self, proto_num: u16, handler: alloc::sync::Arc<dyn NetworkLayer>) {
412 self.protocols.write().insert(proto_num as u8, handler);
413 }
414
415 fn send(
416 &self,
417 packet: &[u8],
418 context: &LayerContext,
419 next_layers: &[alloc::sync::Arc<dyn NetworkLayer>],
420 ) -> Result<(), SocketError> {
421 let dest_ip_bytes = context
423 .get("ip_dst")
424 .and_then(|ip| {
425 if ip.len() >= 4 {
426 Some([ip[0], ip[1], ip[2], ip[3]])
427 } else {
428 None
429 }
430 })
431 .ok_or(SocketError::InvalidPacket)?;
432 let dest_ip = Ipv4Address::from_bytes(dest_ip_bytes);
433
434 let protocol = context
436 .get("ip_protocol")
437 .and_then(|p| if !p.is_empty() { Some(p[0]) } else { None })
438 .unwrap_or(protocol::TCP);
439
440 let (interface_name, src_ip_bytes, gateway) = if let Some(ip_src) = context.get("ip_src") {
442 if ip_src.len() >= 4 {
443 let iface = context
445 .get("interface")
446 .and_then(|b| core::str::from_utf8(b).ok())
447 .map(String::from)
448 .or_else(|| {
449 get_network_manager()
450 .get_default_interface()
451 .map(|i| String::from(i.name()))
452 })
453 .ok_or(SocketError::NoRoute)?;
454
455 let gateway = self.select_source(dest_ip).and_then(|(_, _, gw)| gw);
457
458 (iface, [ip_src[0], ip_src[1], ip_src[2], ip_src[3]], gateway)
459 } else {
460 return Err(SocketError::InvalidAddress);
461 }
462 } else {
463 let (iface, src_ip, gw) = self.select_source(dest_ip).ok_or(SocketError::NoRoute)?;
465 (iface, src_ip.0, gw)
466 };
467
468 let mut header = Ipv4Header::new();
470 header.source_ip = src_ip_bytes;
471 header.dest_ip = dest_ip_bytes;
472 header.protocol = protocol;
473 header.ttl = self.default_ttl;
474
475 let total_length = (20 + packet.len()) as u16;
477 header.total_length = total_length;
478
479 header.checksum = header.calculate_checksum();
481
482 let mut ip_packet = header.to_bytes();
484
485 ip_packet.extend_from_slice(packet);
487
488 early_println!(
489 "[IPv4] Send: {} bytes (src: {}.{}.{}.{}, dst: {}.{}.{}.{}, proto: {}, iface: {})",
490 ip_packet.len(),
491 src_ip_bytes[0],
492 src_ip_bytes[1],
493 src_ip_bytes[2],
494 src_ip_bytes[3],
495 dest_ip_bytes[0],
496 dest_ip_bytes[1],
497 dest_ip_bytes[2],
498 dest_ip_bytes[3],
499 protocol,
500 interface_name
501 );
502
503 let mut eth_context = context.clone();
505 eth_context.set(
506 "eth_type",
507 &crate::network::ethernet::ether_type::IPV4.to_be_bytes(),
508 );
509 eth_context.set("interface", interface_name.as_bytes());
510 eth_context.set("ip_src", &src_ip_bytes);
511
512 if let Some(gw) = gateway {
514 eth_context.set("next_hop", &gw.0);
515 } else {
516 eth_context.set("next_hop", &dest_ip_bytes);
517 }
518
519 if !next_layers.is_empty() {
521 next_layers[0].send(&ip_packet, ð_context, &next_layers[1..])?;
522 } else if let Some(eth_layer) = get_network_manager().get_layer("ethernet") {
523 eth_layer.send(&ip_packet, ð_context, &[])?;
524 }
525
526 let mut stats = self.stats.write();
528 stats.packets_sent += 1;
529 stats.bytes_sent += ip_packet.len() as u64;
530
531 Ok(())
532 }
533
534 fn receive(&self, packet: &[u8], _context: Option<&LayerContext>) -> Result<(), SocketError> {
535 let header = Ipv4Header::from_bytes(packet).ok_or(SocketError::InvalidPacket)?;
537
538 let header_len = header.header_length();
539 let total_length = usize::from(header.total_length);
540
541 if packet.len() < header_len {
542 return Err(SocketError::InvalidPacket);
543 }
544 if total_length < header_len || total_length > packet.len() {
545 return Err(SocketError::InvalidPacket);
546 }
547
548 early_println!(
549 "[IPv4] RX: total_len={} src={}.{}.{}.{} dst={}.{}.{}.{} proto={}",
550 total_length,
551 header.source_ip[0],
552 header.source_ip[1],
553 header.source_ip[2],
554 header.source_ip[3],
555 header.dest_ip[0],
556 header.dest_ip[1],
557 header.dest_ip[2],
558 header.dest_ip[3],
559 header.protocol
560 );
561
562 let calculated_checksum = checksum_from_bytes(&packet[..header_len]);
564 let header_checksum = unsafe { core::ptr::addr_of!(header.checksum).read_unaligned() };
565 if calculated_checksum != header_checksum {
566 early_println!(
567 "[IPv4] Checksum mismatch: calculated=0x{:04X}, header=0x{:04X}",
568 calculated_checksum,
569 header_checksum
570 );
571 let mut stats = self.stats.write();
572 stats.protocol_errors += 1;
573 return Err(SocketError::InvalidPacket);
574 }
575
576 let payload = &packet[header_len..total_length];
577
578 early_println!(
579 "[IPv4] Recv: {} bytes (src: {}.{}.{}.{}, dst: {}.{}.{}.{}, proto: {})",
580 packet.len(),
581 header.source_ip[0],
582 header.source_ip[1],
583 header.source_ip[2],
584 header.source_ip[3],
585 header.dest_ip[0],
586 header.dest_ip[1],
587 header.dest_ip[2],
588 header.dest_ip[3],
589 header.protocol
590 );
591
592 let mut stats = self.stats.write();
594 stats.packets_received += 1;
595 stats.bytes_received += total_length as u64;
596
597 let protocols = self.protocols.read();
599 if let Some(handler) = protocols.get(&header.protocol) {
600 let mut proto_context = LayerContext::new();
601 proto_context.set("ip_src", &header.source_ip);
602 proto_context.set("ip_dst", &header.dest_ip);
603 handler.receive(payload, Some(&proto_context))
604 } else {
605 Err(SocketError::ProtocolNotSupported)
607 }
608 }
609
610 fn name(&self) -> &'static str {
611 "IPv4"
612 }
613
614 fn stats(&self) -> NetworkLayerStats {
615 self.stats.read().clone()
616 }
617
618 fn as_any(&self) -> &dyn core::any::Any {
619 self
620 }
621}
622
623fn checksum_from_bytes(header_bytes: &[u8]) -> u16 {
624 let mut sum: u32 = 0;
625 let mut i = 0;
626
627 while i + 1 < header_bytes.len() {
628 if i == 10 {
629 i += 2;
630 continue;
631 }
632 let word = u16::from_be_bytes([header_bytes[i], header_bytes[i + 1]]);
633 sum += word as u32;
634 i += 2;
635 }
636
637 if i < header_bytes.len() {
638 let word = u16::from_be_bytes([header_bytes[i], 0]);
639 sum += word as u32;
640 }
641
642 while sum >> 16 != 0 {
643 sum = (sum & 0xFFFF) + (sum >> 16);
644 }
645
646 !sum as u16
647}
648
649#[cfg(test)]
650mod tests {
651 use alloc::string::ToString;
652 use alloc::vec;
653
654 use super::*;
655
656 #[test_case]
657 fn test_ipv4_address() {
658 let addr = Ipv4Address::new(192, 168, 1, 100);
659 assert_eq!(addr.as_bytes(), [192, 168, 1, 100]);
660 assert!(!addr.is_broadcast());
661 assert!(!addr.is_loopback());
662 assert!(!addr.is_any());
663
664 let broadcast = Ipv4Address::new(255, 255, 255, 255);
665 assert!(broadcast.is_broadcast());
666
667 let loopback = Ipv4Address::new(127, 0, 0, 1);
668 assert!(loopback.is_loopback());
669
670 let any = Ipv4Address::new(0, 0, 0, 0);
671 assert!(any.is_any());
672 }
673
674 #[test_case]
675 fn test_ipv4_address_u32_conversion() {
676 let addr = Ipv4Address::new(192, 168, 1, 100);
677 assert_eq!(addr.to_u32_be(), u32::from_be_bytes([192, 168, 1, 100]));
678
679 let from_u32 = Ipv4Address::from_u32_be(0xC0A80164u32);
680 assert_eq!(from_u32, addr);
681 }
682
683 #[test_case]
684 fn test_ipv4_header_creation() {
685 let mut header = Ipv4Header::new();
686 header.source_ip = [192, 168, 1, 100];
687 header.dest_ip = [192, 168, 1, 1];
688 header.protocol = protocol::TCP;
689 header.total_length = (20 + 10) as u16;
690
691 assert_eq!(header.version(), 4);
692 assert_eq!(header.ihl(), 5);
693 assert_eq!(header.header_length(), 20);
694 assert_eq!(header.protocol, protocol::TCP);
695 }
696
697 #[test_case]
698 fn test_ipv4_header_serialization() {
699 let mut header = Ipv4Header::new();
700 header.source_ip = [192, 168, 1, 100];
701 header.dest_ip = [192, 168, 1, 1];
702 header.protocol = protocol::TCP;
703 header.total_length = 30;
704
705 let bytes = header.to_bytes();
706 assert_eq!(bytes.len(), 20);
707 assert_eq!(bytes[0], 0x45); assert_eq!(u16::from_be_bytes([bytes[2], bytes[3]]), 30);
709 assert_eq!(&bytes[12..16], [192, 168, 1, 100]);
710 assert_eq!(&bytes[16..20], [192, 168, 1, 1]);
711 }
712
713 #[test_case]
714 fn test_ipv4_header_parsing() {
715 let mut bytes = vec![
716 0x45, 0x00, 0x00, 0x1E, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, 0x00, 0x00, 0xC0, 0xA8, 0x01, 0x64, 0xC0, 0xA8, 0x01, 0x01, ];
727
728 let header = Ipv4Header::from_bytes(&bytes).unwrap();
729 assert_eq!(header.version(), 4);
730 assert_eq!(header.ihl(), 5);
731 let total_length = unsafe { core::ptr::addr_of!(header.total_length).read_unaligned() };
732 assert_eq!(total_length, 30);
733 assert_eq!(header.protocol, protocol::TCP);
734 assert_eq!(header.source_ip, [192, 168, 1, 100]);
735 assert_eq!(header.dest_ip, [192, 168, 1, 1]);
736 assert_eq!(header.ttl, 64);
737 }
738
739 #[test_case]
740 fn test_ipv4_header_invalid_version() {
741 let mut bytes = alloc::vec![0x55u8; 20]; assert!(Ipv4Header::from_bytes(&bytes).is_none());
743 }
744
745 #[test_case]
746 fn test_ipv4_header_too_short() {
747 let bytes = [0u8; 10];
748 assert!(Ipv4Header::from_bytes(&bytes).is_none());
749 }
750
751 #[test_case]
752 fn test_ipv4_layer_creation() {
753 let ip_layer = Ipv4Layer::new();
754 assert!(ip_layer.get_addresses("eth0").is_empty());
756 }
757
758 #[test_case]
759 fn test_ipv4_layer_add_address() {
760 let ip_layer = Ipv4Layer::new();
761
762 let ip = Ipv4Address::new(192, 168, 1, 100);
763 ip_layer.add_address(
764 "eth0",
765 Ipv4AddressInfo {
766 address: ip,
767 netmask: Ipv4Address::new(255, 255, 255, 0),
768 broadcast: Some(Ipv4Address::new(192, 168, 1, 255)),
769 is_primary: true,
770 },
771 );
772
773 let addrs = ip_layer.get_addresses("eth0");
774 assert_eq!(addrs.len(), 1);
775 assert_eq!(addrs[0].address, ip);
776 assert!(addrs[0].is_primary);
777 }
778
779 #[test_case]
780 fn test_ipv4_layer_multiple_addresses() {
781 let ip_layer = Ipv4Layer::new();
782
783 ip_layer.add_address(
785 "eth0",
786 Ipv4AddressInfo {
787 address: Ipv4Address::new(192, 168, 1, 100),
788 netmask: Ipv4Address::new(255, 255, 255, 0),
789 broadcast: None,
790 is_primary: true,
791 },
792 );
793
794 ip_layer.add_address(
796 "eth0",
797 Ipv4AddressInfo {
798 address: Ipv4Address::new(192, 168, 1, 101),
799 netmask: Ipv4Address::new(255, 255, 255, 0),
800 broadcast: None,
801 is_primary: false,
802 },
803 );
804
805 let addrs = ip_layer.get_addresses("eth0");
806 assert_eq!(addrs.len(), 2);
807 assert_eq!(
808 ip_layer.get_primary_ip("eth0"),
809 Some(Ipv4Address::new(192, 168, 1, 100))
810 );
811 }
812
813 #[test_case]
814 fn test_ipv4_layer_routing() {
815 let ip_layer = Ipv4Layer::new();
816
817 ip_layer.add_address(
819 "eth0",
820 Ipv4AddressInfo {
821 address: Ipv4Address::new(192, 168, 1, 100),
822 netmask: Ipv4Address::new(255, 255, 255, 0),
823 broadcast: None,
824 is_primary: true,
825 },
826 );
827
828 ip_layer.add_route(RouteEntry {
830 destination: Ipv4Address::new(192, 168, 1, 0),
831 netmask: Ipv4Address::new(255, 255, 255, 0),
832 gateway: None,
833 interface: "eth0".to_string(),
834 metric: 0,
835 });
836
837 ip_layer.set_default_gateway(Ipv4Address::new(192, 168, 1, 1), "eth0");
839
840 let result = ip_layer.select_source(Ipv4Address::new(192, 168, 1, 50));
842 assert!(result.is_some());
843 let (iface, src_ip, gw) = result.unwrap();
844 assert_eq!(iface, "eth0");
845 assert_eq!(src_ip, Ipv4Address::new(192, 168, 1, 100));
846 assert!(gw.is_none()); let result = ip_layer.select_source(Ipv4Address::new(8, 8, 8, 8));
850 assert!(result.is_some());
851 let (iface, src_ip, gw) = result.unwrap();
852 assert_eq!(iface, "eth0");
853 assert_eq!(src_ip, Ipv4Address::new(192, 168, 1, 100));
854 assert_eq!(gw, Some(Ipv4Address::new(192, 168, 1, 1)));
855 }
856
857 #[test_case]
858 fn test_ipv4_is_local_ip() {
859 let ip_layer = Ipv4Layer::new();
860
861 ip_layer.add_address(
862 "eth0",
863 Ipv4AddressInfo {
864 address: Ipv4Address::new(192, 168, 1, 100),
865 netmask: Ipv4Address::new(255, 255, 255, 0),
866 broadcast: None,
867 is_primary: true,
868 },
869 );
870
871 assert!(ip_layer.is_local_ip(Ipv4Address::new(192, 168, 1, 100)));
872 assert!(!ip_layer.is_local_ip(Ipv4Address::new(192, 168, 1, 101)));
873 }
874
875 #[test_case]
876 fn test_ipv4_checksum() {
877 let mut header = Ipv4Header::new();
878 header.source_ip = [192, 168, 1, 100];
879 header.dest_ip = [192, 168, 1, 1];
880 header.protocol = protocol::TCP;
881 header.ttl = 64;
882 header.total_length = 20;
883 header.identification = 0;
884 header.flags_fragment = 0;
885 header.tos = 0;
886
887 let checksum = header.calculate_checksum();
888 assert_ne!(checksum, 0);
890 }
891
892 #[test_case]
893 fn test_ipv4_checksum_known_vector() {
894 let header = Ipv4Header {
895 version_ihl: 0x45,
896 tos: 0x00,
897 total_length: 0x003C,
898 identification: 0x1C46,
899 flags_fragment: 0x4000,
900 ttl: 0x40,
901 protocol: 0x06,
902 checksum: 0x0000,
903 source_ip: [192, 168, 0, 1],
904 dest_ip: [192, 168, 0, 199],
905 };
906
907 let checksum = header.calculate_checksum();
908 assert_eq!(checksum, 0x9C5D);
909 }
910
911 #[test_case]
912 fn test_protocol_constants() {
913 assert_eq!(protocol::ICMP, 1);
914 assert_eq!(protocol::TCP, 6);
915 assert_eq!(protocol::UDP, 17);
916 assert_eq!(protocol::IPV6, 41);
917 }
918}