1use alloc::collections::{BTreeMap, VecDeque};
7use alloc::string::String;
8use alloc::sync::{Arc, Weak};
9use alloc::vec::Vec;
10use core::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU64, AtomicUsize, Ordering};
11use spin::{Mutex, RwLock};
12
13use crate::network::ipv4::Ipv4Address;
14use crate::network::protocol_stack::get_network_manager;
15use crate::network::protocol_stack::{LayerContext, NetworkLayer, NetworkLayerStats};
16use crate::network::socket::SocketError;
17use crate::network::socket::{
18 Inet4SocketAddress, SocketAddress, SocketControl, SocketObject, SocketState,
19};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum TcpState {
24 Closed,
25 Listen,
26 SynSent,
27 SynReceived,
28 Established,
29 FinWait1,
30 FinWait2,
31 CloseWait,
32 Closing,
33 LastAck,
34 TimeWait,
35}
36
37pub mod tcp_flags {
39 pub const FIN: u8 = 0x01;
40 pub const SYN: u8 = 0x02;
41 pub const RST: u8 = 0x04;
42 pub const PSH: u8 = 0x08;
43 pub const ACK: u8 = 0x10;
44 pub const URG: u8 = 0x20;
45}
46
47const MAX_SEND_BUFFER_SIZE: usize = 65536; const MAX_RECV_BUFFER_SIZE: usize = 65536; const MAX_UNACKED_SEGMENTS: usize = 256; #[derive(Debug, Clone, Copy)]
54#[repr(C, packed)]
55pub struct TcpHeader {
56 pub src_port: u16,
58 pub dst_port: u16,
60 pub seq_number: u32,
62 pub ack_number: u32,
64 pub data_offset_flags: u16,
66 pub window_size: u16,
68 pub checksum: u16,
70 pub urgent_pointer: u16,
72}
73
74impl TcpHeader {
75 pub fn new(src_port: u16, dst_port: u16) -> Self {
77 Self {
78 src_port,
79 dst_port,
80 seq_number: 0,
81 ack_number: 0,
82 data_offset_flags: 0x5000, window_size: 65535,
84 checksum: 0,
85 urgent_pointer: 0,
86 }
87 }
88
89 pub fn flags(&self) -> u8 {
91 (self.data_offset_flags & 0x3F) as u8
92 }
93
94 pub fn set_flags(&mut self, flags: u8) {
96 self.data_offset_flags = (self.data_offset_flags & 0xFFC0) | (flags as u16 & 0x3F);
97 }
98
99 pub fn data_offset(&self) -> usize {
101 ((self.data_offset_flags >> 12) as usize) * 4
102 }
103
104 pub fn calculate_checksum(&self, src_ip: [u8; 4], dst_ip: [u8; 4], data: &[u8]) -> u16 {
106 let tcp_len = (self.data_offset() + data.len()) as u16;
107 let mut pseudo = Vec::with_capacity(12 + 20 + data.len());
108 pseudo.extend_from_slice(&src_ip);
109 pseudo.extend_from_slice(&dst_ip);
110 pseudo.push(0);
111 pseudo.push(6); pseudo.extend_from_slice(&tcp_len.to_be_bytes());
113
114 let mut header = *self;
115 header.checksum = 0;
116 pseudo.extend_from_slice(&header.to_bytes());
117 pseudo.extend_from_slice(data);
118
119 let mut sum: u32 = 0;
120 for chunk in pseudo.chunks(2) {
121 if chunk.len() == 2 {
122 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
123 } else {
124 sum += (chunk[0] as u32) << 8;
125 }
126 sum = (sum & 0xFFFF) + (sum >> 16);
127 }
128
129 while sum >> 16 != 0 {
130 sum = (sum & 0xFFFF) + (sum >> 16);
131 }
132
133 !sum as u16
134 }
135
136 pub fn to_bytes(&self) -> Vec<u8> {
138 let mut bytes = Vec::with_capacity(20);
139 bytes.extend_from_slice(&self.src_port.to_be_bytes());
140 bytes.extend_from_slice(&self.dst_port.to_be_bytes());
141 bytes.extend_from_slice(&self.seq_number.to_be_bytes());
142 bytes.extend_from_slice(&self.ack_number.to_be_bytes());
143 bytes.extend_from_slice(&self.data_offset_flags.to_be_bytes());
144 bytes.extend_from_slice(&self.window_size.to_be_bytes());
145 bytes.extend_from_slice(&self.checksum.to_be_bytes());
146 bytes.extend_from_slice(&self.urgent_pointer.to_be_bytes());
147 bytes
148 }
149
150 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
152 if bytes.len() < 20 {
153 return None;
154 }
155
156 Some(Self {
157 src_port: u16::from_be_bytes([bytes[0], bytes[1]]),
158 dst_port: u16::from_be_bytes([bytes[2], bytes[3]]),
159 seq_number: u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
160 ack_number: u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]),
161 data_offset_flags: u16::from_be_bytes([bytes[12], bytes[13]]),
162 window_size: u16::from_be_bytes([bytes[14], bytes[15]]),
163 checksum: u16::from_be_bytes([bytes[16], bytes[17]]),
164 urgent_pointer: u16::from_be_bytes([bytes[18], bytes[19]]),
165 })
166 }
167}
168
169#[derive(Clone)]
171struct UnackedSegment {
172 seq: u32,
174 data: Vec<u8>,
176 flags: u8,
178 tx_count: u16,
180 last_tx_time: u64,
182}
183
184#[derive(Clone)]
186struct OutOfOrderSegment {
187 seq: u32,
189 data: Vec<u8>,
191}
192
193impl PartialEq for OutOfOrderSegment {
194 fn eq(&self, other: &Self) -> bool {
195 self.seq == other.seq
196 }
197}
198
199impl Eq for OutOfOrderSegment {}
200
201impl PartialOrd for OutOfOrderSegment {
202 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
203 self.seq.partial_cmp(&other.seq)
204 }
205}
206
207impl Ord for OutOfOrderSegment {
208 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
209 self.seq.cmp(&other.seq)
210 }
211}
212
213struct RetransTimer {
215 socket: Weak<TcpSocket>,
216 seq: u32,
217}
218
219impl crate::timer::TimerHandler for RetransTimer {
220 fn on_timer_expired(self: Arc<Self>, _context: usize) {
221 if let Some(socket) = self.socket.upgrade() {
222 socket.handle_retrans_timeout(self.seq);
223 }
224 }
225}
226
227pub struct TcpSocket {
229 state: Mutex<TcpState>,
231
232 local_ip: Mutex<Option<Ipv4Address>>,
234 pub(crate) local_port: AtomicU16,
236
237 remote_ip: Mutex<Option<Ipv4Address>>,
239 remote_port: AtomicU16,
241
242 send_seq: AtomicU32,
244 send_unacked: AtomicU32,
245 recv_seq: AtomicU32,
246 recv_ack: AtomicU32,
247
248 send_window: AtomicU16,
250 recv_window: AtomicU16,
251
252 send_buffer: Mutex<VecDeque<u8>>,
254 recv_buffer: Mutex<VecDeque<u8>>,
255
256 tcp_layer: Weak<TcpLayer>,
258 self_weak: Weak<TcpSocket>,
260 pending_accept: Mutex<VecDeque<Arc<TcpSocket>>>,
262 max_backlog: AtomicUsize,
264
265 bytes_sent: AtomicU64,
267 bytes_received: AtomicU64,
268
269 srtt: AtomicU32,
272 rttvar: AtomicU32,
274 rto: AtomicU32,
276 retrans_count: AtomicU16,
278 retrans_timer_id: Mutex<Option<u64>>,
280 last_send_time: AtomicU64,
282 timing_rtt: AtomicU16,
284 timed_seq: AtomicU32,
286
287 unacked_segments: Mutex<VecDeque<UnackedSegment>>,
289
290 out_of_order: Mutex<BTreeMap<u32, OutOfOrderSegment>>,
292
293 accept_waker: Mutex<Option<Arc<crate::sync::Waker>>>,
295 recv_waker: Mutex<Option<Arc<crate::sync::Waker>>>,
297 send_waker: Mutex<Option<Arc<crate::sync::Waker>>>,
299 blocking_mode: AtomicBool,
301
302 dup_ack_count: AtomicU16,
304 last_ack_seq: AtomicU32,
306}
307
308impl TcpSocket {
309 pub fn from_socket_object(socket: &Arc<dyn SocketObject>) -> Option<&Self> {
314 socket.as_any().downcast_ref::<TcpSocket>()
315 }
316
317 pub fn accept_blocking(
319 &self,
320 task_id: usize,
321 trapframe: &mut crate::arch::Trapframe,
322 ) -> Result<Arc<dyn SocketObject>, SocketError> {
323 if self.get_state() != TcpState::Listen {
324 return Err(SocketError::NotListening);
325 }
326
327 let nonblocking = !self.blocking_mode.load(Ordering::SeqCst);
328
329 loop {
330 {
331 let mut pending = self.pending_accept.lock();
332 if let Some(socket) = pending.pop_front() {
333 return Ok(socket as Arc<dyn SocketObject>);
334 }
335 }
336
337 if nonblocking {
338 return Err(SocketError::WouldBlock);
339 }
340
341 let waker = {
342 let mut waker_lock = self.accept_waker.lock();
343 waker_lock
344 .get_or_insert_with(|| {
345 Arc::new(crate::sync::Waker::new_interruptible("tcp_accept"))
346 })
347 .clone()
348 };
349 waker.wait(task_id, trapframe);
350 }
351 }
352
353 pub fn new(tcp_layer: Weak<TcpLayer>) -> Arc<Self> {
355 Arc::new_cyclic(|weak| Self {
356 state: Mutex::new(TcpState::Closed),
357 local_ip: Mutex::new(None),
358 local_port: AtomicU16::new(0),
359 remote_ip: Mutex::new(None),
360 remote_port: AtomicU16::new(0),
361 send_seq: AtomicU32::new(0),
362 send_unacked: AtomicU32::new(0),
363 recv_seq: AtomicU32::new(0),
364 recv_ack: AtomicU32::new(0),
365 send_window: AtomicU16::new(65535),
366 recv_window: AtomicU16::new(65535),
367 send_buffer: Mutex::new(VecDeque::new()),
368 recv_buffer: Mutex::new(VecDeque::new()),
369 tcp_layer,
370 self_weak: weak.clone(),
371 pending_accept: Mutex::new(VecDeque::new()),
372 max_backlog: AtomicUsize::new(0),
373 bytes_sent: AtomicU64::new(0),
374 bytes_received: AtomicU64::new(0),
375
376 srtt: AtomicU32::new(0),
379 rttvar: AtomicU32::new(0),
380 rto: AtomicU32::new(100), retrans_count: AtomicU16::new(0),
382 retrans_timer_id: Mutex::new(None),
383 last_send_time: AtomicU64::new(0),
384 timing_rtt: AtomicU16::new(0),
385 timed_seq: AtomicU32::new(0),
386
387 unacked_segments: Mutex::new(VecDeque::new()),
389
390 out_of_order: Mutex::new(BTreeMap::new()),
392
393 accept_waker: Mutex::new(None),
395 recv_waker: Mutex::new(None),
396 send_waker: Mutex::new(None),
397 blocking_mode: AtomicBool::new(true), dup_ack_count: AtomicU16::new(0),
401 last_ack_seq: AtomicU32::new(0),
402 })
403 }
404
405 fn matches_peer(&self, src_ip: Ipv4Address, src_port: u16) -> bool {
406 if self.get_state() == TcpState::Listen {
407 return false;
408 }
409
410 let remote_ip = self.remote_ip.lock().clone();
411 let remote_port = self.remote_port.load(Ordering::SeqCst);
412 match remote_ip {
413 Some(ip) => ip == src_ip && remote_port == src_port,
414 None => false,
415 }
416 }
417
418 fn ensure_local_ip(&self) {
419 let mut local_ip = self.local_ip.lock();
420 let needs_update = match *local_ip {
421 Some(ip) => ip.0 == [0, 0, 0, 0],
422 None => true,
423 };
424 if !needs_update {
425 return;
426 }
427
428 let manager = get_network_manager();
429 if let Some(default_iface) = manager.get_default_interface() {
430 if let Some(ip_layer) = manager.get_layer("ip") {
431 if let Some(ip) = ip_layer
432 .as_any()
433 .downcast_ref::<crate::network::ipv4::Ipv4Layer>()
434 {
435 if let Some(addr) = ip.get_primary_ip(default_iface.name()) {
436 *local_ip = Some(addr);
437 }
438 }
439 }
440 }
441 }
442
443 fn register_local_port(&self, port: u16) -> Result<(), SocketError> {
444 let tcp_layer = self
445 .tcp_layer
446 .upgrade()
447 .ok_or(SocketError::InvalidOperation)?;
448
449 tcp_layer.register_port(port, self.self_weak.clone());
450 Ok(())
451 }
452
453 fn allocate_ephemeral_port(&self) -> u16 {
454 static NEXT_EPHEMERAL_PORT: AtomicU16 = AtomicU16::new(49152);
455
456 let port = NEXT_EPHEMERAL_PORT.fetch_add(1, Ordering::SeqCst);
457 if port == u16::MAX {
458 NEXT_EPHEMERAL_PORT.store(49152, Ordering::SeqCst);
459 }
460 if port < 49152 { 49152 } else { port }
461 }
462
463 pub fn get_state(&self) -> TcpState {
465 *self.state.lock()
466 }
467
468 pub fn set_state(&self, new_state: TcpState) {
470 *self.state.lock() = new_state;
471 }
472
473 pub fn process_segment(&self, src_ip: Ipv4Address, header: TcpHeader, data: &[u8]) {
475 let current_state = self.get_state();
476
477 match current_state {
478 TcpState::Listen => {
479 if header.flags() & tcp_flags::SYN != 0 {
480 let tcp_layer = match self.tcp_layer.upgrade() {
481 Some(layer) => layer,
482 None => return,
483 };
484
485 let child = TcpSocket::new(Arc::downgrade(&tcp_layer));
486 let local_port = self.local_port.load(Ordering::SeqCst);
487 if local_port == 0 {
488 return;
489 }
490
491 if let Some(local_ip) = self.local_ip.lock().clone() {
492 *child.local_ip.lock() = Some(local_ip);
493 } else {
494 child.ensure_local_ip();
495 }
496
497 child.local_port.store(local_port, Ordering::SeqCst);
498 tcp_layer.register_port(local_port, child.self_weak.clone());
499 child.handle_syn_received(src_ip, header);
500
501 let max_backlog = self.max_backlog.load(Ordering::SeqCst);
503 {
504 let mut pending = self.pending_accept.lock();
505 if pending.len() < max_backlog {
506 pending.push_back(child);
507 } else {
508 drop(pending);
510 let rst_port = self.local_port.load(Ordering::SeqCst);
512 let rst_seq = self.send_seq.load(Ordering::SeqCst);
513 let mut rst_header = TcpHeader::new(rst_port, header.src_port);
514 rst_header.seq_number = rst_seq;
515 rst_header.ack_number = child.recv_ack.load(Ordering::SeqCst);
516 rst_header.set_flags(tcp_flags::RST);
517 child.send_segment(src_ip, rst_header, &[], false, false);
518 return;
519 }
520 }
521
522 if let Some(waker) = self.accept_waker.lock().as_ref() {
524 waker.wake_one();
525 }
526 }
527 }
528 TcpState::SynSent => {
529 if header.flags() & (tcp_flags::SYN | tcp_flags::ACK)
530 == (tcp_flags::SYN | tcp_flags::ACK)
531 {
532 self.handle_syn_ack_received(src_ip, header);
534 } else if header.flags() & tcp_flags::RST != 0 {
535 self.handle_rst();
537 }
538 }
539 TcpState::SynReceived => {
540 if header.flags() & tcp_flags::RST != 0 {
541 self.handle_rst();
542 return;
543 }
544
545 if header.flags() & tcp_flags::ACK != 0 {
546 let expected_ack = self.send_seq.load(Ordering::SeqCst);
547 if header.ack_number == expected_ack {
548 self.update_send_window(header.ack_number);
549 self.set_state(TcpState::Established);
550
551 if !data.is_empty() {
552 self.handle_data_segment(src_ip, header, data);
553 } else if header.flags() & tcp_flags::FIN != 0 {
554 self.handle_fin(src_ip, header);
555 }
556 }
557 }
558 }
559 TcpState::Established => {
560 if data.is_empty() {
561 self.handle_control_segment(src_ip, header);
562 } else {
563 self.handle_data_segment(src_ip, header, data);
564 }
565 }
566 _ => {}
567 }
568 }
569
570 fn handle_syn_received(&self, src_ip: Ipv4Address, header: TcpHeader) {
572 *self.remote_ip.lock() = Some(src_ip);
574 self.remote_port.store(header.src_port, Ordering::SeqCst);
575
576 let initial_seq = 1000;
578 let next_recv = header.seq_number.wrapping_add(1);
579 self.send_seq.store(initial_seq, Ordering::SeqCst);
580 self.recv_seq.store(next_recv, Ordering::SeqCst);
581 self.recv_ack.store(next_recv, Ordering::SeqCst);
582
583 let local_port = self.local_port.load(Ordering::SeqCst);
585 let mut syn_ack = TcpHeader::new(local_port, header.src_port);
586 syn_ack.seq_number = initial_seq;
587 syn_ack.ack_number = next_recv;
588 syn_ack.set_flags(tcp_flags::SYN | tcp_flags::ACK);
589 self.send_segment(src_ip, syn_ack, &[], false, false);
590 self.send_seq.fetch_add(1, Ordering::SeqCst);
591 self.set_state(TcpState::SynReceived);
592 }
593
594 fn handle_syn_ack_received(&self, src_ip: Ipv4Address, header: TcpHeader) {
596 *self.remote_ip.lock() = Some(src_ip);
597 self.remote_port.store(header.src_port, Ordering::SeqCst);
598
599 let next_recv = header.seq_number.wrapping_add(1);
600 self.recv_seq.store(next_recv, Ordering::SeqCst);
601 self.recv_ack.store(next_recv, Ordering::SeqCst);
602
603 let acked = header.ack_number;
605 self.send_seq.store(acked, Ordering::SeqCst);
606 self.send_unacked.store(acked, Ordering::SeqCst);
607
608 self.send_ack(src_ip, header.src_port, next_recv);
609
610 self.set_state(TcpState::Established);
611 }
612
613 fn handle_rst(&self) {
615 self.cancel_retrans_timer();
617
618 self.send_buffer.lock().clear();
620
621 self.recv_buffer.lock().clear();
623
624 self.unacked_segments.lock().clear();
626
627 self.out_of_order.lock().clear();
629
630 self.send_seq.store(0, Ordering::SeqCst);
632 self.send_unacked.store(0, Ordering::SeqCst);
633 self.recv_seq.store(0, Ordering::SeqCst);
634 self.recv_ack.store(0, Ordering::SeqCst);
635
636 self.send_window.store(65535, Ordering::SeqCst);
638 self.recv_window.store(65535, Ordering::SeqCst);
639
640 self.srtt.store(0, Ordering::SeqCst);
642 self.rttvar.store(0, Ordering::SeqCst);
643 self.rto.store(100, Ordering::SeqCst); self.retrans_count.store(0, Ordering::SeqCst);
645 self.timing_rtt.store(0, Ordering::SeqCst);
646
647 *self.local_ip.lock() = None;
649 *self.remote_ip.lock() = None;
650 self.local_port.store(0, Ordering::SeqCst);
651 self.remote_port.store(0, Ordering::SeqCst);
652
653 self.set_state(TcpState::Closed);
655 }
656
657 fn handle_control_segment(&self, src_ip: Ipv4Address, header: TcpHeader) {
659 if header.flags() & tcp_flags::RST != 0 {
660 self.handle_rst();
661 return;
662 }
663
664 if header.flags() & tcp_flags::FIN != 0 {
665 self.handle_fin(src_ip, header);
666 }
667
668 if header.flags() & tcp_flags::ACK != 0 {
669 self.update_send_window(header.ack_number);
670 self.stop_rtt_measurement(header.ack_number);
671 self.remove_acked_segments(header.ack_number);
672 }
673 }
674
675 fn handle_data_segment(&self, src_ip: Ipv4Address, header: TcpHeader, data: &[u8]) {
677 if header.flags() & tcp_flags::RST != 0 {
678 self.handle_rst();
679 return;
680 }
681
682 let expected_seq = self.recv_seq.load(Ordering::SeqCst);
684 let segment_seq = header.seq_number;
685 let segment_end = segment_seq.wrapping_add(data.len() as u32);
686
687 if segment_end <= expected_seq {
689 self.send_ack(src_ip, header.src_port, expected_seq);
690 return;
691 }
692
693 if segment_seq > expected_seq {
695 if !data.is_empty() {
696 let mut out_of_order = self.out_of_order.lock();
697 if !out_of_order.contains_key(&segment_seq) {
699 if out_of_order.len() < 128 {
701 let ooo_seg = OutOfOrderSegment {
702 seq: segment_seq,
703 data: data.to_vec(),
704 };
705 out_of_order.insert(segment_seq, ooo_seg);
706 }
707 }
708 drop(out_of_order);
709
710 self.send_ack(src_ip, header.src_port, expected_seq);
712 }
713
714 if header.flags() & tcp_flags::ACK != 0 {
716 self.update_send_window(header.ack_number);
717 self.stop_rtt_measurement(header.ack_number);
718 self.remove_acked_segments(header.ack_number);
719 }
720 return;
721 }
722
723 if !data.is_empty() {
725 let mut recv_buf = self.recv_buffer.lock();
726
727 if recv_buf.len() + data.len() > MAX_RECV_BUFFER_SIZE {
729 drop(recv_buf);
731 self.send_ack(src_ip, header.src_port, expected_seq);
732 return;
733 }
734
735 recv_buf.extend(data);
736 let mut next_seq = expected_seq.wrapping_add(data.len() as u32);
737
738 let mut out_of_order = self.out_of_order.lock();
740 loop {
741 if let Some((_seq, ooo_seg)) = out_of_order.first_key_value() {
743 if *_seq == next_seq {
744 let seq = *_seq;
745 let ooo_seg = out_of_order.remove(&seq).unwrap();
746
747 if recv_buf.len() + ooo_seg.data.len() <= MAX_RECV_BUFFER_SIZE {
749 recv_buf.extend(&ooo_seg.data);
750 next_seq = next_seq.wrapping_add(ooo_seg.data.len() as u32);
751 } else {
752 break;
754 }
755 } else {
756 break;
758 }
759 } else {
760 break;
762 }
763 }
764 drop(out_of_order);
765
766 let available = MAX_RECV_BUFFER_SIZE.saturating_sub(recv_buf.len());
768 self.recv_window
769 .store(available.min(65535) as u16, Ordering::SeqCst);
770
771 self.recv_seq.store(next_seq, Ordering::SeqCst);
772 self.recv_ack.store(next_seq, Ordering::SeqCst);
773 drop(recv_buf);
774
775 if let Some(waker) = self.recv_waker.lock().as_ref() {
777 waker.wake_one();
778 }
779
780 self.send_ack(src_ip, header.src_port, next_seq);
782
783 self.bytes_received
785 .fetch_add(data.len() as u64, Ordering::SeqCst);
786 }
787
788 if header.flags() & tcp_flags::ACK != 0 {
790 self.update_send_window(header.ack_number);
791 self.stop_rtt_measurement(header.ack_number);
792 self.remove_acked_segments(header.ack_number);
793 }
794 }
795
796 fn handle_fin(&self, src_ip: Ipv4Address, header: TcpHeader) {
798 let current_state = self.get_state();
799 let ack_seq = header.seq_number.wrapping_add(1);
800 self.recv_seq.store(ack_seq, Ordering::SeqCst);
801 self.recv_ack.store(ack_seq, Ordering::SeqCst);
802 match current_state {
803 TcpState::Established => {
804 self.send_ack(src_ip, header.src_port, ack_seq);
805 self.set_state(TcpState::CloseWait);
806 if let Some(waker) = self.recv_waker.lock().as_ref() {
807 waker.wake_all();
808 }
809 }
810 TcpState::FinWait1 => {
811 if header.flags() & (tcp_flags::FIN | tcp_flags::ACK)
812 == (tcp_flags::FIN | tcp_flags::ACK)
813 {
814 self.send_ack(src_ip, header.src_port, ack_seq);
815 self.set_state(TcpState::TimeWait);
816 }
817 }
818 _ => {}
819 }
820 }
821
822 fn update_send_window(&self, ack_number: u32) {
824 self.send_unacked.store(ack_number, Ordering::SeqCst);
826
827 let last_ack = self.last_ack_seq.load(Ordering::SeqCst);
829 if ack_number == last_ack {
830 let count = self.dup_ack_count.fetch_add(1, Ordering::SeqCst);
832
833 if count >= 2 {
835 self.fast_retransmit();
836 self.dup_ack_count.store(0, Ordering::SeqCst);
837 }
838 } else {
839 self.last_ack_seq.store(ack_number, Ordering::SeqCst);
841 self.dup_ack_count.store(0, Ordering::SeqCst);
842 }
843
844 if let Some(waker) = self.send_waker.lock().as_ref() {
846 waker.wake_one();
847 }
848 }
849
850 fn fast_retransmit(&self) {
852 let mut unacked = self.unacked_segments.lock();
853 if let Some(first_seg) = unacked.front() {
854 if let Some(dest_ip) = self.remote_ip.lock().clone() {
856 let dest_port = self.remote_port.load(Ordering::SeqCst);
857 let local_port = self.local_port.load(Ordering::SeqCst);
858
859 let mut header = TcpHeader::new(local_port, dest_port);
860 header.seq_number = first_seg.seq;
861 header.ack_number = self.recv_ack.load(Ordering::SeqCst);
862 header.set_flags(first_seg.flags);
863
864 self.send_segment(dest_ip, header, &first_seg.data, false, true);
866
867 if let Some(seg) = unacked.front_mut() {
869 seg.tx_count += 1;
870 seg.last_tx_time = crate::timer::get_tick();
871 }
872
873 self.retrans_count.store(1, Ordering::SeqCst);
875 }
876 }
877 }
878
879 fn send_syn(&self, dest_ip: Ipv4Address, dest_port: u16) {
881 let local_port = self.local_port.load(Ordering::SeqCst);
882
883 let initial_seq = 1000;
884 self.send_seq.store(initial_seq, Ordering::SeqCst);
885
886 let mut header = TcpHeader::new(local_port, dest_port);
887 header.seq_number = initial_seq;
888 header.set_flags(tcp_flags::SYN);
889
890 self.send_segment(dest_ip, header, &[], false, false);
891 self.send_seq.fetch_add(1, Ordering::SeqCst);
892 self.set_state(TcpState::SynSent);
893 }
894
895 fn send_syn_ack(&self, dest_ip: Ipv4Address, dest_port: u16, their_seq: u32, ack_seq: u32) {
897 let local_port = self.local_port.load(Ordering::SeqCst);
898
899 let mut header = TcpHeader::new(local_port, dest_port);
900 header.seq_number = their_seq;
901 header.ack_number = ack_seq;
902 header.set_flags(tcp_flags::SYN | tcp_flags::ACK);
903
904 self.send_segment(dest_ip, header, &[], false, false);
905 self.set_state(TcpState::SynReceived);
906 }
907
908 fn send_ack(&self, dest_ip: Ipv4Address, dest_port: u16, ack_seq: u32) {
910 let local_port = self.local_port.load(Ordering::SeqCst);
911 let send_seq = self.send_seq.load(Ordering::SeqCst);
912
913 let mut header = TcpHeader::new(local_port, dest_port);
914 header.seq_number = send_seq;
915 header.ack_number = ack_seq;
916 header.set_flags(tcp_flags::ACK);
917
918 self.send_segment(dest_ip, header, &[], false, false);
919 }
920
921 fn send_fin(&self) {
923 let dest_ip = self.remote_ip.lock().clone().unwrap();
924 let dest_port = self.remote_port.load(Ordering::SeqCst);
925 let local_port = self.local_port.load(Ordering::SeqCst);
926
927 let send_seq = self.send_seq.load(Ordering::SeqCst);
928
929 let mut header = TcpHeader::new(local_port, dest_port);
930 header.seq_number = send_seq;
931 header.set_flags(tcp_flags::FIN);
932
933 self.send_segment(dest_ip, header, &[], true, false);
934 self.set_state(TcpState::FinWait1);
935 }
936
937 fn send_fin_ack(&self) {
939 let dest_ip = self.remote_ip.lock().clone().unwrap();
940 let dest_port = self.remote_port.load(Ordering::SeqCst);
941 let local_port = self.local_port.load(Ordering::SeqCst);
942 let send_seq = self.send_seq.load(Ordering::SeqCst);
943 let recv_seq = self.recv_seq.load(Ordering::SeqCst);
944
945 let mut header = TcpHeader::new(local_port, dest_port);
946 header.seq_number = send_seq;
947 header.ack_number = recv_seq;
948 header.set_flags(tcp_flags::FIN | tcp_flags::ACK);
949
950 self.send_segment(dest_ip, header, &[], true, false);
951 }
952
953 fn send_segment(
955 &self,
956 dest_ip: Ipv4Address,
957 mut header: TcpHeader,
958 data: &[u8],
959 update_seq: bool,
960 is_retransmit: bool,
961 ) {
962 self.ensure_local_ip();
963 let local_ip = self
964 .local_ip
965 .lock()
966 .clone()
967 .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
968
969 let total_len = header.data_offset() + data.len();
970 header.window_size = self.recv_window.load(Ordering::SeqCst);
971
972 header.checksum = header.calculate_checksum(local_ip.0, dest_ip.0, data);
974
975 let header_bytes = header.to_bytes();
977
978 let mut segment = Vec::with_capacity(total_len);
980 segment.extend_from_slice(&header_bytes);
981 segment.extend_from_slice(data);
982
983 let mut ip_context = LayerContext::new();
985 ip_context.set("ip_dst", &dest_ip.0);
986 ip_context.set("ip_protocol", &[6]); if let Some(ip_layer) = get_network_manager().get_layer("ip") {
990 if let Ok(()) = ip_layer.send(&segment, &ip_context, &[]) {
991 self.bytes_sent
992 .fetch_add(segment.len() as u64, Ordering::SeqCst);
993
994 if update_seq {
995 let mut advance = data.len() as u32;
996 let flags = header.flags();
997 if (flags & tcp_flags::SYN) != 0 {
998 advance = advance.wrapping_add(1);
999 }
1000 if (flags & tcp_flags::FIN) != 0 {
1001 advance = advance.wrapping_add(1);
1002 }
1003 if advance != 0 {
1004 self.send_seq.fetch_add(advance, Ordering::SeqCst);
1005 }
1006 }
1007
1008 if !is_retransmit {
1010 let flags = header.flags();
1011 let has_data = !data.is_empty();
1012 let is_syn = (flags & tcp_flags::SYN) != 0;
1013 let is_fin = (flags & tcp_flags::FIN) != 0;
1014
1015 if has_data || is_syn || is_fin {
1016 let seq = header.seq_number;
1017 self.add_unacked_segment(seq, data.to_vec(), flags);
1018 }
1019 }
1020 }
1021 }
1022 }
1023
1024 pub fn send_data(&self, data: &[u8]) -> Result<usize, SocketError> {
1026 if self.get_state() != TcpState::Established {
1027 return Err(SocketError::NotConnected);
1028 }
1029
1030 let dest_ip = self
1031 .remote_ip
1032 .lock()
1033 .clone()
1034 .ok_or(SocketError::NotConnected)?;
1035 let dest_port = self.remote_port.load(Ordering::SeqCst);
1036
1037 let mut send_buf = self.send_buffer.lock();
1039 if send_buf.len() + data.len() > MAX_SEND_BUFFER_SIZE {
1040 return Err(SocketError::WouldBlock);
1041 }
1042 send_buf.extend(data);
1043 drop(send_buf);
1044
1045 let local_port = self.local_port.load(Ordering::SeqCst);
1047 let local_ip = self
1048 .local_ip
1049 .lock()
1050 .clone()
1051 .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
1052
1053 let send_seq = self.send_seq.load(Ordering::SeqCst);
1054
1055 let mut header = TcpHeader::new(local_port, dest_port);
1056 header.seq_number = send_seq;
1057 header.ack_number = self.recv_ack.load(Ordering::SeqCst);
1058 header.set_flags(tcp_flags::ACK | tcp_flags::PSH);
1059
1060 self.send_segment(dest_ip, header, data, true, false);
1061
1062 Ok(data.len())
1063 }
1064
1065 pub fn send_blocking(
1067 &self,
1068 data: &[u8],
1069 task_id: usize,
1070 trapframe: &mut crate::arch::Trapframe,
1071 ) -> Result<usize, SocketError> {
1072 if self.get_state() != TcpState::Established {
1073 return Err(SocketError::NotConnected);
1074 }
1075
1076 let dest_ip = self
1077 .remote_ip
1078 .lock()
1079 .clone()
1080 .ok_or(SocketError::NotConnected)?;
1081 let dest_port = self.remote_port.load(Ordering::SeqCst);
1082 let local_port = self.local_port.load(Ordering::SeqCst);
1083 let local_ip = self
1084 .local_ip
1085 .lock()
1086 .clone()
1087 .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
1088
1089 let nonblocking = !self.blocking_mode.load(Ordering::SeqCst);
1090
1091 loop {
1092 {
1093 let mut send_buf = self.send_buffer.lock();
1094 if send_buf.len() + data.len() <= MAX_SEND_BUFFER_SIZE {
1095 send_buf.extend(data);
1096 drop(send_buf);
1097
1098 let send_seq = self.send_seq.load(Ordering::SeqCst);
1099
1100 let mut header = TcpHeader::new(local_port, dest_port);
1101 header.seq_number = send_seq;
1102 header.ack_number = self.recv_ack.load(Ordering::SeqCst);
1103 header.set_flags(tcp_flags::ACK | tcp_flags::PSH);
1104
1105 self.send_segment(dest_ip, header, data, true, false);
1106
1107 return Ok(data.len());
1108 }
1109 }
1110
1111 if nonblocking {
1112 return Err(SocketError::WouldBlock);
1113 }
1114
1115 {
1116 let waker = {
1117 let mut waker_lock = self.send_waker.lock();
1118 waker_lock
1119 .get_or_insert_with(|| {
1120 Arc::new(crate::sync::Waker::new_interruptible("tcp_send"))
1121 })
1122 .clone()
1123 };
1124 waker.wait(task_id, trapframe);
1125 }
1126 }
1127 }
1128
1129 pub fn recv_data(&self, buffer: &mut [u8]) -> Result<usize, SocketError> {
1131 match self.get_state() {
1132 TcpState::Established | TcpState::CloseWait => {}
1133 _ => return Err(SocketError::NotConnected),
1134 }
1135
1136 let mut recv_buf = self.recv_buffer.lock();
1137 let len = buffer.len().min(recv_buf.len());
1138
1139 for i in 0..len {
1140 buffer[i] = recv_buf.pop_front().unwrap();
1141 }
1142
1143 let available = MAX_RECV_BUFFER_SIZE.saturating_sub(recv_buf.len());
1145 self.recv_window
1146 .store(available.min(65535) as u16, Ordering::SeqCst);
1147
1148 Ok(len)
1149 }
1150
1151 pub fn recv_blocking(
1153 &self,
1154 buffer: &mut [u8],
1155 task_id: usize,
1156 trapframe: &mut crate::arch::Trapframe,
1157 ) -> Result<usize, SocketError> {
1158 match self.get_state() {
1159 TcpState::Established | TcpState::CloseWait => {}
1160 _ => return Err(SocketError::NotConnected),
1161 }
1162
1163 let nonblocking = !self.blocking_mode.load(Ordering::SeqCst);
1164
1165 loop {
1166 {
1167 let mut recv_buf = self.recv_buffer.lock();
1168 let len = buffer.len().min(recv_buf.len());
1169
1170 if len > 0 {
1171 for i in 0..len {
1172 buffer[i] = recv_buf.pop_front().unwrap();
1173 }
1174
1175 let available = MAX_RECV_BUFFER_SIZE.saturating_sub(recv_buf.len());
1177 self.recv_window
1178 .store(available.min(65535) as u16, Ordering::SeqCst);
1179
1180 return Ok(len);
1181 }
1182
1183 let state = self.get_state();
1185 if state == TcpState::Closed
1186 || state == TcpState::TimeWait
1187 || state == TcpState::CloseWait
1188 {
1189 return Ok(0);
1190 }
1191 }
1192
1193 if nonblocking {
1194 return Err(SocketError::WouldBlock);
1195 }
1196
1197 {
1198 let waker = {
1199 let mut waker_lock = self.recv_waker.lock();
1200 waker_lock
1201 .get_or_insert_with(|| {
1202 Arc::new(crate::sync::Waker::new_interruptible("tcp_recv"))
1203 })
1204 .clone()
1205 };
1206 let state = self.get_state();
1207 if state == TcpState::Closed
1208 || state == TcpState::TimeWait
1209 || state == TcpState::CloseWait
1210 {
1211 return Ok(0);
1212 }
1213 waker.wait(task_id, trapframe);
1214 }
1215 }
1216 }
1217
1218 fn update_rto(&self, rtt_ticks: u32) {
1225 const ALPHA_SHIFT: u32 = 3; const BETA_SHIFT: u32 = 2; const K: u32 = 4; let srtt = self.srtt.load(Ordering::SeqCst);
1236 let rttvar = self.rttvar.load(Ordering::SeqCst);
1237
1238 if srtt == 0 {
1239 self.srtt.store(rtt_ticks << 3, Ordering::SeqCst); self.rttvar.store((rtt_ticks << 2) >> 1, Ordering::SeqCst); } else {
1245 let srtt_val = srtt >> 3; let diff = if srtt_val > rtt_ticks {
1250 srtt_val - rtt_ticks
1251 } else {
1252 rtt_ticks - srtt_val
1253 };
1254
1255 let new_rttvar = ((rttvar * 3) >> BETA_SHIFT) + ((diff << 2) >> BETA_SHIFT);
1257 self.rttvar.store(new_rttvar, Ordering::SeqCst);
1258
1259 let new_srtt = ((srtt * 7) >> ALPHA_SHIFT) + (rtt_ticks << (3 - ALPHA_SHIFT));
1261 self.srtt.store(new_srtt, Ordering::SeqCst);
1262 }
1263
1264 let srtt_val = self.srtt.load(Ordering::SeqCst) >> 3;
1267 let rttvar_val = self.rttvar.load(Ordering::SeqCst) >> 2;
1268 let mut rto = srtt_val + (K * rttvar_val).max(1);
1269
1270 rto = rto.max(1).min(12000);
1273
1274 self.rto.store(rto, Ordering::SeqCst);
1275 }
1276
1277 fn get_rto_ms(&self) -> u32 {
1279 self.rto.load(Ordering::SeqCst) * 10
1281 }
1282
1283 fn get_rto_ticks(&self) -> u32 {
1285 self.rto.load(Ordering::SeqCst)
1286 }
1287
1288 fn start_rtt_measurement(&self, seq: u32) {
1290 if self.timing_rtt.load(Ordering::SeqCst) == 0 {
1292 self.timed_seq.store(seq, Ordering::SeqCst);
1293 self.last_send_time
1294 .store(crate::timer::get_tick(), Ordering::SeqCst);
1295 self.timing_rtt.store(1, Ordering::SeqCst);
1296 }
1297 }
1298
1299 fn stop_rtt_measurement(&self, ack_seq: u32) {
1301 if self.timing_rtt.load(Ordering::SeqCst) != 0 {
1303 let timed_seq = self.timed_seq.load(Ordering::SeqCst);
1304 if is_seq_acknowledged(timed_seq, ack_seq) {
1307 let send_time = self.last_send_time.load(Ordering::SeqCst);
1308 let now = crate::timer::get_tick();
1309 if now > send_time {
1310 let rtt = (now - send_time) as u32;
1311 self.update_rto(rtt);
1312 }
1313 self.timing_rtt.store(0, Ordering::SeqCst);
1314 self.retrans_count.store(0, Ordering::SeqCst);
1316 }
1317 }
1318 }
1319
1320 fn backoff_rto(&self) {
1322 let count = self.retrans_count.load(Ordering::SeqCst);
1323 if count < 6 {
1324 let backoff = 1u32 << count.min(6);
1326 let base_rto = self.rto.load(Ordering::SeqCst);
1327 let new_rto = (base_rto * backoff).min(12000); self.rto.store(new_rto, Ordering::SeqCst);
1329 self.retrans_count.store(count + 1, Ordering::SeqCst);
1330 }
1331 }
1332
1333 fn max_retransmissions_exceeded(&self) -> bool {
1335 self.retrans_count.load(Ordering::SeqCst) >= 12 }
1337
1338 fn handle_retrans_timeout(&self, seq: u32) {
1340 let state = self.get_state();
1342 match state {
1343 TcpState::Closed | TcpState::Listen | TcpState::TimeWait => return,
1344 _ => {}
1345 }
1346
1347 let mut unacked = self.unacked_segments.lock();
1349 if let Some(pos) = unacked.iter().position(|seg| seg.seq == seq) {
1350 if let Some(mut seg) = unacked.get(pos).cloned() {
1351 if seg.tx_count >= 12 {
1353 self.set_state(TcpState::Closed);
1355 return;
1356 }
1357
1358 self.backoff_rto();
1360
1361 if let Some(dest_ip) = self.remote_ip.lock().clone() {
1363 let dest_port = self.remote_port.load(Ordering::SeqCst);
1364 let local_port = self.local_port.load(Ordering::SeqCst);
1365
1366 let mut header = TcpHeader::new(local_port, dest_port);
1367 header.seq_number = seg.seq;
1368 header.ack_number = self.recv_ack.load(Ordering::SeqCst);
1369 header.set_flags(seg.flags);
1370
1371 self.send_segment(dest_ip, header, &seg.data, false, true);
1373
1374 seg.tx_count += 1;
1376 seg.last_tx_time = crate::timer::get_tick();
1377
1378 if let Some(existing) = unacked.get_mut(pos) {
1380 *existing = seg;
1381 }
1382
1383 self.schedule_retrans_timer(seq);
1385 }
1386 }
1387 }
1388 }
1389
1390 fn schedule_retrans_timer(&self, seq: u32) {
1392 let rto_ticks = self.get_rto_ticks();
1393 let expires = crate::timer::get_tick() + rto_ticks as u64;
1394
1395 let timer: Arc<dyn crate::timer::TimerHandler> = Arc::new(RetransTimer {
1396 socket: self.self_weak.clone(),
1397 seq,
1398 });
1399
1400 let timer_id = crate::timer::add_timer(expires, &timer, 0);
1401
1402 *self.retrans_timer_id.lock() = Some(timer_id);
1404 }
1405
1406 fn cancel_retrans_timer(&self) {
1408 let timer_id = {
1409 let mut timer_lock = self.retrans_timer_id.lock();
1410 timer_lock.take()
1411 };
1412
1413 if let Some(timer_id) = timer_id {
1414 crate::timer::cancel_timer(timer_id);
1415 }
1416 }
1417
1418 fn add_unacked_segment(&self, seq: u32, data: Vec<u8>, flags: u8) {
1420 let mut unacked = self.unacked_segments.lock();
1422 if unacked.len() >= MAX_UNACKED_SEGMENTS {
1423 if let Some(old) = unacked.pop_front() {
1425 }
1427 }
1428
1429 let segment = UnackedSegment {
1430 seq,
1431 data,
1432 flags,
1433 tx_count: 1,
1434 last_tx_time: crate::timer::get_tick(),
1435 };
1436
1437 unacked.push_back(segment);
1438 drop(unacked);
1439
1440 self.schedule_retrans_timer(seq);
1442
1443 self.start_rtt_measurement(seq);
1445 }
1446
1447 fn remove_acked_segments(&self, ack_seq: u32) {
1449 let mut unacked = self.unacked_segments.lock();
1450 unacked.retain(|seg| {
1452 let seg_end = seg.seq.wrapping_add(seg.data.len() as u32);
1453 !is_seq_acknowledged(seg_end, ack_seq)
1455 });
1456
1457 if unacked.is_empty() {
1459 drop(unacked);
1460 self.cancel_retrans_timer();
1461 }
1462 }
1463}
1464
1465fn is_seq_acknowledged(seq: u32, ack: u32) -> bool {
1468 seq.wrapping_sub(ack) > (1u32 << 31)
1471}
1472
1473impl SocketObject for TcpSocket {
1474 fn socket_type(&self) -> crate::network::socket::SocketType {
1475 crate::network::socket::SocketType::Stream
1476 }
1477
1478 fn socket_domain(&self) -> crate::network::socket::SocketDomain {
1479 crate::network::socket::SocketDomain::Inet4
1480 }
1481
1482 fn socket_protocol(&self) -> crate::network::socket::SocketProtocol {
1483 crate::network::socket::SocketProtocol::Tcp
1484 }
1485
1486 fn as_any(&self) -> &dyn core::any::Any {
1487 self
1488 }
1489
1490 fn as_selectable(&self) -> Option<&dyn crate::object::capability::Selectable> {
1491 Some(self)
1492 }
1493
1494 fn as_control_ops(&self) -> Option<&dyn crate::object::capability::ControlOps> {
1495 Some(self)
1496 }
1497
1498 fn sendto(
1499 &self,
1500 data: &[u8],
1501 address: &SocketAddress,
1502 flags: u32,
1503 ) -> Result<usize, SocketError> {
1504 let _ = flags;
1505
1506 match address {
1507 SocketAddress::Inet(inet) => {
1508 let addr = Ipv4Address::from_bytes(inet.addr);
1509 let port = inet.port;
1510 *self.remote_ip.lock() = Some(addr);
1512 self.remote_port.store(port, Ordering::SeqCst);
1513 self.send_data(data)
1514 }
1515 _ => Err(SocketError::InvalidAddress),
1516 }
1517 }
1518
1519 fn recvfrom(
1520 &self,
1521 buffer: &mut [u8],
1522 flags: u32,
1523 ) -> Result<(usize, SocketAddress), SocketError> {
1524 let _ = flags;
1525
1526 let len = self.recv_data(buffer)?;
1527 let remote_ip = self
1528 .remote_ip
1529 .lock()
1530 .clone()
1531 .unwrap_or(Ipv4Address::new(0, 0, 0, 0));
1532 let addr = SocketAddress::Inet(Inet4SocketAddress::new(
1533 remote_ip.0,
1534 self.remote_port.load(Ordering::SeqCst),
1535 ));
1536
1537 Ok((len, addr))
1538 }
1539}
1540
1541impl crate::object::capability::ControlOps for TcpSocket {
1542 fn control(&self, command: u32, arg: usize) -> Result<i32, &'static str> {
1543 match command {
1544 crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK => {
1545 crate::object::capability::selectable::Selectable::set_nonblocking(self, arg != 0);
1546 Ok(0)
1547 }
1548 crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK => Ok(
1549 if crate::object::capability::selectable::Selectable::is_nonblocking(self) {
1550 1
1551 } else {
1552 0
1553 },
1554 ),
1555 _ => Err("Unsupported socket control command"),
1556 }
1557 }
1558
1559 fn supported_control_commands(&self) -> alloc::vec::Vec<(u32, &'static str)> {
1560 alloc::vec![
1561 (
1562 crate::network::socket::socket_ctl::SCTL_SOCKET_SET_NONBLOCK,
1563 "Set non-blocking mode",
1564 ),
1565 (
1566 crate::network::socket::socket_ctl::SCTL_SOCKET_GET_NONBLOCK,
1567 "Get non-blocking mode",
1568 ),
1569 ]
1570 }
1571}
1572
1573impl SocketControl for TcpSocket {
1574 fn bind(&self, address: &SocketAddress) -> Result<(), SocketError> {
1575 match address {
1576 SocketAddress::Inet(inet) => {
1577 if inet.port == 0 {
1578 return Err(SocketError::InvalidAddress);
1579 }
1580
1581 self.register_local_port(inet.port)?;
1582 *self.local_ip.lock() = Some(Ipv4Address::from_bytes(inet.addr));
1583 self.local_port.store(inet.port, Ordering::SeqCst);
1584 Ok(())
1585 }
1586 _ => Err(SocketError::InvalidAddress),
1587 }
1588 }
1589
1590 fn listen(&self, backlog: usize) -> Result<(), SocketError> {
1591 if self.local_port.load(Ordering::SeqCst) == 0 {
1592 return Err(SocketError::InvalidOperation);
1593 }
1594
1595 let max_backlog = backlog.max(1).min(128);
1597 self.max_backlog.store(max_backlog, Ordering::SeqCst);
1598
1599 self.set_state(TcpState::Listen);
1600 Ok(())
1601 }
1602
1603 fn connect(&self, address: &SocketAddress) -> Result<(), SocketError> {
1604 match address {
1605 SocketAddress::Inet(inet) => {
1606 let addr = Ipv4Address::from_bytes(inet.addr);
1607 let port = inet.port;
1608 *self.remote_ip.lock() = Some(addr);
1609 self.remote_port.store(port, Ordering::SeqCst);
1610
1611 let local_port = self.local_port.load(Ordering::SeqCst);
1612 if local_port == 0 {
1613 let port = self.allocate_ephemeral_port();
1614 self.register_local_port(port)?;
1615 self.local_port.store(port, Ordering::SeqCst);
1616 }
1617
1618 self.ensure_local_ip();
1619
1620 self.send_syn(addr, port);
1622 Ok(())
1623 }
1624 _ => Err(SocketError::InvalidAddress),
1625 }
1626 }
1627
1628 fn accept(&self) -> Result<Arc<dyn SocketObject>, SocketError> {
1629 if self.get_state() != TcpState::Listen {
1630 return Err(SocketError::NotListening);
1631 }
1632
1633 let mut pending = self.pending_accept.lock();
1634 pending
1635 .pop_front()
1636 .map(|socket| socket as Arc<dyn SocketObject>)
1637 .ok_or(SocketError::WouldBlock)
1638 }
1639
1640 fn getpeername(&self) -> Result<SocketAddress, SocketError> {
1641 let ip = self
1642 .remote_ip
1643 .lock()
1644 .clone()
1645 .ok_or(SocketError::NotConnected)?;
1646 let port = self.remote_port.load(Ordering::SeqCst);
1647 Ok(SocketAddress::Inet(Inet4SocketAddress::new(ip.0, port)))
1648 }
1649
1650 fn getsockname(&self) -> Result<SocketAddress, SocketError> {
1651 let ip = self
1652 .local_ip
1653 .lock()
1654 .clone()
1655 .ok_or(SocketError::InvalidAddress)?;
1656 let port = self.local_port.load(Ordering::SeqCst);
1657 Ok(SocketAddress::Inet(Inet4SocketAddress::new(ip.0, port)))
1658 }
1659
1660 fn shutdown(&self, how: crate::network::socket::ShutdownHow) -> Result<(), SocketError> {
1661 match how {
1662 crate::network::socket::ShutdownHow::Write
1663 | crate::network::socket::ShutdownHow::Both => {
1664 self.send_fin();
1665 }
1666 _ => {}
1667 }
1668 Ok(())
1669 }
1670
1671 fn is_connected(&self) -> bool {
1672 self.get_state() == TcpState::Established
1673 }
1674
1675 fn state(&self) -> SocketState {
1676 match self.get_state() {
1677 TcpState::Closed => SocketState::Unconnected,
1678 TcpState::Listen => SocketState::Listening,
1679 TcpState::Established => SocketState::Connected,
1680 _ => SocketState::Unconnected,
1681 }
1682 }
1683}
1684
1685impl crate::ipc::StreamIpcOps for TcpSocket {
1686 fn is_connected(&self) -> bool {
1687 SocketControl::is_connected(self)
1688 }
1689
1690 fn peer_count(&self) -> usize {
1691 if SocketControl::is_connected(self) {
1692 1
1693 } else {
1694 0
1695 }
1696 }
1697
1698 fn description(&self) -> String {
1699 alloc::format!("TCP socket")
1700 }
1701}
1702
1703impl crate::object::capability::StreamOps for TcpSocket {
1704 fn read(&self, buffer: &mut [u8]) -> Result<usize, crate::object::capability::StreamError> {
1705 use crate::object::capability::selectable::Selectable;
1706
1707 if Selectable::is_nonblocking(self) {
1708 return self.recv_data(buffer).map_err(|err| match err {
1709 SocketError::WouldBlock => crate::object::capability::StreamError::WouldBlock,
1710 SocketError::NotConnected => crate::object::capability::StreamError::BrokenPipe,
1711 _ => crate::object::capability::StreamError::Other("tcp recv error".into()),
1712 });
1713 }
1714
1715 let task = match crate::task::mytask() {
1716 Some(task) => task,
1717 None => {
1718 return Err(crate::object::capability::StreamError::Other(
1719 "tcp recv: no task".into(),
1720 ));
1721 }
1722 };
1723
1724 self.recv_blocking(buffer, task.get_id(), task.get_trapframe())
1725 .map_err(|err| match err {
1726 SocketError::WouldBlock => crate::object::capability::StreamError::WouldBlock,
1727 SocketError::NotConnected => crate::object::capability::StreamError::BrokenPipe,
1728 _ => crate::object::capability::StreamError::Other("tcp recv error".into()),
1729 })
1730 }
1731
1732 fn write(&self, data: &[u8]) -> Result<usize, crate::object::capability::StreamError> {
1733 use crate::object::capability::selectable::Selectable;
1734
1735 if Selectable::is_nonblocking(self) {
1736 return self.send_data(data).map_err(|_| {
1737 crate::object::capability::StreamError::Other("tcp send error".into())
1738 });
1739 }
1740
1741 let task = match crate::task::mytask() {
1742 Some(task) => task,
1743 None => {
1744 return Err(crate::object::capability::StreamError::Other(
1745 "tcp send: no task".into(),
1746 ));
1747 }
1748 };
1749
1750 self.send_blocking(data, task.get_id(), task.get_trapframe())
1751 .map_err(|_| crate::object::capability::StreamError::Other("tcp send error".into()))
1752 }
1753}
1754
1755impl crate::object::capability::Selectable for TcpSocket {
1756 fn current_ready(
1757 &self,
1758 interest: crate::object::capability::selectable::ReadyInterest,
1759 ) -> crate::object::capability::selectable::ReadySet {
1760 let mut ready = crate::object::capability::selectable::ReadySet::none();
1761
1762 if interest.read {
1763 let recv_buf = self.recv_buffer.lock();
1764 let has_data = !recv_buf.is_empty();
1765 drop(recv_buf);
1766 let state = self.get_state();
1767 ready.read = has_data
1768 || state == TcpState::Closed
1769 || state == TcpState::TimeWait
1770 || state == TcpState::CloseWait;
1771 }
1772
1773 if interest.write {
1774 let send_buf = self.send_buffer.lock();
1775 ready.write = send_buf.len() < MAX_SEND_BUFFER_SIZE;
1776 }
1777
1778 ready
1779 }
1780
1781 fn wait_until_ready(
1782 &self,
1783 interest: crate::object::capability::selectable::ReadyInterest,
1784 trapframe: &mut crate::arch::Trapframe,
1785 timeout_ticks: Option<u64>,
1786 ) -> crate::object::capability::selectable::SelectWaitOutcome {
1787 let current = self.current_ready(interest);
1788 if (interest.read && current.read) || (interest.write && current.write) {
1789 return crate::object::capability::selectable::SelectWaitOutcome::Ready;
1790 }
1791
1792 let task_id = {
1793 use crate::arch::get_cpu;
1794 use crate::sched::scheduler::get_scheduler;
1795 let cpu_id = get_cpu().get_cpuid();
1796 get_scheduler().get_current_task_id(cpu_id).unwrap_or(0)
1797 };
1798
1799 let woke = if interest.read {
1800 let waker = {
1801 let mut waker_lock = self.recv_waker.lock();
1802 waker_lock
1803 .get_or_insert_with(|| {
1804 Arc::new(crate::sync::Waker::new_interruptible("tcp_recv"))
1805 })
1806 .clone()
1807 };
1808 waker.wait_with_timeout(task_id, trapframe, timeout_ticks)
1809 } else if interest.write {
1810 let waker = {
1811 let mut waker_lock = self.send_waker.lock();
1812 waker_lock
1813 .get_or_insert_with(|| {
1814 Arc::new(crate::sync::Waker::new_interruptible("tcp_send"))
1815 })
1816 .clone()
1817 };
1818 waker.wait_with_timeout(task_id, trapframe, timeout_ticks)
1819 } else {
1820 true
1821 };
1822
1823 if timeout_ticks.is_some() && !woke {
1824 let after = self.current_ready(interest);
1825 if (interest.read && !after.read) && (interest.write && !after.write) {
1826 return crate::object::capability::selectable::SelectWaitOutcome::TimedOut;
1827 }
1828 }
1829
1830 crate::object::capability::selectable::SelectWaitOutcome::Ready
1831 }
1832
1833 fn set_nonblocking(&self, enabled: bool) {
1834 self.blocking_mode.store(!enabled, Ordering::SeqCst);
1835 }
1836
1837 fn is_nonblocking(&self) -> bool {
1838 !self.blocking_mode.load(Ordering::SeqCst)
1839 }
1840}
1841
1842impl crate::object::capability::CloneOps for TcpSocket {
1843 fn custom_clone(&self) -> crate::object::KernelObject {
1844 crate::object::KernelObject::Socket(TcpSocket::new(self.tcp_layer.clone()))
1845 }
1846}
1847
1848impl Drop for TcpSocket {
1849 fn drop(&mut self) {
1850 self.cancel_retrans_timer();
1852
1853 let state = self.get_state();
1855 match state {
1856 TcpState::Established | TcpState::SynReceived | TcpState::SynSent => {
1857 let _ = self.remote_ip.lock().clone().map(|dest_ip| {
1859 let dest_port = self.remote_port.load(Ordering::SeqCst);
1860 let local_port = self.local_port.load(Ordering::SeqCst);
1861 let send_seq = self.send_seq.load(Ordering::SeqCst);
1862
1863 let mut header = TcpHeader::new(local_port, dest_port);
1864 header.seq_number = send_seq;
1865 header.set_flags(tcp_flags::FIN);
1866 self.send_segment(dest_ip, header, &[], true, false);
1867 });
1868 }
1869 _ => {}
1870 }
1871
1872 if let Some(layer) = self.tcp_layer.upgrade() {
1874 let port = self.local_port.load(Ordering::SeqCst);
1875 if port != 0 {
1876 layer.unregister_socket(port, &self.self_weak);
1877 }
1878 }
1879 }
1880}
1881
1882pub struct TcpLayer {
1886 port_map: RwLock<BTreeMap<u16, Vec<Weak<TcpSocket>>>>,
1888 stats: RwLock<NetworkLayerStats>,
1890 self_weak: Weak<TcpLayer>,
1891}
1892
1893impl TcpLayer {
1894 pub fn new() -> Arc<Self> {
1896 Arc::new_cyclic(|weak| Self {
1897 port_map: RwLock::new(BTreeMap::new()),
1898 stats: RwLock::new(NetworkLayerStats::default()),
1899 self_weak: weak.clone(),
1900 })
1901 }
1902
1903 pub fn init(network_manager: &crate::network::NetworkManager) {
1912 let layer = Self::new();
1913 network_manager.register_layer("tcp", layer.clone());
1914
1915 let ipv4 = network_manager
1917 .get_layer("ip")
1918 .expect("Ipv4Layer must be initialized before TcpLayer");
1919 ipv4.register_protocol(crate::network::ipv4::protocol::TCP as u16, layer);
1920 }
1921
1922 pub fn create_socket(&self) -> Arc<TcpSocket> {
1923 TcpSocket::new(self.self_weak.clone())
1924 }
1925
1926 pub fn register_port(&self, port: u16, socket: Weak<TcpSocket>) {
1928 let mut map = self.port_map.write();
1929 let entry = map.entry(port).or_default();
1930 if entry.iter().any(|existing| existing.ptr_eq(&socket)) {
1931 return;
1932 }
1933 if entry.iter().any(|existing| {
1934 existing
1935 .upgrade()
1936 .map(|sock| sock.get_state() == TcpState::Listen)
1937 .unwrap_or(false)
1938 && socket
1939 .upgrade()
1940 .map(|sock| sock.get_state() == TcpState::Listen)
1941 .unwrap_or(false)
1942 }) {
1943 return;
1944 }
1945 entry.push(socket);
1946 }
1947
1948 pub fn unregister_socket(&self, port: u16, socket: &Weak<TcpSocket>) {
1953 let mut map = self.port_map.write();
1954 if let Some(sockets) = map.get_mut(&port) {
1955 sockets.retain(|existing| !existing.ptr_eq(socket));
1956 if sockets.is_empty() {
1957 map.remove(&port);
1958 }
1959 }
1960 }
1961
1962 pub fn find_socket(
1964 &self,
1965 port: u16,
1966 src_ip: Ipv4Address,
1967 src_port: u16,
1968 ) -> Option<Arc<TcpSocket>> {
1969 let map = self.port_map.read();
1970 let sockets = map.get(&port)?;
1971 let mut listening = None;
1972 for weak in sockets {
1973 if let Some(socket) = weak.upgrade() {
1974 if socket.matches_peer(src_ip, src_port) {
1975 return Some(socket);
1976 }
1977 if socket.get_state() == TcpState::Listen {
1978 listening = Some(socket);
1979 }
1980 }
1981 }
1982 listening
1983 }
1984
1985 pub fn find_listening_socket(&self, port: u16) -> Option<Arc<TcpSocket>> {
1986 let map = self.port_map.read();
1987 let sockets = map.get(&port)?;
1988 for weak in sockets {
1989 if let Some(socket) = weak.upgrade() {
1990 if socket.get_state() == TcpState::Listen {
1991 return Some(socket);
1992 }
1993 }
1994 }
1995 None
1996 }
1997
1998 pub fn receive_segment(&self, src_ip: Ipv4Address, header: TcpHeader, data: &[u8]) {
2000 let mut stats = self.stats.write();
2001 stats.packets_received += 1;
2002 stats.bytes_received += (header.data_offset() + data.len()) as u64;
2003
2004 let src_port = unsafe { core::ptr::addr_of!(header.src_port).read_unaligned() };
2005 let dst_port = unsafe { core::ptr::addr_of!(header.dst_port).read_unaligned() };
2006
2007 if let Some(socket) = self.find_socket(dst_port, src_ip, src_port) {
2008 socket.process_segment(src_ip, header, data);
2009 }
2010 }
2011}
2012
2013impl NetworkLayer for TcpLayer {
2014 fn register_protocol(&self, _proto_num: u16, _handler: Arc<dyn NetworkLayer>) {
2015 }
2017
2018 fn send(
2019 &self,
2020 _packet: &[u8],
2021 _context: &LayerContext,
2022 _next_layers: &[Arc<dyn NetworkLayer>],
2023 ) -> Result<(), SocketError> {
2024 Ok(())
2026 }
2027
2028 fn receive(&self, packet: &[u8], context: Option<&LayerContext>) -> Result<(), SocketError> {
2029 let mut src_ip = Ipv4Address::new(0, 0, 0, 0);
2030 let mut dst_ip = Ipv4Address::new(0, 0, 0, 0);
2031 if let Some(ctx) = context {
2032 if let Some(raw) = ctx.get("ip_src") {
2033 if raw.len() >= 4 {
2034 src_ip = Ipv4Address::new(raw[0], raw[1], raw[2], raw[3]);
2035 }
2036 }
2037 if let Some(raw) = ctx.get("ip_dst") {
2038 if raw.len() >= 4 {
2039 dst_ip = Ipv4Address::new(raw[0], raw[1], raw[2], raw[3]);
2040 }
2041 }
2042 }
2043 self.receive_packet(src_ip, dst_ip, packet)
2044 }
2045
2046 fn name(&self) -> &'static str {
2047 "TCP"
2048 }
2049
2050 fn stats(&self) -> NetworkLayerStats {
2051 self.stats.read().clone()
2052 }
2053
2054 fn as_any(&self) -> &dyn core::any::Any {
2055 self
2056 }
2057}
2058
2059impl TcpLayer {
2060 pub fn receive_packet(
2062 &self,
2063 src_ip: Ipv4Address,
2064 _dst_ip: Ipv4Address,
2065 packet: &[u8],
2066 ) -> Result<(), SocketError> {
2067 if packet.len() < 20 {
2068 return Err(SocketError::InvalidPacket);
2069 }
2070
2071 let header = TcpHeader::from_bytes(&packet[..20]).ok_or(SocketError::InvalidPacket)?;
2072
2073 let data_offset = header.data_offset();
2074 if data_offset < 20 || data_offset > packet.len() {
2075 return Err(SocketError::InvalidPacket);
2076 }
2077
2078 let data = &packet[data_offset..];
2079
2080 self.receive_segment(src_ip, header, data);
2081
2082 Ok(())
2083 }
2084}
2085
2086#[cfg(test)]
2087mod tests {
2088 use super::*;
2089
2090 #[test_case]
2091 fn test_tcp_header_creation() {
2092 let header = TcpHeader::new(8080, 80);
2093
2094 let src_port = unsafe { core::ptr::addr_of!(header.src_port).read_unaligned() };
2095 let dst_port = unsafe { core::ptr::addr_of!(header.dst_port).read_unaligned() };
2096 assert_eq!(src_port, 8080);
2097 assert_eq!(dst_port, 80);
2098 assert_eq!(header.flags(), 0);
2099 assert_eq!(header.data_offset(), 20);
2100 }
2101
2102 #[test_case]
2103 fn test_tcp_flags_constants() {
2104 assert_eq!(tcp_flags::FIN, 0x01);
2105 assert_eq!(tcp_flags::SYN, 0x02);
2106 assert_eq!(tcp_flags::RST, 0x04);
2107 assert_eq!(tcp_flags::ACK, 0x10);
2108 assert_eq!(tcp_flags::PSH, 0x08);
2109 }
2110
2111 #[test_case]
2112 fn test_tcp_state_transitions() {
2113 let tcp_layer = TcpLayer::new();
2114 let socket = TcpSocket::new(Arc::downgrade(&tcp_layer));
2115
2116 assert_eq!(socket.get_state(), TcpState::Closed);
2117
2118 socket.set_state(TcpState::Listen);
2119 assert_eq!(socket.get_state(), TcpState::Listen);
2120
2121 socket.set_state(TcpState::SynSent);
2122 assert_eq!(socket.get_state(), TcpState::SynSent);
2123
2124 socket.set_state(TcpState::Established);
2125 assert_eq!(socket.get_state(), TcpState::Established);
2126 }
2127
2128 #[test_case]
2129 fn test_tcp_checksum() {
2130 let local_ip = [192, 168, 1, 100];
2131 let dest_ip = [192, 168, 1, 1];
2132 let data = b"test";
2133
2134 let mut header = TcpHeader::new(1234, 5678);
2135 header.seq_number = 1000;
2136 header.ack_number = 2000;
2137
2138 let checksum = header.calculate_checksum(local_ip, dest_ip, data);
2139
2140 assert_ne!(checksum, 0);
2141 }
2142}