kernel/network/
tcp.rs

1//! TCP protocol layer (Complete implementation)
2//!
3//! This module provides a full TCP implementation with 3-way handshake,
4//! flow control, and retransmission.
5
6use alloc::collections::{BTreeMap, VecDeque};
7use alloc::string::String;
8use alloc::sync::{Arc, Weak};
9use alloc::vec::Vec;
10use core::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU64, AtomicUsize, Ordering};
11use spin::{Mutex, RwLock};
12
13use crate::network::ipv4::Ipv4Address;
14use crate::network::protocol_stack::get_network_manager;
15use crate::network::protocol_stack::{LayerContext, NetworkLayer, NetworkLayerStats};
16use crate::network::socket::SocketError;
17use crate::network::socket::{
18    Inet4SocketAddress, SocketAddress, SocketControl, SocketObject, SocketState,
19};
20
21/// TCP connection states
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum TcpState {
24    Closed,
25    Listen,
26    SynSent,
27    SynReceived,
28    Established,
29    FinWait1,
30    FinWait2,
31    CloseWait,
32    Closing,
33    LastAck,
34    TimeWait,
35}
36
37/// TCP flags
38pub mod tcp_flags {
39    pub const FIN: u8 = 0x01;
40    pub const SYN: u8 = 0x02;
41    pub const RST: u8 = 0x04;
42    pub const PSH: u8 = 0x08;
43    pub const ACK: u8 = 0x10;
44    pub const URG: u8 = 0x20;
45}
46
47/// Buffer size limits (prevent memory exhaustion)
48const MAX_SEND_BUFFER_SIZE: usize = 65536; // 64KB
49const MAX_RECV_BUFFER_SIZE: usize = 65536; // 64KB
50const MAX_UNACKED_SEGMENTS: usize = 256; // Limit unacked segment list
51
52/// TCP header
53#[derive(Debug, Clone, Copy)]
54#[repr(C, packed)]
55pub struct TcpHeader {
56    /// Source port
57    pub src_port: u16,
58    /// Destination port
59    pub dst_port: u16,
60    /// Sequence number
61    pub seq_number: u32,
62    /// Acknowledgment number
63    pub ack_number: u32,
64    /// Data offset (4 bits) + reserved (4 bits) + flags (8 bits)
65    pub data_offset_flags: u16,
66    /// Window size
67    pub window_size: u16,
68    /// Checksum
69    pub checksum: u16,
70    /// Urgent pointer
71    pub urgent_pointer: u16,
72}
73
74impl TcpHeader {
75    /// Create a new TCP header
76    pub fn new(src_port: u16, dst_port: u16) -> Self {
77        Self {
78            src_port,
79            dst_port,
80            seq_number: 0,
81            ack_number: 0,
82            data_offset_flags: 0x5000, // Data offset = 5 (20 bytes), no flags
83            window_size: 65535,
84            checksum: 0,
85            urgent_pointer: 0,
86        }
87    }
88
89    /// Get TCP flags
90    pub fn flags(&self) -> u8 {
91        (self.data_offset_flags & 0x3F) as u8
92    }
93
94    /// Set TCP flags
95    pub fn set_flags(&mut self, flags: u8) {
96        self.data_offset_flags = (self.data_offset_flags & 0xFFC0) | (flags as u16 & 0x3F);
97    }
98
99    /// Get data offset in bytes
100    pub fn data_offset(&self) -> usize {
101        ((self.data_offset_flags >> 12) as usize) * 4
102    }
103
104    /// Calculate TCP checksum
105    pub fn calculate_checksum(&self, src_ip: [u8; 4], dst_ip: [u8; 4], data: &[u8]) -> u16 {
106        let tcp_len = (self.data_offset() + data.len()) as u16;
107        let mut pseudo = Vec::with_capacity(12 + 20 + data.len());
108        pseudo.extend_from_slice(&src_ip);
109        pseudo.extend_from_slice(&dst_ip);
110        pseudo.push(0);
111        pseudo.push(6); // TCP protocol number
112        pseudo.extend_from_slice(&tcp_len.to_be_bytes());
113
114        let mut header = *self;
115        header.checksum = 0;
116        pseudo.extend_from_slice(&header.to_bytes());
117        pseudo.extend_from_slice(data);
118
119        let mut sum: u32 = 0;
120        for chunk in pseudo.chunks(2) {
121            if chunk.len() == 2 {
122                sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
123            } else {
124                sum += (chunk[0] as u32) << 8;
125            }
126            sum = (sum & 0xFFFF) + (sum >> 16);
127        }
128
129        while sum >> 16 != 0 {
130            sum = (sum & 0xFFFF) + (sum >> 16);
131        }
132
133        !sum as u16
134    }
135
136    /// Serialize header to bytes
137    pub fn to_bytes(&self) -> Vec<u8> {
138        let mut bytes = Vec::with_capacity(20);
139        bytes.extend_from_slice(&self.src_port.to_be_bytes());
140        bytes.extend_from_slice(&self.dst_port.to_be_bytes());
141        bytes.extend_from_slice(&self.seq_number.to_be_bytes());
142        bytes.extend_from_slice(&self.ack_number.to_be_bytes());
143        bytes.extend_from_slice(&self.data_offset_flags.to_be_bytes());
144        bytes.extend_from_slice(&self.window_size.to_be_bytes());
145        bytes.extend_from_slice(&self.checksum.to_be_bytes());
146        bytes.extend_from_slice(&self.urgent_pointer.to_be_bytes());
147        bytes
148    }
149
150    /// Parse header from bytes
151    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
152        if bytes.len() < 20 {
153            return None;
154        }
155
156        Some(Self {
157            src_port: u16::from_be_bytes([bytes[0], bytes[1]]),
158            dst_port: u16::from_be_bytes([bytes[2], bytes[3]]),
159            seq_number: u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
160            ack_number: u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]),
161            data_offset_flags: u16::from_be_bytes([bytes[12], bytes[13]]),
162            window_size: u16::from_be_bytes([bytes[14], bytes[15]]),
163            checksum: u16::from_be_bytes([bytes[16], bytes[17]]),
164            urgent_pointer: u16::from_be_bytes([bytes[18], bytes[19]]),
165        })
166    }
167}
168
169/// Unacknowledged TCP segment for retransmission tracking
170#[derive(Clone)]
171struct UnackedSegment {
172    /// Sequence number of first byte
173    seq: u32,
174    /// Data to retransmit
175    data: Vec<u8>,
176    /// Flags (SYN, FIN, PSH, etc.)
177    flags: u8,
178    /// Transmission count
179    tx_count: u16,
180    /// Last transmission timestamp (ticks)
181    last_tx_time: u64,
182}
183
184/// Out-of-order TCP segment for reassembly
185#[derive(Clone)]
186struct OutOfOrderSegment {
187    /// Sequence number of first byte
188    seq: u32,
189    /// Segment data
190    data: Vec<u8>,
191}
192
193impl PartialEq for OutOfOrderSegment {
194    fn eq(&self, other: &Self) -> bool {
195        self.seq == other.seq
196    }
197}
198
199impl Eq for OutOfOrderSegment {}
200
201impl PartialOrd for OutOfOrderSegment {
202    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
203        self.seq.partial_cmp(&other.seq)
204    }
205}
206
207impl Ord for OutOfOrderSegment {
208    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
209        self.seq.cmp(&other.seq)
210    }
211}
212
213/// Retransmission timer handler
214struct RetransTimer {
215    socket: Weak<TcpSocket>,
216    seq: u32,
217}
218
219impl crate::timer::TimerHandler for RetransTimer {
220    fn on_timer_expired(self: Arc<Self>, _context: usize) {
221        if let Some(socket) = self.socket.upgrade() {
222            socket.handle_retrans_timeout(self.seq);
223        }
224    }
225}
226
227/// TCP socket (full implementation)
228pub struct TcpSocket {
229    /// TCP connection state
230    state: Mutex<TcpState>,
231
232    /// Local IP address
233    local_ip: Mutex<Option<Ipv4Address>>,
234    /// Local port
235    pub(crate) local_port: AtomicU16,
236
237    /// Remote IP address
238    remote_ip: Mutex<Option<Ipv4Address>>,
239    /// Remote port
240    remote_port: AtomicU16,
241
242    /// Sequence numbers
243    send_seq: AtomicU32,
244    send_unacked: AtomicU32,
245    recv_seq: AtomicU32,
246    recv_ack: AtomicU32,
247
248    /// Window size
249    send_window: AtomicU16,
250    recv_window: AtomicU16,
251
252    /// Data buffers
253    send_buffer: Mutex<VecDeque<u8>>,
254    recv_buffer: Mutex<VecDeque<u8>>,
255
256    /// Reference to TCP layer
257    tcp_layer: Weak<TcpLayer>,
258    /// Weak self reference for registration
259    self_weak: Weak<TcpSocket>,
260    /// Pending accepted connections (listener only)
261    pending_accept: Mutex<VecDeque<Arc<TcpSocket>>>,
262    /// Maximum backlog size (from listen())
263    max_backlog: AtomicUsize,
264
265    /// Statistics
266    bytes_sent: AtomicU64,
267    bytes_received: AtomicU64,
268
269    /// RTO (Retransmission Timeout) calculation - RFC 6298
270    /// Smoothed RTT (8 * srtt for fixed-point arithmetic)
271    srtt: AtomicU32,
272    /// RTT variation (4 * rttvar for fixed-point arithmetic)
273    rttvar: AtomicU32,
274    /// Current RTO in ticks (initial: 1 second = 100 ticks @ 10ms)
275    rto: AtomicU32,
276    /// Retransmission count for exponential backoff
277    retrans_count: AtomicU16,
278    /// Timer ID for retransmission timer
279    retrans_timer_id: Mutex<Option<u64>>,
280    /// Timestamp of last segment transmission (for RTT measurement)
281    last_send_time: AtomicU64,
282    /// Whether we're timing an RTT measurement (Karn's algorithm)
283    timing_rtt: AtomicU16,
284    /// Sequence number being timed
285    timed_seq: AtomicU32,
286
287    /// List of unacknowledged segments for retransmission
288    unacked_segments: Mutex<VecDeque<UnackedSegment>>,
289
290    /// Out-of-order segments for reassembly (sorted by sequence number)
291    out_of_order: Mutex<BTreeMap<u32, OutOfOrderSegment>>,
292
293    /// Waker for blocking accept() operations
294    accept_waker: Mutex<Option<Arc<crate::sync::Waker>>>,
295    /// Waker for blocking recv() operations
296    recv_waker: Mutex<Option<Arc<crate::sync::Waker>>>,
297    /// Waker for blocking send() operations
298    send_waker: Mutex<Option<Arc<crate::sync::Waker>>>,
299    /// Block mode: true for blocking, false for non-blocking
300    blocking_mode: AtomicBool,
301
302    /// Duplicate ACK count for Fast Retransmit
303    dup_ack_count: AtomicU16,
304    /// Last ACK sequence number for detecting duplicates
305    last_ack_seq: AtomicU32,
306}
307
308impl TcpSocket {
309    /// Safely downcast a SocketObject to TcpSocket using Any trait
310    ///
311    /// Returns None if socket is not a TcpSocket.
312    /// This is completely safe and does not use any unsafe code.
313    pub fn from_socket_object(socket: &Arc<dyn SocketObject>) -> Option<&Self> {
314        socket.as_any().downcast_ref::<TcpSocket>()
315    }
316
317    /// Blocking accept - waits for a connection
318    pub fn accept_blocking(
319        &self,
320        task_id: usize,
321        trapframe: &mut crate::arch::Trapframe,
322    ) -> Result<Arc<dyn SocketObject>, SocketError> {
323        if self.get_state() != TcpState::Listen {
324            return Err(SocketError::NotListening);
325        }
326
327        let nonblocking = !self.blocking_mode.load(Ordering::SeqCst);
328
329        loop {
330            {
331                let mut pending = self.pending_accept.lock();
332                if let Some(socket) = pending.pop_front() {
333                    return Ok(socket as Arc<dyn SocketObject>);
334                }
335            }
336
337            if nonblocking {
338                return Err(SocketError::WouldBlock);
339            }
340
341            let waker = {
342                let mut waker_lock = self.accept_waker.lock();
343                waker_lock
344                    .get_or_insert_with(|| {
345                        Arc::new(crate::sync::Waker::new_interruptible("tcp_accept"))
346                    })
347                    .clone()
348            };
349            waker.wait(task_id, trapframe);
350        }
351    }
352
353    /// Create a new TCP socket
354    pub fn new(tcp_layer: Weak<TcpLayer>) -> Arc<Self> {
355        Arc::new_cyclic(|weak| Self {
356            state: Mutex::new(TcpState::Closed),
357            local_ip: Mutex::new(None),
358            local_port: AtomicU16::new(0),
359            remote_ip: Mutex::new(None),
360            remote_port: AtomicU16::new(0),
361            send_seq: AtomicU32::new(0),
362            send_unacked: AtomicU32::new(0),
363            recv_seq: AtomicU32::new(0),
364            recv_ack: AtomicU32::new(0),
365            send_window: AtomicU16::new(65535),
366            recv_window: AtomicU16::new(65535),
367            send_buffer: Mutex::new(VecDeque::new()),
368            recv_buffer: Mutex::new(VecDeque::new()),
369            tcp_layer,
370            self_weak: weak.clone(),
371            pending_accept: Mutex::new(VecDeque::new()),
372            max_backlog: AtomicUsize::new(0),
373            bytes_sent: AtomicU64::new(0),
374            bytes_received: AtomicU64::new(0),
375
376            // RTO initialization - RFC 6298
377            // Initial RTO = 1 second = 100 ticks (10ms/tick)
378            srtt: AtomicU32::new(0),
379            rttvar: AtomicU32::new(0),
380            rto: AtomicU32::new(100), // 1 second in ticks
381            retrans_count: AtomicU16::new(0),
382            retrans_timer_id: Mutex::new(None),
383            last_send_time: AtomicU64::new(0),
384            timing_rtt: AtomicU16::new(0),
385            timed_seq: AtomicU32::new(0),
386
387            // Unacked segments list
388            unacked_segments: Mutex::new(VecDeque::new()),
389
390            // Out-of-order segments map
391            out_of_order: Mutex::new(BTreeMap::new()),
392
393            // Blocking support
394            accept_waker: Mutex::new(None),
395            recv_waker: Mutex::new(None),
396            send_waker: Mutex::new(None),
397            blocking_mode: AtomicBool::new(true), // Default to blocking mode
398
399            // Fast Retransmit - duplicate ACK tracking
400            dup_ack_count: AtomicU16::new(0),
401            last_ack_seq: AtomicU32::new(0),
402        })
403    }
404
405    fn matches_peer(&self, src_ip: Ipv4Address, src_port: u16) -> bool {
406        if self.get_state() == TcpState::Listen {
407            return false;
408        }
409
410        let remote_ip = self.remote_ip.lock().clone();
411        let remote_port = self.remote_port.load(Ordering::SeqCst);
412        match remote_ip {
413            Some(ip) => ip == src_ip && remote_port == src_port,
414            None => false,
415        }
416    }
417
418    fn ensure_local_ip(&self) {
419        let mut local_ip = self.local_ip.lock();
420        let needs_update = match *local_ip {
421            Some(ip) => ip.0 == [0, 0, 0, 0],
422            None => true,
423        };
424        if !needs_update {
425            return;
426        }
427
428        let manager = get_network_manager();
429        if let Some(default_iface) = manager.get_default_interface() {
430            if let Some(ip_layer) = manager.get_layer("ip") {
431                if let Some(ip) = ip_layer
432                    .as_any()
433                    .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
434                {
435                    if let Some(addr) = ip.get_primary_ip(default_iface.name()) {
436                        *local_ip = Some(addr);
437                    }
438                }
439            }
440        }
441    }
442
443    fn register_local_port(&self, port: u16) -> Result<(), SocketError> {
444        let tcp_layer = self
445            .tcp_layer
446            .upgrade()
447            .ok_or(SocketError::InvalidOperation)?;
448
449        tcp_layer.register_port(port, self.self_weak.clone());
450        Ok(())
451    }
452
453    fn allocate_ephemeral_port(&self) -> u16 {
454        static NEXT_EPHEMERAL_PORT: AtomicU16 = AtomicU16::new(49152);
455
456        let port = NEXT_EPHEMERAL_PORT.fetch_add(1, Ordering::SeqCst);
457        if port == u16::MAX {
458            NEXT_EPHEMERAL_PORT.store(49152, Ordering::SeqCst);
459        }
460        if port < 49152 { 49152 } else { port }
461    }
462
463    /// Get current TCP state
464    pub fn get_state(&self) -> TcpState {
465        *self.state.lock()
466    }
467
468    /// Set TCP state
469    pub fn set_state(&self, new_state: TcpState) {
470        *self.state.lock() = new_state;
471    }
472
473    /// Process incoming TCP segment
474    pub fn process_segment(&self, src_ip: Ipv4Address, header: TcpHeader, data: &[u8]) {
475        let current_state = self.get_state();
476
477        match current_state {
478            TcpState::Listen => {
479                if header.flags() & tcp_flags::SYN != 0 {
480                    let tcp_layer = match self.tcp_layer.upgrade() {
481                        Some(layer) => layer,
482                        None => return,
483                    };
484
485                    let child = TcpSocket::new(Arc::downgrade(&tcp_layer));
486                    let local_port = self.local_port.load(Ordering::SeqCst);
487                    if local_port == 0 {
488                        return;
489                    }
490
491                    if let Some(local_ip) = self.local_ip.lock().clone() {
492                        *child.local_ip.lock() = Some(local_ip);
493                    } else {
494                        child.ensure_local_ip();
495                    }
496
497                    child.local_port.store(local_port, Ordering::SeqCst);
498                    tcp_layer.register_port(local_port, child.self_weak.clone());
499                    child.handle_syn_received(src_ip, header);
500
501                    // Check backlog limit and enqueue atomically
502                    let max_backlog = self.max_backlog.load(Ordering::SeqCst);
503                    {
504                        let mut pending = self.pending_accept.lock();
505                        if pending.len() < max_backlog {
506                            pending.push_back(child);
507                        } else {
508                            // Must drop lock before sending RST (child borrows network)
509                            drop(pending);
510                            // Backlog full - send RST to reject connection
511                            let rst_port = self.local_port.load(Ordering::SeqCst);
512                            let rst_seq = self.send_seq.load(Ordering::SeqCst);
513                            let mut rst_header = TcpHeader::new(rst_port, header.src_port);
514                            rst_header.seq_number = rst_seq;
515                            rst_header.ack_number = child.recv_ack.load(Ordering::SeqCst);
516                            rst_header.set_flags(tcp_flags::RST);
517                            child.send_segment(src_ip, rst_header, &[], false, false);
518                            return;
519                        }
520                    }
521
522                    // Wake up any blocking accept() calls (lock released above)
523                    if let Some(waker) = self.accept_waker.lock().as_ref() {
524                        waker.wake_one();
525                    }
526                }
527            }
528            TcpState::SynSent => {
529                if header.flags() & (tcp_flags::SYN | tcp_flags::ACK)
530                    == (tcp_flags::SYN | tcp_flags::ACK)
531                {
532                    // Received SYN-ACK, move to ESTABLISHED
533                    self.handle_syn_ack_received(src_ip, header);
534                } else if header.flags() & tcp_flags::RST != 0 {
535                    // Received RST, abort connection
536                    self.handle_rst();
537                }
538            }
539            TcpState::SynReceived => {
540                if header.flags() & tcp_flags::RST != 0 {
541                    self.handle_rst();
542                    return;
543                }
544
545                if header.flags() & tcp_flags::ACK != 0 {
546                    let expected_ack = self.send_seq.load(Ordering::SeqCst);
547                    if header.ack_number == expected_ack {
548                        self.update_send_window(header.ack_number);
549                        self.set_state(TcpState::Established);
550
551                        if !data.is_empty() {
552                            self.handle_data_segment(src_ip, header, data);
553                        } else if header.flags() & tcp_flags::FIN != 0 {
554                            self.handle_fin(src_ip, header);
555                        }
556                    }
557                }
558            }
559            TcpState::Established => {
560                if data.is_empty() {
561                    self.handle_control_segment(src_ip, header);
562                } else {
563                    self.handle_data_segment(src_ip, header, data);
564                }
565            }
566            _ => {}
567        }
568    }
569
570    /// Handle incoming SYN (SYN-RECEIVED state)
571    fn handle_syn_received(&self, src_ip: Ipv4Address, header: TcpHeader) {
572        // Store remote address
573        *self.remote_ip.lock() = Some(src_ip);
574        self.remote_port.store(header.src_port, Ordering::SeqCst);
575
576        // Track peer sequence and our initial sequence
577        let initial_seq = 1000;
578        let next_recv = header.seq_number.wrapping_add(1);
579        self.send_seq.store(initial_seq, Ordering::SeqCst);
580        self.recv_seq.store(next_recv, Ordering::SeqCst);
581        self.recv_ack.store(next_recv, Ordering::SeqCst);
582
583        // Send SYN-ACK
584        let local_port = self.local_port.load(Ordering::SeqCst);
585        let mut syn_ack = TcpHeader::new(local_port, header.src_port);
586        syn_ack.seq_number = initial_seq;
587        syn_ack.ack_number = next_recv;
588        syn_ack.set_flags(tcp_flags::SYN | tcp_flags::ACK);
589        self.send_segment(src_ip, syn_ack, &[], false, false);
590        self.send_seq.fetch_add(1, Ordering::SeqCst);
591        self.set_state(TcpState::SynReceived);
592    }
593
594    /// Handle received SYN-ACK (move to ESTABLISHED)
595    fn handle_syn_ack_received(&self, src_ip: Ipv4Address, header: TcpHeader) {
596        *self.remote_ip.lock() = Some(src_ip);
597        self.remote_port.store(header.src_port, Ordering::SeqCst);
598
599        let next_recv = header.seq_number.wrapping_add(1);
600        self.recv_seq.store(next_recv, Ordering::SeqCst);
601        self.recv_ack.store(next_recv, Ordering::SeqCst);
602
603        // Advance our sequence number past the SYN we sent
604        let acked = header.ack_number;
605        self.send_seq.store(acked, Ordering::SeqCst);
606        self.send_unacked.store(acked, Ordering::SeqCst);
607
608        self.send_ack(src_ip, header.src_port, next_recv);
609
610        self.set_state(TcpState::Established);
611    }
612
613    /// Handle RST (Reset) - properly cleanup connection
614    fn handle_rst(&self) {
615        // Cancel retransmission timers
616        self.cancel_retrans_timer();
617
618        // Clear send buffer
619        self.send_buffer.lock().clear();
620
621        // Clear receive buffer
622        self.recv_buffer.lock().clear();
623
624        // Clear unacked segments
625        self.unacked_segments.lock().clear();
626
627        // Clear out-of-order segments
628        self.out_of_order.lock().clear();
629
630        // Reset sequence numbers
631        self.send_seq.store(0, Ordering::SeqCst);
632        self.send_unacked.store(0, Ordering::SeqCst);
633        self.recv_seq.store(0, Ordering::SeqCst);
634        self.recv_ack.store(0, Ordering::SeqCst);
635
636        // Reset window sizes
637        self.send_window.store(65535, Ordering::SeqCst);
638        self.recv_window.store(65535, Ordering::SeqCst);
639
640        // Reset RTO state
641        self.srtt.store(0, Ordering::SeqCst);
642        self.rttvar.store(0, Ordering::SeqCst);
643        self.rto.store(100, Ordering::SeqCst); // Reset to initial value
644        self.retrans_count.store(0, Ordering::SeqCst);
645        self.timing_rtt.store(0, Ordering::SeqCst);
646
647        // Clear addresses
648        *self.local_ip.lock() = None;
649        *self.remote_ip.lock() = None;
650        self.local_port.store(0, Ordering::SeqCst);
651        self.remote_port.store(0, Ordering::SeqCst);
652
653        // Set state to Closed
654        self.set_state(TcpState::Closed);
655    }
656
657    /// Handle control segment (ACK, FIN, RST)
658    fn handle_control_segment(&self, src_ip: Ipv4Address, header: TcpHeader) {
659        if header.flags() & tcp_flags::RST != 0 {
660            self.handle_rst();
661            return;
662        }
663
664        if header.flags() & tcp_flags::FIN != 0 {
665            self.handle_fin(src_ip, header);
666        }
667
668        if header.flags() & tcp_flags::ACK != 0 {
669            self.update_send_window(header.ack_number);
670            self.stop_rtt_measurement(header.ack_number);
671            self.remove_acked_segments(header.ack_number);
672        }
673    }
674
675    /// Handle data segment
676    fn handle_data_segment(&self, src_ip: Ipv4Address, header: TcpHeader, data: &[u8]) {
677        if header.flags() & tcp_flags::RST != 0 {
678            self.handle_rst();
679            return;
680        }
681
682        // Check sequence number
683        let expected_seq = self.recv_seq.load(Ordering::SeqCst);
684        let segment_seq = header.seq_number;
685        let segment_end = segment_seq.wrapping_add(data.len() as u32);
686
687        // Old segment (duplicate) - send ACK
688        if segment_end <= expected_seq {
689            self.send_ack(src_ip, header.src_port, expected_seq);
690            return;
691        }
692
693        // Out-of-order segment - buffer it
694        if segment_seq > expected_seq {
695            if !data.is_empty() {
696                let mut out_of_order = self.out_of_order.lock();
697                // Check if segment already buffered
698                if !out_of_order.contains_key(&segment_seq) {
699                    // Check out-of-order buffer limit
700                    if out_of_order.len() < 128 {
701                        let ooo_seg = OutOfOrderSegment {
702                            seq: segment_seq,
703                            data: data.to_vec(),
704                        };
705                        out_of_order.insert(segment_seq, ooo_seg);
706                    }
707                }
708                drop(out_of_order);
709
710                // Send ACK for expected sequence
711                self.send_ack(src_ip, header.src_port, expected_seq);
712            }
713
714            // Process ACK if present
715            if header.flags() & tcp_flags::ACK != 0 {
716                self.update_send_window(header.ack_number);
717                self.stop_rtt_measurement(header.ack_number);
718                self.remove_acked_segments(header.ack_number);
719            }
720            return;
721        }
722
723        // In-order segment (segment_seq == expected_seq)
724        if !data.is_empty() {
725            let mut recv_buf = self.recv_buffer.lock();
726
727            // Check receive buffer limit - drop data if full (should update window to 0)
728            if recv_buf.len() + data.len() > MAX_RECV_BUFFER_SIZE {
729                // Buffer full - send ACK with window=0 to stop sender
730                drop(recv_buf);
731                self.send_ack(src_ip, header.src_port, expected_seq);
732                return;
733            }
734
735            recv_buf.extend(data);
736            let mut next_seq = expected_seq.wrapping_add(data.len() as u32);
737
738            // Check if we can reassemble from out-of-order buffer
739            let mut out_of_order = self.out_of_order.lock();
740            loop {
741                // Remove and process next consecutive segment from out-of-order buffer
742                if let Some((_seq, ooo_seg)) = out_of_order.first_key_value() {
743                    if *_seq == next_seq {
744                        let seq = *_seq;
745                        let ooo_seg = out_of_order.remove(&seq).unwrap();
746
747                        // Check buffer limit before adding out-of-order data
748                        if recv_buf.len() + ooo_seg.data.len() <= MAX_RECV_BUFFER_SIZE {
749                            recv_buf.extend(&ooo_seg.data);
750                            next_seq = next_seq.wrapping_add(ooo_seg.data.len() as u32);
751                        } else {
752                            // Buffer full, stop reassembly
753                            break;
754                        }
755                    } else {
756                        // Gap found, stop reassembly
757                        break;
758                    }
759                } else {
760                    // No more out-of-order segments
761                    break;
762                }
763            }
764            drop(out_of_order);
765
766            // Update receive window based on remaining buffer space
767            let available = MAX_RECV_BUFFER_SIZE.saturating_sub(recv_buf.len());
768            self.recv_window
769                .store(available.min(65535) as u16, Ordering::SeqCst);
770
771            self.recv_seq.store(next_seq, Ordering::SeqCst);
772            self.recv_ack.store(next_seq, Ordering::SeqCst);
773            drop(recv_buf);
774
775            // Wake up any blocking recv() calls
776            if let Some(waker) = self.recv_waker.lock().as_ref() {
777                waker.wake_one();
778            }
779
780            // Send ACK for the new next_seq (may acknowledge multiple segments)
781            self.send_ack(src_ip, header.src_port, next_seq);
782
783            // Update received bytes
784            self.bytes_received
785                .fetch_add(data.len() as u64, Ordering::SeqCst);
786        }
787
788        // Process ACK if present
789        if header.flags() & tcp_flags::ACK != 0 {
790            self.update_send_window(header.ack_number);
791            self.stop_rtt_measurement(header.ack_number);
792            self.remove_acked_segments(header.ack_number);
793        }
794    }
795
796    /// Handle FIN segment
797    fn handle_fin(&self, src_ip: Ipv4Address, header: TcpHeader) {
798        let current_state = self.get_state();
799        let ack_seq = header.seq_number.wrapping_add(1);
800        self.recv_seq.store(ack_seq, Ordering::SeqCst);
801        self.recv_ack.store(ack_seq, Ordering::SeqCst);
802        match current_state {
803            TcpState::Established => {
804                self.send_ack(src_ip, header.src_port, ack_seq);
805                self.set_state(TcpState::CloseWait);
806                if let Some(waker) = self.recv_waker.lock().as_ref() {
807                    waker.wake_all();
808                }
809            }
810            TcpState::FinWait1 => {
811                if header.flags() & (tcp_flags::FIN | tcp_flags::ACK)
812                    == (tcp_flags::FIN | tcp_flags::ACK)
813                {
814                    self.send_ack(src_ip, header.src_port, ack_seq);
815                    self.set_state(TcpState::TimeWait);
816                }
817            }
818            _ => {}
819        }
820    }
821
822    /// Update send window based on acknowledgment
823    fn update_send_window(&self, ack_number: u32) {
824        // Track the latest acknowledged sequence number
825        self.send_unacked.store(ack_number, Ordering::SeqCst);
826
827        // Fast Retransmit - duplicate ACK detection
828        let last_ack = self.last_ack_seq.load(Ordering::SeqCst);
829        if ack_number == last_ack {
830            // Duplicate ACK - increment counter
831            let count = self.dup_ack_count.fetch_add(1, Ordering::SeqCst);
832
833            // If we've received 3 duplicate ACKs, trigger fast retransmit
834            if count >= 2 {
835                self.fast_retransmit();
836                self.dup_ack_count.store(0, Ordering::SeqCst);
837            }
838        } else {
839            // New ACK - reset duplicate counter
840            self.last_ack_seq.store(ack_number, Ordering::SeqCst);
841            self.dup_ack_count.store(0, Ordering::SeqCst);
842        }
843
844        // Wake up any blocking send() calls (buffer may have space now)
845        if let Some(waker) = self.send_waker.lock().as_ref() {
846            waker.wake_one();
847        }
848    }
849
850    /// Fast Retransmit - immediately retransmit unacknowledged segments
851    fn fast_retransmit(&self) {
852        let mut unacked = self.unacked_segments.lock();
853        if let Some(first_seg) = unacked.front() {
854            // Retransmit the oldest unacknowledged segment
855            if let Some(dest_ip) = self.remote_ip.lock().clone() {
856                let dest_port = self.remote_port.load(Ordering::SeqCst);
857                let local_port = self.local_port.load(Ordering::SeqCst);
858
859                let mut header = TcpHeader::new(local_port, dest_port);
860                header.seq_number = first_seg.seq;
861                header.ack_number = self.recv_ack.load(Ordering::SeqCst);
862                header.set_flags(first_seg.flags);
863
864                // Retransmit (don't update sequence number, mark as retransmission)
865                self.send_segment(dest_ip, header, &first_seg.data, false, true);
866
867                // Update transmission count
868                if let Some(seg) = unacked.front_mut() {
869                    seg.tx_count += 1;
870                    seg.last_tx_time = crate::timer::get_tick();
871                }
872
873                // Reset RTO for next retransmission
874                self.retrans_count.store(1, Ordering::SeqCst);
875            }
876        }
877    }
878
879    /// Send SYN packet
880    fn send_syn(&self, dest_ip: Ipv4Address, dest_port: u16) {
881        let local_port = self.local_port.load(Ordering::SeqCst);
882
883        let initial_seq = 1000;
884        self.send_seq.store(initial_seq, Ordering::SeqCst);
885
886        let mut header = TcpHeader::new(local_port, dest_port);
887        header.seq_number = initial_seq;
888        header.set_flags(tcp_flags::SYN);
889
890        self.send_segment(dest_ip, header, &[], false, false);
891        self.send_seq.fetch_add(1, Ordering::SeqCst);
892        self.set_state(TcpState::SynSent);
893    }
894
895    /// Send SYN-ACK packet
896    fn send_syn_ack(&self, dest_ip: Ipv4Address, dest_port: u16, their_seq: u32, ack_seq: u32) {
897        let local_port = self.local_port.load(Ordering::SeqCst);
898
899        let mut header = TcpHeader::new(local_port, dest_port);
900        header.seq_number = their_seq;
901        header.ack_number = ack_seq;
902        header.set_flags(tcp_flags::SYN | tcp_flags::ACK);
903
904        self.send_segment(dest_ip, header, &[], false, false);
905        self.set_state(TcpState::SynReceived);
906    }
907
908    /// Send ACK packet
909    fn send_ack(&self, dest_ip: Ipv4Address, dest_port: u16, ack_seq: u32) {
910        let local_port = self.local_port.load(Ordering::SeqCst);
911        let send_seq = self.send_seq.load(Ordering::SeqCst);
912
913        let mut header = TcpHeader::new(local_port, dest_port);
914        header.seq_number = send_seq;
915        header.ack_number = ack_seq;
916        header.set_flags(tcp_flags::ACK);
917
918        self.send_segment(dest_ip, header, &[], false, false);
919    }
920
921    /// Send FIN packet
922    fn send_fin(&self) {
923        let dest_ip = self.remote_ip.lock().clone().unwrap();
924        let dest_port = self.remote_port.load(Ordering::SeqCst);
925        let local_port = self.local_port.load(Ordering::SeqCst);
926
927        let send_seq = self.send_seq.load(Ordering::SeqCst);
928
929        let mut header = TcpHeader::new(local_port, dest_port);
930        header.seq_number = send_seq;
931        header.set_flags(tcp_flags::FIN);
932
933        self.send_segment(dest_ip, header, &[], true, false);
934        self.set_state(TcpState::FinWait1);
935    }
936
937    /// Send FIN-ACK packet
938    fn send_fin_ack(&self) {
939        let dest_ip = self.remote_ip.lock().clone().unwrap();
940        let dest_port = self.remote_port.load(Ordering::SeqCst);
941        let local_port = self.local_port.load(Ordering::SeqCst);
942        let send_seq = self.send_seq.load(Ordering::SeqCst);
943        let recv_seq = self.recv_seq.load(Ordering::SeqCst);
944
945        let mut header = TcpHeader::new(local_port, dest_port);
946        header.seq_number = send_seq;
947        header.ack_number = recv_seq;
948        header.set_flags(tcp_flags::FIN | tcp_flags::ACK);
949
950        self.send_segment(dest_ip, header, &[], true, false);
951    }
952
953    /// Send TCP segment through IP layer
954    fn send_segment(
955        &self,
956        dest_ip: Ipv4Address,
957        mut header: TcpHeader,
958        data: &[u8],
959        update_seq: bool,
960        is_retransmit: bool,
961    ) {
962        self.ensure_local_ip();
963        let local_ip = self
964            .local_ip
965            .lock()
966            .clone()
967            .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
968
969        let total_len = header.data_offset() + data.len();
970        header.window_size = self.recv_window.load(Ordering::SeqCst);
971
972        // Calculate checksum
973        header.checksum = header.calculate_checksum(local_ip.0, dest_ip.0, data);
974
975        // Serialize header
976        let header_bytes = header.to_bytes();
977
978        // Combine header and data
979        let mut segment = Vec::with_capacity(total_len);
980        segment.extend_from_slice(&header_bytes);
981        segment.extend_from_slice(data);
982
983        // Create IP context
984        let mut ip_context = LayerContext::new();
985        ip_context.set("ip_dst", &dest_ip.0);
986        ip_context.set("ip_protocol", &[6]); // TCP protocol
987
988        // Send through IP layer
989        if let Some(ip_layer) = get_network_manager().get_layer("ip") {
990            if let Ok(()) = ip_layer.send(&segment, &ip_context, &[]) {
991                self.bytes_sent
992                    .fetch_add(segment.len() as u64, Ordering::SeqCst);
993
994                if update_seq {
995                    let mut advance = data.len() as u32;
996                    let flags = header.flags();
997                    if (flags & tcp_flags::SYN) != 0 {
998                        advance = advance.wrapping_add(1);
999                    }
1000                    if (flags & tcp_flags::FIN) != 0 {
1001                        advance = advance.wrapping_add(1);
1002                    }
1003                    if advance != 0 {
1004                        self.send_seq.fetch_add(advance, Ordering::SeqCst);
1005                    }
1006                }
1007
1008                // Track segment for retransmission (only for new transmissions with data or SYN/FIN)
1009                if !is_retransmit {
1010                    let flags = header.flags();
1011                    let has_data = !data.is_empty();
1012                    let is_syn = (flags & tcp_flags::SYN) != 0;
1013                    let is_fin = (flags & tcp_flags::FIN) != 0;
1014
1015                    if has_data || is_syn || is_fin {
1016                        let seq = header.seq_number;
1017                        self.add_unacked_segment(seq, data.to_vec(), flags);
1018                    }
1019                }
1020            }
1021        }
1022    }
1023
1024    /// Send data through socket
1025    pub fn send_data(&self, data: &[u8]) -> Result<usize, SocketError> {
1026        if self.get_state() != TcpState::Established {
1027            return Err(SocketError::NotConnected);
1028        }
1029
1030        let dest_ip = self
1031            .remote_ip
1032            .lock()
1033            .clone()
1034            .ok_or(SocketError::NotConnected)?;
1035        let dest_port = self.remote_port.load(Ordering::SeqCst);
1036
1037        // Check buffer size limit
1038        let mut send_buf = self.send_buffer.lock();
1039        if send_buf.len() + data.len() > MAX_SEND_BUFFER_SIZE {
1040            return Err(SocketError::WouldBlock);
1041        }
1042        send_buf.extend(data);
1043        drop(send_buf);
1044
1045        // Create TCP header
1046        let local_port = self.local_port.load(Ordering::SeqCst);
1047        let local_ip = self
1048            .local_ip
1049            .lock()
1050            .clone()
1051            .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
1052
1053        let send_seq = self.send_seq.load(Ordering::SeqCst);
1054
1055        let mut header = TcpHeader::new(local_port, dest_port);
1056        header.seq_number = send_seq;
1057        header.ack_number = self.recv_ack.load(Ordering::SeqCst);
1058        header.set_flags(tcp_flags::ACK | tcp_flags::PSH);
1059
1060        self.send_segment(dest_ip, header, data, true, false);
1061
1062        Ok(data.len())
1063    }
1064
1065    /// Blocking send - waits for buffer space
1066    pub fn send_blocking(
1067        &self,
1068        data: &[u8],
1069        task_id: usize,
1070        trapframe: &mut crate::arch::Trapframe,
1071    ) -> Result<usize, SocketError> {
1072        if self.get_state() != TcpState::Established {
1073            return Err(SocketError::NotConnected);
1074        }
1075
1076        let dest_ip = self
1077            .remote_ip
1078            .lock()
1079            .clone()
1080            .ok_or(SocketError::NotConnected)?;
1081        let dest_port = self.remote_port.load(Ordering::SeqCst);
1082        let local_port = self.local_port.load(Ordering::SeqCst);
1083        let local_ip = self
1084            .local_ip
1085            .lock()
1086            .clone()
1087            .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
1088
1089        let nonblocking = !self.blocking_mode.load(Ordering::SeqCst);
1090
1091        loop {
1092            {
1093                let mut send_buf = self.send_buffer.lock();
1094                if send_buf.len() + data.len() <= MAX_SEND_BUFFER_SIZE {
1095                    send_buf.extend(data);
1096                    drop(send_buf);
1097
1098                    let send_seq = self.send_seq.load(Ordering::SeqCst);
1099
1100                    let mut header = TcpHeader::new(local_port, dest_port);
1101                    header.seq_number = send_seq;
1102                    header.ack_number = self.recv_ack.load(Ordering::SeqCst);
1103                    header.set_flags(tcp_flags::ACK | tcp_flags::PSH);
1104
1105                    self.send_segment(dest_ip, header, data, true, false);
1106
1107                    return Ok(data.len());
1108                }
1109            }
1110
1111            if nonblocking {
1112                return Err(SocketError::WouldBlock);
1113            }
1114
1115            {
1116                let waker = {
1117                    let mut waker_lock = self.send_waker.lock();
1118                    waker_lock
1119                        .get_or_insert_with(|| {
1120                            Arc::new(crate::sync::Waker::new_interruptible("tcp_send"))
1121                        })
1122                        .clone()
1123                };
1124                waker.wait(task_id, trapframe);
1125            }
1126        }
1127    }
1128
1129    /// Receive data from socket
1130    pub fn recv_data(&self, buffer: &mut [u8]) -> Result<usize, SocketError> {
1131        match self.get_state() {
1132            TcpState::Established | TcpState::CloseWait => {}
1133            _ => return Err(SocketError::NotConnected),
1134        }
1135
1136        let mut recv_buf = self.recv_buffer.lock();
1137        let len = buffer.len().min(recv_buf.len());
1138
1139        for i in 0..len {
1140            buffer[i] = recv_buf.pop_front().unwrap();
1141        }
1142
1143        // Update receive window after reading data
1144        let available = MAX_RECV_BUFFER_SIZE.saturating_sub(recv_buf.len());
1145        self.recv_window
1146            .store(available.min(65535) as u16, Ordering::SeqCst);
1147
1148        Ok(len)
1149    }
1150
1151    /// Blocking receive - waits for data to be available
1152    pub fn recv_blocking(
1153        &self,
1154        buffer: &mut [u8],
1155        task_id: usize,
1156        trapframe: &mut crate::arch::Trapframe,
1157    ) -> Result<usize, SocketError> {
1158        match self.get_state() {
1159            TcpState::Established | TcpState::CloseWait => {}
1160            _ => return Err(SocketError::NotConnected),
1161        }
1162
1163        let nonblocking = !self.blocking_mode.load(Ordering::SeqCst);
1164
1165        loop {
1166            {
1167                let mut recv_buf = self.recv_buffer.lock();
1168                let len = buffer.len().min(recv_buf.len());
1169
1170                if len > 0 {
1171                    for i in 0..len {
1172                        buffer[i] = recv_buf.pop_front().unwrap();
1173                    }
1174
1175                    // Update receive window after reading data
1176                    let available = MAX_RECV_BUFFER_SIZE.saturating_sub(recv_buf.len());
1177                    self.recv_window
1178                        .store(available.min(65535) as u16, Ordering::SeqCst);
1179
1180                    return Ok(len);
1181                }
1182
1183                // Check if connection is closed
1184                let state = self.get_state();
1185                if state == TcpState::Closed
1186                    || state == TcpState::TimeWait
1187                    || state == TcpState::CloseWait
1188                {
1189                    return Ok(0);
1190                }
1191            }
1192
1193            if nonblocking {
1194                return Err(SocketError::WouldBlock);
1195            }
1196
1197            {
1198                let waker = {
1199                    let mut waker_lock = self.recv_waker.lock();
1200                    waker_lock
1201                        .get_or_insert_with(|| {
1202                            Arc::new(crate::sync::Waker::new_interruptible("tcp_recv"))
1203                        })
1204                        .clone()
1205                };
1206                let state = self.get_state();
1207                if state == TcpState::Closed
1208                    || state == TcpState::TimeWait
1209                    || state == TcpState::CloseWait
1210                {
1211                    return Ok(0);
1212                }
1213                waker.wait(task_id, trapframe);
1214            }
1215        }
1216    }
1217
1218    // ===================================================================
1219    // RTO (Retransmission Timeout) - RFC 6298
1220    // ===================================================================
1221
1222    /// Update RTO based on RTT measurement (Jacobson/Karels algorithm)
1223    /// Uses fixed-point arithmetic for better precision in no_std
1224    fn update_rto(&self, rtt_ticks: u32) {
1225        // RFC 6298: RTO calculation
1226        // SRTT = (1 - alpha) * SRTT + alpha * RTT
1227        // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - RTT|
1228        // RTO = SRTT + max(G, K * RTTVAR)
1229        // where alpha = 1/8, beta = 1/4, K = 4, G = clock granularity
1230
1231        const ALPHA_SHIFT: u32 = 3; // alpha = 1/8
1232        const BETA_SHIFT: u32 = 2; // beta = 1/4
1233        const K: u32 = 4; // multiplier for RTTVAR
1234
1235        let srtt = self.srtt.load(Ordering::SeqCst);
1236        let rttvar = self.rttvar.load(Ordering::SeqCst);
1237
1238        if srtt == 0 {
1239            // First RTT measurement
1240            // SRTT = RTT
1241            // RTTVAR = RTT / 2
1242            self.srtt.store(rtt_ticks << 3, Ordering::SeqCst); // 8 * RTT
1243            self.rttvar.store((rtt_ticks << 2) >> 1, Ordering::SeqCst); // 4 * RTT / 2
1244        } else {
1245            // Subsequent measurements
1246            // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - RTT|
1247            // SRTT = (1 - alpha) * SRTT + alpha * RTT
1248            let srtt_val = srtt >> 3; // Divide by 8
1249            let diff = if srtt_val > rtt_ticks {
1250                srtt_val - rtt_ticks
1251            } else {
1252                rtt_ticks - srtt_val
1253            };
1254
1255            // RTTVAR = (3/4) * RTTVAR + (1/4) * |diff|
1256            let new_rttvar = ((rttvar * 3) >> BETA_SHIFT) + ((diff << 2) >> BETA_SHIFT);
1257            self.rttvar.store(new_rttvar, Ordering::SeqCst);
1258
1259            // SRTT = (7/8) * SRTT + (1/8) * RTT
1260            let new_srtt = ((srtt * 7) >> ALPHA_SHIFT) + (rtt_ticks << (3 - ALPHA_SHIFT));
1261            self.srtt.store(new_srtt, Ordering::SeqCst);
1262        }
1263
1264        // RTO = SRTT + max(G, K * RTTVAR)
1265        // G = 1 tick (10ms), K = 4
1266        let srtt_val = self.srtt.load(Ordering::SeqCst) >> 3;
1267        let rttvar_val = self.rttvar.load(Ordering::SeqCst) >> 2;
1268        let mut rto = srtt_val + (K * rttvar_val).max(1);
1269
1270        // Clamp RTO to bounds
1271        // Min: 1 tick (10ms), Max: 12000 ticks (120 seconds)
1272        rto = rto.max(1).min(12000);
1273
1274        self.rto.store(rto, Ordering::SeqCst);
1275    }
1276
1277    /// Get current RTO in milliseconds
1278    fn get_rto_ms(&self) -> u32 {
1279        // Convert ticks to milliseconds (10ms per tick)
1280        self.rto.load(Ordering::SeqCst) * 10
1281    }
1282
1283    /// Get current RTO in ticks
1284    fn get_rto_ticks(&self) -> u32 {
1285        self.rto.load(Ordering::SeqCst)
1286    }
1287
1288    /// Start RTT measurement for a sequence number
1289    fn start_rtt_measurement(&self, seq: u32) {
1290        // Only start timing if not already timing
1291        if self.timing_rtt.load(Ordering::SeqCst) == 0 {
1292            self.timed_seq.store(seq, Ordering::SeqCst);
1293            self.last_send_time
1294                .store(crate::timer::get_tick(), Ordering::SeqCst);
1295            self.timing_rtt.store(1, Ordering::SeqCst);
1296        }
1297    }
1298
1299    /// Stop RTT measurement when ACK is received
1300    fn stop_rtt_measurement(&self, ack_seq: u32) {
1301        // Check if we're timing and if this ACK covers the timed sequence
1302        if self.timing_rtt.load(Ordering::SeqCst) != 0 {
1303            let timed_seq = self.timed_seq.load(Ordering::SeqCst);
1304            // Check if ACK acknowledges the segment we were timing
1305            // Note: Sequence number comparison needs to handle wraparound
1306            if is_seq_acknowledged(timed_seq, ack_seq) {
1307                let send_time = self.last_send_time.load(Ordering::SeqCst);
1308                let now = crate::timer::get_tick();
1309                if now > send_time {
1310                    let rtt = (now - send_time) as u32;
1311                    self.update_rto(rtt);
1312                }
1313                self.timing_rtt.store(0, Ordering::SeqCst);
1314                // Reset retransmission count on successful ACK
1315                self.retrans_count.store(0, Ordering::SeqCst);
1316            }
1317        }
1318    }
1319
1320    /// Exponential backoff for retransmission
1321    fn backoff_rto(&self) {
1322        let count = self.retrans_count.load(Ordering::SeqCst);
1323        if count < 6 {
1324            // Double RTO (exponential backoff), max 64x
1325            let backoff = 1u32 << count.min(6);
1326            let base_rto = self.rto.load(Ordering::SeqCst);
1327            let new_rto = (base_rto * backoff).min(12000); // Max 120 seconds
1328            self.rto.store(new_rto, Ordering::SeqCst);
1329            self.retrans_count.store(count + 1, Ordering::SeqCst);
1330        }
1331    }
1332
1333    /// Check if maximum retransmissions exceeded
1334    fn max_retransmissions_exceeded(&self) -> bool {
1335        self.retrans_count.load(Ordering::SeqCst) >= 12 // Max 12 retransmissions
1336    }
1337
1338    /// Handle retransmission timeout
1339    fn handle_retrans_timeout(&self, seq: u32) {
1340        // Check if socket is still in a valid state for retransmission
1341        let state = self.get_state();
1342        match state {
1343            TcpState::Closed | TcpState::Listen | TcpState::TimeWait => return,
1344            _ => {}
1345        }
1346
1347        // Find the segment to retransmit
1348        let mut unacked = self.unacked_segments.lock();
1349        if let Some(pos) = unacked.iter().position(|seg| seg.seq == seq) {
1350            if let Some(mut seg) = unacked.get(pos).cloned() {
1351                // Check max retransmissions
1352                if seg.tx_count >= 12 {
1353                    // Too many retransmissions, close connection
1354                    self.set_state(TcpState::Closed);
1355                    return;
1356                }
1357
1358                // Exponential backoff
1359                self.backoff_rto();
1360
1361                // Retransmit the segment
1362                if let Some(dest_ip) = self.remote_ip.lock().clone() {
1363                    let dest_port = self.remote_port.load(Ordering::SeqCst);
1364                    let local_port = self.local_port.load(Ordering::SeqCst);
1365
1366                    let mut header = TcpHeader::new(local_port, dest_port);
1367                    header.seq_number = seg.seq;
1368                    header.ack_number = self.recv_ack.load(Ordering::SeqCst);
1369                    header.set_flags(seg.flags);
1370
1371                    // Retransmit (don't update sequence number, mark as retransmission)
1372                    self.send_segment(dest_ip, header, &seg.data, false, true);
1373
1374                    // Update segment info
1375                    seg.tx_count += 1;
1376                    seg.last_tx_time = crate::timer::get_tick();
1377
1378                    // Update in queue
1379                    if let Some(existing) = unacked.get_mut(pos) {
1380                        *existing = seg;
1381                    }
1382
1383                    // Schedule next retransmission timer
1384                    self.schedule_retrans_timer(seq);
1385                }
1386            }
1387        }
1388    }
1389
1390    /// Schedule retransmission timer for a segment
1391    fn schedule_retrans_timer(&self, seq: u32) {
1392        let rto_ticks = self.get_rto_ticks();
1393        let expires = crate::timer::get_tick() + rto_ticks as u64;
1394
1395        let timer: Arc<dyn crate::timer::TimerHandler> = Arc::new(RetransTimer {
1396            socket: self.self_weak.clone(),
1397            seq,
1398        });
1399
1400        let timer_id = crate::timer::add_timer(expires, &timer, 0);
1401
1402        // Store timer ID
1403        *self.retrans_timer_id.lock() = Some(timer_id);
1404    }
1405
1406    /// Cancel retransmission timer
1407    fn cancel_retrans_timer(&self) {
1408        let timer_id = {
1409            let mut timer_lock = self.retrans_timer_id.lock();
1410            timer_lock.take()
1411        };
1412
1413        if let Some(timer_id) = timer_id {
1414            crate::timer::cancel_timer(timer_id);
1415        }
1416    }
1417
1418    /// Add segment to unacked list and schedule retransmission
1419    fn add_unacked_segment(&self, seq: u32, data: Vec<u8>, flags: u8) {
1420        // Check unacked segment limit to prevent memory exhaustion
1421        let mut unacked = self.unacked_segments.lock();
1422        if unacked.len() >= MAX_UNACKED_SEGMENTS {
1423            // Remove oldest segment if limit reached
1424            if let Some(old) = unacked.pop_front() {
1425                // Cancel its timer (timer will be skipped when it fires)
1426            }
1427        }
1428
1429        let segment = UnackedSegment {
1430            seq,
1431            data,
1432            flags,
1433            tx_count: 1,
1434            last_tx_time: crate::timer::get_tick(),
1435        };
1436
1437        unacked.push_back(segment);
1438        drop(unacked);
1439
1440        // Schedule retransmission timer
1441        self.schedule_retrans_timer(seq);
1442
1443        // Start RTT measurement if not already timing
1444        self.start_rtt_measurement(seq);
1445    }
1446
1447    /// Remove acknowledged segments from unacked list
1448    fn remove_acked_segments(&self, ack_seq: u32) {
1449        let mut unacked = self.unacked_segments.lock();
1450        // Remove all segments that are fully acknowledged
1451        unacked.retain(|seg| {
1452            let seg_end = seg.seq.wrapping_add(seg.data.len() as u32);
1453            // Keep if not fully acknowledged
1454            !is_seq_acknowledged(seg_end, ack_seq)
1455        });
1456
1457        // If all segments are acknowledged, cancel timer
1458        if unacked.is_empty() {
1459            drop(unacked);
1460            self.cancel_retrans_timer();
1461        }
1462    }
1463}
1464
1465/// Check if a sequence number is acknowledged by an ACK number
1466/// Handles sequence number wraparound
1467fn is_seq_acknowledged(seq: u32, ack: u32) -> bool {
1468    // Standard TCP sequence number comparison
1469    // Returns true if ack acknowledges seq
1470    seq.wrapping_sub(ack) > (1u32 << 31)
1471}
1472
1473impl SocketObject for TcpSocket {
1474    fn socket_type(&self) -> crate::network::socket::SocketType {
1475        crate::network::socket::SocketType::Stream
1476    }
1477
1478    fn socket_domain(&self) -> crate::network::socket::SocketDomain {
1479        crate::network::socket::SocketDomain::Inet4
1480    }
1481
1482    fn socket_protocol(&self) -> crate::network::socket::SocketProtocol {
1483        crate::network::socket::SocketProtocol::Tcp
1484    }
1485
1486    fn as_any(&self) -> &dyn core::any::Any {
1487        self
1488    }
1489
1490    fn as_selectable(&self) -> Option<&dyn crate::object::capability::Selectable> {
1491        Some(self)
1492    }
1493
1494    fn as_control_ops(&self) -> Option<&dyn crate::object::capability::ControlOps> {
1495        Some(self)
1496    }
1497
1498    fn sendto(
1499        &self,
1500        data: &[u8],
1501        address: &SocketAddress,
1502        flags: u32,
1503    ) -> Result<usize, SocketError> {
1504        let _ = flags;
1505
1506        match address {
1507            SocketAddress::Inet(inet) => {
1508                let addr = Ipv4Address::from_bytes(inet.addr);
1509                let port = inet.port;
1510                // Update remote address
1511                *self.remote_ip.lock() = Some(addr);
1512                self.remote_port.store(port, Ordering::SeqCst);
1513                self.send_data(data)
1514            }
1515            _ => Err(SocketError::InvalidAddress),
1516        }
1517    }
1518
1519    fn recvfrom(
1520        &self,
1521        buffer: &mut [u8],
1522        flags: u32,
1523    ) -> Result<(usize, SocketAddress), SocketError> {
1524        let _ = flags;
1525
1526        let len = self.recv_data(buffer)?;
1527        let remote_ip = self
1528            .remote_ip
1529            .lock()
1530            .clone()
1531            .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
1532        let addr = SocketAddress::Inet(Inet4SocketAddress::new(
1533            remote_ip.0,
1534            self.remote_port.load(Ordering::SeqCst),
1535        ));
1536
1537        Ok((len, addr))
1538    }
1539}
1540
1541impl crate::object::capability::ControlOps for TcpSocket {
1542    fn control(&self, command: u32, arg: usize) -> Result<i32, &'static str> {
1543        match command {
1544            crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK => {
1545                crate::object::capability::selectable::Selectable::set_nonblocking(self, arg != 0);
1546                Ok(0)
1547            }
1548            crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK => Ok(
1549                if crate::object::capability::selectable::Selectable::is_nonblocking(self) {
1550                    1
1551                } else {
1552                    0
1553                },
1554            ),
1555            _ => Err("Unsupported socket control command"),
1556        }
1557    }
1558
1559    fn supported_control_commands(&self) -> alloc::vec::Vec<(u32, &'static str)> {
1560        alloc::vec![
1561            (
1562                crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK,
1563                "Set non-blocking mode",
1564            ),
1565            (
1566                crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK,
1567                "Get non-blocking mode",
1568            ),
1569        ]
1570    }
1571}
1572
1573impl SocketControl for TcpSocket {
1574    fn bind(&self, address: &SocketAddress) -> Result<(), SocketError> {
1575        match address {
1576            SocketAddress::Inet(inet) => {
1577                if inet.port == 0 {
1578                    return Err(SocketError::InvalidAddress);
1579                }
1580
1581                self.register_local_port(inet.port)?;
1582                *self.local_ip.lock() = Some(Ipv4Address::from_bytes(inet.addr));
1583                self.local_port.store(inet.port, Ordering::SeqCst);
1584                Ok(())
1585            }
1586            _ => Err(SocketError::InvalidAddress),
1587        }
1588    }
1589
1590    fn listen(&self, backlog: usize) -> Result<(), SocketError> {
1591        if self.local_port.load(Ordering::SeqCst) == 0 {
1592            return Err(SocketError::InvalidOperation);
1593        }
1594
1595        // Set backlog limit (clamp to reasonable range)
1596        let max_backlog = backlog.max(1).min(128);
1597        self.max_backlog.store(max_backlog, Ordering::SeqCst);
1598
1599        self.set_state(TcpState::Listen);
1600        Ok(())
1601    }
1602
1603    fn connect(&self, address: &SocketAddress) -> Result<(), SocketError> {
1604        match address {
1605            SocketAddress::Inet(inet) => {
1606                let addr = Ipv4Address::from_bytes(inet.addr);
1607                let port = inet.port;
1608                *self.remote_ip.lock() = Some(addr);
1609                self.remote_port.store(port, Ordering::SeqCst);
1610
1611                let local_port = self.local_port.load(Ordering::SeqCst);
1612                if local_port == 0 {
1613                    let port = self.allocate_ephemeral_port();
1614                    self.register_local_port(port)?;
1615                    self.local_port.store(port, Ordering::SeqCst);
1616                }
1617
1618                self.ensure_local_ip();
1619
1620                // Start 3-way handshake
1621                self.send_syn(addr, port);
1622                Ok(())
1623            }
1624            _ => Err(SocketError::InvalidAddress),
1625        }
1626    }
1627
1628    fn accept(&self) -> Result<Arc<dyn SocketObject>, SocketError> {
1629        if self.get_state() != TcpState::Listen {
1630            return Err(SocketError::NotListening);
1631        }
1632
1633        let mut pending = self.pending_accept.lock();
1634        pending
1635            .pop_front()
1636            .map(|socket| socket as Arc<dyn SocketObject>)
1637            .ok_or(SocketError::WouldBlock)
1638    }
1639
1640    fn getpeername(&self) -> Result<SocketAddress, SocketError> {
1641        let ip = self
1642            .remote_ip
1643            .lock()
1644            .clone()
1645            .ok_or(SocketError::NotConnected)?;
1646        let port = self.remote_port.load(Ordering::SeqCst);
1647        Ok(SocketAddress::Inet(Inet4SocketAddress::new(ip.0, port)))
1648    }
1649
1650    fn getsockname(&self) -> Result<SocketAddress, SocketError> {
1651        let ip = self
1652            .local_ip
1653            .lock()
1654            .clone()
1655            .ok_or(SocketError::InvalidAddress)?;
1656        let port = self.local_port.load(Ordering::SeqCst);
1657        Ok(SocketAddress::Inet(Inet4SocketAddress::new(ip.0, port)))
1658    }
1659
1660    fn shutdown(&self, how: crate::network::socket::ShutdownHow) -> Result<(), SocketError> {
1661        match how {
1662            crate::network::socket::ShutdownHow::Write
1663            | crate::network::socket::ShutdownHow::Both => {
1664                self.send_fin();
1665            }
1666            _ => {}
1667        }
1668        Ok(())
1669    }
1670
1671    fn is_connected(&self) -> bool {
1672        self.get_state() == TcpState::Established
1673    }
1674
1675    fn state(&self) -> SocketState {
1676        match self.get_state() {
1677            TcpState::Closed => SocketState::Unconnected,
1678            TcpState::Listen => SocketState::Listening,
1679            TcpState::Established => SocketState::Connected,
1680            _ => SocketState::Unconnected,
1681        }
1682    }
1683}
1684
1685impl crate::ipc::StreamIpcOps for TcpSocket {
1686    fn is_connected(&self) -> bool {
1687        SocketControl::is_connected(self)
1688    }
1689
1690    fn peer_count(&self) -> usize {
1691        if SocketControl::is_connected(self) {
1692            1
1693        } else {
1694            0
1695        }
1696    }
1697
1698    fn description(&self) -> String {
1699        alloc::format!("TCP socket")
1700    }
1701}
1702
1703impl crate::object::capability::StreamOps for TcpSocket {
1704    fn read(&self, buffer: &mut [u8]) -> Result<usize, crate::object::capability::StreamError> {
1705        use crate::object::capability::selectable::Selectable;
1706
1707        if Selectable::is_nonblocking(self) {
1708            return self.recv_data(buffer).map_err(|err| match err {
1709                SocketError::WouldBlock => crate::object::capability::StreamError::WouldBlock,
1710                SocketError::NotConnected => crate::object::capability::StreamError::BrokenPipe,
1711                _ => crate::object::capability::StreamError::Other("tcp recv error".into()),
1712            });
1713        }
1714
1715        let task = match crate::task::mytask() {
1716            Some(task) => task,
1717            None => {
1718                return Err(crate::object::capability::StreamError::Other(
1719                    "tcp recv: no task".into(),
1720                ));
1721            }
1722        };
1723
1724        self.recv_blocking(buffer, task.get_id(), task.get_trapframe())
1725            .map_err(|err| match err {
1726                SocketError::WouldBlock => crate::object::capability::StreamError::WouldBlock,
1727                SocketError::NotConnected => crate::object::capability::StreamError::BrokenPipe,
1728                _ => crate::object::capability::StreamError::Other("tcp recv error".into()),
1729            })
1730    }
1731
1732    fn write(&self, data: &[u8]) -> Result<usize, crate::object::capability::StreamError> {
1733        use crate::object::capability::selectable::Selectable;
1734
1735        if Selectable::is_nonblocking(self) {
1736            return self.send_data(data).map_err(|_| {
1737                crate::object::capability::StreamError::Other("tcp send error".into())
1738            });
1739        }
1740
1741        let task = match crate::task::mytask() {
1742            Some(task) => task,
1743            None => {
1744                return Err(crate::object::capability::StreamError::Other(
1745                    "tcp send: no task".into(),
1746                ));
1747            }
1748        };
1749
1750        self.send_blocking(data, task.get_id(), task.get_trapframe())
1751            .map_err(|_| crate::object::capability::StreamError::Other("tcp send error".into()))
1752    }
1753}
1754
1755impl crate::object::capability::Selectable for TcpSocket {
1756    fn current_ready(
1757        &self,
1758        interest: crate::object::capability::selectable::ReadyInterest,
1759    ) -> crate::object::capability::selectable::ReadySet {
1760        let mut ready = crate::object::capability::selectable::ReadySet::none();
1761
1762        if interest.read {
1763            let recv_buf = self.recv_buffer.lock();
1764            let has_data = !recv_buf.is_empty();
1765            drop(recv_buf);
1766            let state = self.get_state();
1767            ready.read = has_data
1768                || state == TcpState::Closed
1769                || state == TcpState::TimeWait
1770                || state == TcpState::CloseWait;
1771        }
1772
1773        if interest.write {
1774            let send_buf = self.send_buffer.lock();
1775            ready.write = send_buf.len() < MAX_SEND_BUFFER_SIZE;
1776        }
1777
1778        ready
1779    }
1780
1781    fn wait_until_ready(
1782        &self,
1783        interest: crate::object::capability::selectable::ReadyInterest,
1784        trapframe: &mut crate::arch::Trapframe,
1785        timeout_ticks: Option<u64>,
1786    ) -> crate::object::capability::selectable::SelectWaitOutcome {
1787        let current = self.current_ready(interest);
1788        if (interest.read && current.read) || (interest.write && current.write) {
1789            return crate::object::capability::selectable::SelectWaitOutcome::Ready;
1790        }
1791
1792        let task_id = {
1793            use crate::arch::get_cpu;
1794            use crate::sched::scheduler::get_scheduler;
1795            let cpu_id = get_cpu().get_cpuid();
1796            get_scheduler().get_current_task_id(cpu_id).unwrap_or(0)
1797        };
1798
1799        let woke = if interest.read {
1800            let waker = {
1801                let mut waker_lock = self.recv_waker.lock();
1802                waker_lock
1803                    .get_or_insert_with(|| {
1804                        Arc::new(crate::sync::Waker::new_interruptible("tcp_recv"))
1805                    })
1806                    .clone()
1807            };
1808            waker.wait_with_timeout(task_id, trapframe, timeout_ticks)
1809        } else if interest.write {
1810            let waker = {
1811                let mut waker_lock = self.send_waker.lock();
1812                waker_lock
1813                    .get_or_insert_with(|| {
1814                        Arc::new(crate::sync::Waker::new_interruptible("tcp_send"))
1815                    })
1816                    .clone()
1817            };
1818            waker.wait_with_timeout(task_id, trapframe, timeout_ticks)
1819        } else {
1820            true
1821        };
1822
1823        if timeout_ticks.is_some() && !woke {
1824            let after = self.current_ready(interest);
1825            if (interest.read && !after.read) && (interest.write && !after.write) {
1826                return crate::object::capability::selectable::SelectWaitOutcome::TimedOut;
1827            }
1828        }
1829
1830        crate::object::capability::selectable::SelectWaitOutcome::Ready
1831    }
1832
1833    fn set_nonblocking(&self, enabled: bool) {
1834        self.blocking_mode.store(!enabled, Ordering::SeqCst);
1835    }
1836
1837    fn is_nonblocking(&self) -> bool {
1838        !self.blocking_mode.load(Ordering::SeqCst)
1839    }
1840}
1841
1842impl crate::object::capability::CloneOps for TcpSocket {
1843    fn custom_clone(&self) -> crate::object::KernelObject {
1844        crate::object::KernelObject::Socket(TcpSocket::new(self.tcp_layer.clone()))
1845    }
1846}
1847
1848impl Drop for TcpSocket {
1849    fn drop(&mut self) {
1850        // Cancel any pending retransmission timers
1851        self.cancel_retrans_timer();
1852
1853        // Send FIN if connection is still open
1854        let state = self.get_state();
1855        match state {
1856            TcpState::Established | TcpState::SynReceived | TcpState::SynSent => {
1857                // Send FIN to close connection gracefully
1858                let _ = self.remote_ip.lock().clone().map(|dest_ip| {
1859                    let dest_port = self.remote_port.load(Ordering::SeqCst);
1860                    let local_port = self.local_port.load(Ordering::SeqCst);
1861                    let send_seq = self.send_seq.load(Ordering::SeqCst);
1862
1863                    let mut header = TcpHeader::new(local_port, dest_port);
1864                    header.seq_number = send_seq;
1865                    header.set_flags(tcp_flags::FIN);
1866                    self.send_segment(dest_ip, header, &[], true, false);
1867                });
1868            }
1869            _ => {}
1870        }
1871
1872        // Unregister this socket from TcpLayer (not all sockets on this port)
1873        if let Some(layer) = self.tcp_layer.upgrade() {
1874            let port = self.local_port.load(Ordering::SeqCst);
1875            if port != 0 {
1876                layer.unregister_socket(port, &self.self_weak);
1877            }
1878        }
1879    }
1880}
1881
1882/// TCP layer
1883///
1884/// Manages TCP port bindings and routes packets to sockets.
1885pub struct TcpLayer {
1886    /// Port-to-socket mapping for receiving packets
1887    port_map: RwLock<BTreeMap<u16, Vec<Weak<TcpSocket>>>>,
1888    /// Statistics
1889    stats: RwLock<NetworkLayerStats>,
1890    self_weak: Weak<TcpLayer>,
1891}
1892
1893impl TcpLayer {
1894    /// Create a new TCP layer
1895    pub fn new() -> Arc<Self> {
1896        Arc::new_cyclic(|weak| Self {
1897            port_map: RwLock::new(BTreeMap::new()),
1898            stats: RwLock::new(NetworkLayerStats::default()),
1899            self_weak: weak.clone(),
1900        })
1901    }
1902
1903    /// Initialize and register the TCP layer with NetworkManager
1904    ///
1905    /// Registers with NetworkManager and registers itself with Ipv4Layer
1906    /// for protocol number 6 (TCP).
1907    ///
1908    /// # Panics
1909    ///
1910    /// Panics if Ipv4Layer is not registered (must be initialized first).
1911    pub fn init(network_manager: &crate::network::NetworkManager) {
1912        let layer = Self::new();
1913        network_manager.register_layer("tcp", layer.clone());
1914
1915        // Register with IPv4 layer for TCP packets (protocol 6)
1916        let ipv4 = network_manager
1917            .get_layer("ip")
1918            .expect("Ipv4Layer must be initialized before TcpLayer");
1919        ipv4.register_protocol(crate::network::ipv4::protocol::TCP as u16, layer);
1920    }
1921
1922    pub fn create_socket(&self) -> Arc<TcpSocket> {
1923        TcpSocket::new(self.self_weak.clone())
1924    }
1925
1926    /// Register a socket for a specific port
1927    pub fn register_port(&self, port: u16, socket: Weak<TcpSocket>) {
1928        let mut map = self.port_map.write();
1929        let entry = map.entry(port).or_default();
1930        if entry.iter().any(|existing| existing.ptr_eq(&socket)) {
1931            return;
1932        }
1933        if entry.iter().any(|existing| {
1934            existing
1935                .upgrade()
1936                .map(|sock| sock.get_state() == TcpState::Listen)
1937                .unwrap_or(false)
1938                && socket
1939                    .upgrade()
1940                    .map(|sock| sock.get_state() == TcpState::Listen)
1941                    .unwrap_or(false)
1942        }) {
1943            return;
1944        }
1945        entry.push(socket);
1946    }
1947
1948    /// Unregister a specific socket from a port
1949    ///
1950    /// Only removes the given socket from the port's socket list.
1951    /// The port entry itself is removed only when no sockets remain.
1952    pub fn unregister_socket(&self, port: u16, socket: &Weak<TcpSocket>) {
1953        let mut map = self.port_map.write();
1954        if let Some(sockets) = map.get_mut(&port) {
1955            sockets.retain(|existing| !existing.ptr_eq(socket));
1956            if sockets.is_empty() {
1957                map.remove(&port);
1958            }
1959        }
1960    }
1961
1962    /// Find socket for a destination port
1963    pub fn find_socket(
1964        &self,
1965        port: u16,
1966        src_ip: Ipv4Address,
1967        src_port: u16,
1968    ) -> Option<Arc<TcpSocket>> {
1969        let map = self.port_map.read();
1970        let sockets = map.get(&port)?;
1971        let mut listening = None;
1972        for weak in sockets {
1973            if let Some(socket) = weak.upgrade() {
1974                if socket.matches_peer(src_ip, src_port) {
1975                    return Some(socket);
1976                }
1977                if socket.get_state() == TcpState::Listen {
1978                    listening = Some(socket);
1979                }
1980            }
1981        }
1982        listening
1983    }
1984
1985    pub fn find_listening_socket(&self, port: u16) -> Option<Arc<TcpSocket>> {
1986        let map = self.port_map.read();
1987        let sockets = map.get(&port)?;
1988        for weak in sockets {
1989            if let Some(socket) = weak.upgrade() {
1990                if socket.get_state() == TcpState::Listen {
1991                    return Some(socket);
1992                }
1993            }
1994        }
1995        None
1996    }
1997
1998    /// Process incoming TCP segment
1999    pub fn receive_segment(&self, src_ip: Ipv4Address, header: TcpHeader, data: &[u8]) {
2000        let mut stats = self.stats.write();
2001        stats.packets_received += 1;
2002        stats.bytes_received += (header.data_offset() + data.len()) as u64;
2003
2004        let src_port = unsafe { core::ptr::addr_of!(header.src_port).read_unaligned() };
2005        let dst_port = unsafe { core::ptr::addr_of!(header.dst_port).read_unaligned() };
2006
2007        if let Some(socket) = self.find_socket(dst_port, src_ip, src_port) {
2008            socket.process_segment(src_ip, header, data);
2009        }
2010    }
2011}
2012
2013impl NetworkLayer for TcpLayer {
2014    fn register_protocol(&self, _proto_num: u16, _handler: Arc<dyn NetworkLayer>) {
2015        // TCP is typically a leaf protocol
2016    }
2017
2018    fn send(
2019        &self,
2020        _packet: &[u8],
2021        _context: &LayerContext,
2022        _next_layers: &[Arc<dyn NetworkLayer>],
2023    ) -> Result<(), SocketError> {
2024        // TCP send is handled through sockets
2025        Ok(())
2026    }
2027
2028    fn receive(&self, packet: &[u8], context: Option<&LayerContext>) -> Result<(), SocketError> {
2029        let mut src_ip = Ipv4Address::new(0, 0, 0, 0);
2030        let mut dst_ip = Ipv4Address::new(0, 0, 0, 0);
2031        if let Some(ctx) = context {
2032            if let Some(raw) = ctx.get("ip_src") {
2033                if raw.len() >= 4 {
2034                    src_ip = Ipv4Address::new(raw[0], raw[1], raw[2], raw[3]);
2035                }
2036            }
2037            if let Some(raw) = ctx.get("ip_dst") {
2038                if raw.len() >= 4 {
2039                    dst_ip = Ipv4Address::new(raw[0], raw[1], raw[2], raw[3]);
2040                }
2041            }
2042        }
2043        self.receive_packet(src_ip, dst_ip, packet)
2044    }
2045
2046    fn name(&self) -> &'static str {
2047        "TCP"
2048    }
2049
2050    fn stats(&self) -> NetworkLayerStats {
2051        self.stats.read().clone()
2052    }
2053
2054    fn as_any(&self) -> &dyn core::any::Any {
2055        self
2056    }
2057}
2058
2059impl TcpLayer {
2060    /// Receive a TCP segment with IPv4 addressing information
2061    pub fn receive_packet(
2062        &self,
2063        src_ip: Ipv4Address,
2064        _dst_ip: Ipv4Address,
2065        packet: &[u8],
2066    ) -> Result<(), SocketError> {
2067        if packet.len() < 20 {
2068            return Err(SocketError::InvalidPacket);
2069        }
2070
2071        let header = TcpHeader::from_bytes(&packet[..20]).ok_or(SocketError::InvalidPacket)?;
2072
2073        let data_offset = header.data_offset();
2074        if data_offset < 20 || data_offset > packet.len() {
2075            return Err(SocketError::InvalidPacket);
2076        }
2077
2078        let data = &packet[data_offset..];
2079
2080        self.receive_segment(src_ip, header, data);
2081
2082        Ok(())
2083    }
2084}
2085
2086#[cfg(test)]
2087mod tests {
2088    use super::*;
2089
2090    #[test_case]
2091    fn test_tcp_header_creation() {
2092        let header = TcpHeader::new(8080, 80);
2093
2094        let src_port = unsafe { core::ptr::addr_of!(header.src_port).read_unaligned() };
2095        let dst_port = unsafe { core::ptr::addr_of!(header.dst_port).read_unaligned() };
2096        assert_eq!(src_port, 8080);
2097        assert_eq!(dst_port, 80);
2098        assert_eq!(header.flags(), 0);
2099        assert_eq!(header.data_offset(), 20);
2100    }
2101
2102    #[test_case]
2103    fn test_tcp_flags_constants() {
2104        assert_eq!(tcp_flags::FIN, 0x01);
2105        assert_eq!(tcp_flags::SYN, 0x02);
2106        assert_eq!(tcp_flags::RST, 0x04);
2107        assert_eq!(tcp_flags::ACK, 0x10);
2108        assert_eq!(tcp_flags::PSH, 0x08);
2109    }
2110
2111    #[test_case]
2112    fn test_tcp_state_transitions() {
2113        let tcp_layer = TcpLayer::new();
2114        let socket = TcpSocket::new(Arc::downgrade(&tcp_layer));
2115
2116        assert_eq!(socket.get_state(), TcpState::Closed);
2117
2118        socket.set_state(TcpState::Listen);
2119        assert_eq!(socket.get_state(), TcpState::Listen);
2120
2121        socket.set_state(TcpState::SynSent);
2122        assert_eq!(socket.get_state(), TcpState::SynSent);
2123
2124        socket.set_state(TcpState::Established);
2125        assert_eq!(socket.get_state(), TcpState::Established);
2126    }
2127
2128    #[test_case]
2129    fn test_tcp_checksum() {
2130        let local_ip = [192, 168, 1, 100];
2131        let dest_ip = [192, 168, 1, 1];
2132        let data = b"test";
2133
2134        let mut header = TcpHeader::new(1234, 5678);
2135        header.seq_number = 1000;
2136        header.ack_number = 2000;
2137
2138        let checksum = header.calculate_checksum(local_ip, dest_ip, data);
2139
2140        assert_ne!(checksum, 0);
2141    }
2142}