diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index 19543e4a..3654f780 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -387,8 +387,6 @@ mod tests { #[tokio::test] async fn test_send_delegates_to_session_handle() { - use crate::session::error::SendOutcome; - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let config = create_test_config("127.0.0.1", port); @@ -402,9 +400,13 @@ mod tests { .await .expect("initiator should connect"); - // Message should be received by session and persisted (seq 2 after Logon) + // Session is in AwaitingLogon (no logon response from counterparty), + // so send should be rejected — only Active sessions accept app messages let result = initiator.send(DummyMessage).await; - assert!(matches!(result, Ok(SendOutcome::Sent { .. }))); + assert!( + matches!(result, Err(crate::session::error::SendError::Disconnected)), + "expected Disconnected error, got: {result:?}" + ); } #[tokio::test] diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 3926bc74..3d1f5c57 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -1,7 +1,9 @@ pub(crate) mod admin_request; +mod ctx; pub mod error; pub(crate) mod event; mod info; +mod message_handling; mod session_handle; pub mod session_ref; mod state; @@ -14,28 +16,17 @@ use std::pin::Pin; use tokio::select; use tokio::sync::mpsc; use tokio::time::{Duration, Instant, Sleep, sleep, sleep_until}; -use tracing::{debug, enabled, error, info, warn}; +use tracing::{debug, error, info, warn}; use crate::Application; -use crate::application::{InboundDecision, OutboundDecision}; use crate::config::SessionConfig; -use crate::message::OutboundMessage; -use crate::message::business_reject::BusinessReject; -use crate::message::generate_message; -use crate::message::heartbeat::Heartbeat; use crate::message::logon::{Logon, ResetSeqNumConfig}; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; -use crate::message::reject::Reject; use crate::message::resend_request::ResendRequest; -use crate::message::sequence_reset::SequenceReset; -use crate::message::test_request::TestRequest; -use crate::message::verification::verify_message; -use crate::message::verification_error::{CompIdType, MessageVerificationError}; -use crate::message::{is_admin, prepare_message_for_resend}; use crate::session::admin_request::AdminRequest; use crate::session::error::SessionCreationError; -use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; +use crate::session::error::{InternalSendResultExt, SessionOperationError}; pub use crate::session::error::{SendError, SendOutcome}; pub use crate::session::info::{SessionInfo, Status}; pub use crate::session::session_handle::SessionHandle; @@ -45,16 +36,13 @@ pub(crate) use crate::session::session_ref::InternalSessionRef; pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; -use crate::session::state::{AwaitingResendTransitionOutcome, TestRequestId}; +use crate::session::state::{SessionCtx, TransitionResult}; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; use event::SessionEvent; use hotfix_message::parsed_message::{InvalidReason, ParsedMessage}; -use hotfix_message::session_fields::{ - BEGIN_SEQ_NO, END_SEQ_NO, GAP_FILL_FLAG, MSG_SEQ_NUM, MSG_TYPE, NEW_SEQ_NO, - SessionRejectReason, TEST_REQ_ID, -}; +use hotfix_message::session_fields::MSG_SEQ_NUM; const SCHEDULE_CHECK_INTERVAL: u64 = 1; @@ -120,64 +108,26 @@ where raw_message: RawFixMessage, ) -> Result<(), SessionOperationError> { debug!("received message: {}", raw_message); - if !self.state.is_expecting_test_response() { - // if we are not awaiting a specific test response, any message can reset the timer - // otherwise only a heartbeat with the corresponding TestReqID can - self.reset_peer_timer(None); + + // Reset peer timer before dispatching (if not expecting test response) + if let SessionState::Active(active) = &mut self.state + && active.expected_test_response_id().is_none() + { + active.reset_peer_timer(self.config.heartbeat_interval, None); } match self.message_builder.build(raw_message.as_bytes()) { ParsedMessage::Valid(message) => { - self.process_message(message).await?; - self.check_end_of_resend().await?; + self.dispatch_valid_message(message).await?; } ParsedMessage::Garbled(r) => { - // garbled messages should be skipped and we should assume it was a transmission error let message = raw_message.to_string(); let reason = format!("{r:?}"); error!(message, reason, "received garbled message"); } - ParsedMessage::Invalid { message, reason } => match reason { - InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidTagNumber) - .text(&format!("invalid field {tag}")); - self.send_message(reject) - .await - .with_send_context("reject for invalid field")?; - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } - InvalidReason::InvalidComponent(_component_name) => { - // TODO: what's the correct way to handle this? - warn!("received invalid component"); - } - InvalidReason::InvalidMsgType(msg_type) => { - self.handle_invalid_msg_type(message, &msg_type).await; - } - InvalidReason::InvalidOrderInGroup { tag, .. } => { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason( - SessionRejectReason::RepeatingGroupFieldsOutOfOrder, - ) - .text(&format!("field appears in incorrect order:{tag}")); - self.send_message(reject) - .await - .with_send_context("reject for invalid group order")?; - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } - }, + ParsedMessage::Invalid { message, reason } => { + self.handle_invalid_parsed_message(message, reason).await?; + } ParsedMessage::UnexpectedError(err) => { error!("unexpected error: {:?}", err); } @@ -186,775 +136,141 @@ where Ok(()) } - async fn process_message(&mut self, message: Message) -> Result<(), SessionOperationError> { - let message_type: &str = message - .header() - .get(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; - - if let SessionState::AwaitingResend(state) = &mut self.state { - let seq_number = get_msg_seq_num(&message); - if seq_number > state.end_seq_number && message_type != ResendRequest::MSG_TYPE { - state.inbound_queue.push_back(message); - return Ok(()); - } - } - - if let SessionState::AwaitingLogon { .. } = &mut self.state { - // TODO: should this (and all inbound message processing) logic be pushed into the state? - if message_type != Logon::MSG_TYPE { - self.state.disconnect_writer().await; - return Ok(()); - } - } - - match message_type { - Heartbeat::MSG_TYPE => { - self.on_heartbeat(&message).await?; - } - TestRequest::MSG_TYPE => { - self.on_test_request(&message).await?; - } - ResendRequest::MSG_TYPE => { - self.on_resend_request(&message).await?; - } - Reject::MSG_TYPE => { - self.on_reject(&message).await?; - } - SequenceReset::MSG_TYPE => { - self.on_sequence_reset(&message).await?; - } - Logout::MSG_TYPE => { - self.on_logout(&message).await?; - } - Logon::MSG_TYPE => { - self.on_logon(&message).await?; - } - _ => self.process_app_message(&message).await?, - } - - Ok(()) - } - - async fn process_app_message( + async fn handle_invalid_parsed_message( &mut self, - message: &Message, + message: Message, + reason: InvalidReason, ) -> Result<(), SessionOperationError> { - match self.verify_message(message, true, true) { - Ok(_) => { - match self.application.on_inbound_message(message).await { - InboundDecision::Accept => {} - InboundDecision::Reject { reason, text } => { - let msg_type: &str = message - .header() - .get(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; - let mut reject = BusinessReject::new(msg_type, reason) - .ref_seq_num(get_msg_seq_num(message)); - if let Some(text) = text { - reject = reject.text(&text); - } - self.send_message(reject) - .await - .with_send_context("business message reject")?; - } - InboundDecision::TerminateSession => { - error!("failed to send inbound message to application"); - self.state.disconnect_writer().await; - } - } - self.store.increment_target_seq_number().await?; - } - Err(err) => self.handle_verification_error(err).await?, - } - - Ok(()) - } - - async fn check_end_of_resend(&mut self) -> Result<(), SessionOperationError> { - let ended_state = if let SessionState::AwaitingResend(state) = &mut self.state { - if self.store.next_target_seq_number() > state.end_seq_number { - let new_state = - SessionState::new_active(state.writer.clone(), self.config.heartbeat_interval); - Some(std::mem::replace(&mut self.state, new_state)) - } else { - None - } - } else { - None - }; - - if let Some(SessionState::AwaitingResend(mut state)) = ended_state { - // we have reached the end of the resend, - // process queued messages and resume normal operation - debug!("resend is done, processing backlog"); - while let Some(msg) = state.inbound_queue.pop_front() { - let seq_number: u64 = msg.get(MSG_SEQ_NUM).unwrap_or_else(|e| { - error!("failed to get seq number: {:?}", e); - 0 - }); - let msg_type: &str = msg.header().get(MSG_TYPE).unwrap_or(""); - debug!(seq_number, msg_type, "processing queued message"); - - if msg_type == ResendRequest::MSG_TYPE { - // ResendRequest was already processed when it arrived (it bypasses - // the queue in process_message). Just increment the target seq number - // for sequence accounting purposes. - self.store.increment_target_seq_number().await?; - } else { - self.process_message(msg).await?; - } - } - debug!("resend backlog is cleared, resuming normal operation"); - } - - Ok(()) - } - - fn verify_message( - &self, - message: &Message, - check_too_high: bool, - check_too_low: bool, - ) -> Result<(), MessageVerificationError> { - let expected_seq_number = if check_too_high || check_too_low { - Some(self.store.next_target_seq_number()) - } else { - None - }; - verify_message( - message, - &self.config, - expected_seq_number, - check_too_high, - check_too_low, - ) - } - - async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { - self.state = SessionState::AwaitingLogon { - writer, - logon_sent: false, - logon_timeout: Instant::now() + Duration::from_secs(self.config.logon_timeout), - }; - self.reset_peer_timer(None); - self.send_logon().await?; - - Ok(()) - } - - async fn on_disconnect(&mut self, reason: String) { - match self.state { - SessionState::Active { .. } - | SessionState::AwaitingLogon { .. } - | SessionState::AwaitingResend(_) => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected(true, &reason); - } - SessionState::Disconnected { .. } => { - warn!("disconnect message was received, but the session is already disconnected") - } - SessionState::AwaitingLogout { reconnect, .. } => { - self.state = SessionState::new_disconnected(reconnect, &reason); - } - } - } - - async fn on_logon(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let SessionState::AwaitingLogon { writer, .. } = &self.state { - match self.verify_message(message, true, true) { - Ok(_) => { - // happy logon flow, the session is now active - self.state = - SessionState::new_active(writer.clone(), self.config.heartbeat_interval); - self.application.on_logon().await; - self.store.increment_target_seq_number().await?; - } - Err(err) => self.handle_verification_error(err).await?, - } - } else { - error!("received unexpected logon message"); - } - - Ok(()) - } - - async fn on_logout(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, false, false) { - self.handle_verification_error(err).await?; + let writer = self.state.get_writer(); + let Some(writer) = writer else { return Ok(()); - } - - if self.state.is_logged_on() { - self.send_logout("Logout acknowledged").await?; - } - - self.application.on_logout("peer has logged us out").await; - - match self.state { - // if the session is already disconnected, we have nothing else to do - SessionState::Disconnected(..) => {} - // if we initiated the logout, preserve the reconnect flag - SessionState::AwaitingLogout { reconnect, .. } => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected(reconnect, "logout completed"); - } - // otherwise assume it makes sense to try to reconnect - _ => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected(true, "peer has logged us out") - } - } - - self.store.increment_target_seq_number().await?; - Ok(()) - } - - async fn on_heartbeat(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, true, true) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - if let (Some(expected_req_id), Ok(message_req_id)) = ( - &self.state.expected_test_response_id(), - message.get::<&str>(TEST_REQ_ID), - ) && expected_req_id.as_str() == message_req_id - { - debug!("received response for TestRequest, resetting timer"); - self.reset_peer_timer(None); - } - - self.store.increment_target_seq_number().await?; - Ok(()) - } - - async fn on_test_request(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, true, true) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { - // TODO: send reject? - todo!() - }); - - self.store.increment_target_seq_number().await?; - - self.send_message(Heartbeat::for_request(req_id.to_string())) - .await - .with_send_context("heartbeat response")?; - - Ok(()) - } - - async fn on_resend_request(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if !self.state.is_connected() { - warn!("received resend request while disconnected, ignoring"); - return Ok(()); - } - - // Verify with check_too_high=false so ResendRequest is never blocked by seq-too-high. - // This is the key part of the QFJ-673 deadlock fix: when both sides send ResendRequest - // simultaneously, each side's ResendRequest will have a seq number higher than expected. - // By not treating that as an error, we allow the ResendRequest to be processed. - match self.verify_message(message, false, true) { - Ok(_) => {} - Err(err) => { - self.handle_verification_error(err).await?; - return Ok(()); - } - } - - let msg_seq_num = get_msg_seq_num(message); - let expected = self.store.next_target_seq_number(); - - // If seq is too high and we're in AwaitingResend, queue it for seq accounting - // when the gap fill catches up, but still process the resend below. - if msg_seq_num > expected - && let SessionState::AwaitingResend(state) = &mut self.state - { - state.inbound_queue.push_back(message.clone()); - } - - let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { - Ok(seq_number) => seq_number, - Err(_) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("missing begin sequence number for resend request"); - self.send_message(reject) - .await - .with_send_context("reject for missing BEGIN_SEQ_NO")?; - return Ok(()); - } }; - - let end_seq_number: u64 = match message.get(END_SEQ_NO) { - Ok(seq_number) => { - let last_seq_number = self.store.next_sender_seq_number() - 1; - if seq_number == 0 { - last_seq_number - } else { - std::cmp::min(seq_number, last_seq_number) - } - } - Err(_) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("missing end sequence number for resend request"); - self.send_message(reject) - .await - .with_send_context("reject for missing END_SEQ_NO")?; - return Ok(()); - } + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, }; - - // Only increment target seq if seq matches expected - if msg_seq_num == expected { - self.store.increment_target_seq_number().await?; - } - - self.resend_messages(begin_seq_number, end_seq_number, message) - .await?; - - Ok(()) - } - - /// Handle Reject messages. - async fn on_reject(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, false, true) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - self.store.increment_target_seq_number().await?; - Ok(()) + message_handling::handle_invalid_parsed_message(&mut ctx, writer, &message, reason).await } - async fn on_sequence_reset(&mut self, message: &Message) -> Result<(), SessionOperationError> { - let msg_seq_num = get_msg_seq_num(message); - let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); - if let Err(err) = self.verify_message(message, is_gap_fill, is_gap_fill) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - let end: u64 = match message.get(NEW_SEQ_NO) { - Ok(new_seq_no) => new_seq_no, - Err(err) => { - error!( - "received sequence reset message without new sequence number: {:?}", - err - ); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("missing NewSeqNo tag in sequence reset message"); - self.send_message(reject) - .await - .with_send_context("reject for missing NEW_SEQ_NO")?; - - // note: we don't increment the target seq number here - // this is an ambiguous case in the specification, but leaving the - // sequence number as is feels the safest - return Ok(()); - } - }; - - // sequence resets cannot move the target seq number backwards - // regardless of whether the message is a gap fill or not - if end <= self.store.next_target_seq_number() { - error!( - "received sequence reset message which would move target seq number backwards: {end}", - ); - let text = - format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::ValueIsIncorrect) - .text(&text); - self.send_message(reject) - .await - .with_send_context("reject for invalid sequence reset")?; - return Ok(()); - } - - self.store.set_target_seq_number(end - 1).await?; - Ok(()) - } - - async fn handle_verification_error( + async fn dispatch_valid_message( &mut self, - error: MessageVerificationError, + message: Message, ) -> Result<(), SessionOperationError> { - match error { - MessageVerificationError::SeqNumberTooLow { - expected, - actual, - possible_duplicate, - } => { - self.handle_sequence_number_too_low(expected, actual, possible_duplicate) - .await; - } - MessageVerificationError::SeqNumberTooHigh { expected, actual } => { - self.handle_sequence_number_too_high(expected, actual) - .await?; - } - MessageVerificationError::IncorrectBeginString(begin_string) => { - self.handle_incorrect_begin_string(begin_string).await; - } - MessageVerificationError::IncorrectCompId { - comp_id, - comp_id_type, - msg_seq_num, - } => { - self.handle_incorrect_comp_id(comp_id, comp_id_type, msg_seq_num) - .await; - } - MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { - self.handle_sending_time_accuracy_problem(msg_seq_num, "unexpected sending time") - .await; - } - MessageVerificationError::SendingTimeMissing { msg_seq_num } => { - self.handle_sending_time_accuracy_problem(msg_seq_num, "sending time missing") - .await; - } - MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { - self.handle_original_sending_time_missing(msg_seq_num).await; - } - MessageVerificationError::OriginalSendingTimeAfterSendingTime { - msg_seq_num, .. - } => { - self.handle_sending_time_accuracy_problem( - msg_seq_num, - "original sending time is after sending time", - ) - .await; - } - } - - Ok(()) + let transition = self.dispatch_to_state(message).await?; + self.apply_transition(transition).await } - async fn handle_incorrect_begin_string(&mut self, received_begin_string: String) { - self.logout_and_terminate(&format!( - "beginString={received_begin_string} is not supported" - )) - .await; - } - - async fn handle_incorrect_comp_id( + async fn dispatch_to_state( &mut self, - received_comp_id: String, - comp_id_type: CompIdType, - msg_seq_num: u64, - ) { - error!( - "rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})" - ); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::ValueIsIncorrect) - .text(&format!("invalid comp ID {received_comp_id}")); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject message with invalid comp ID: {err}"); + message: Message, + ) -> Result { + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + ref mut application, + .. + } = *self; + + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, }; - self.logout_and_terminate("incorrect comp ID received") - .await; - } - - async fn handle_sequence_number_too_low( - &mut self, - expected: u64, - actual: u64, - possible_duplicate: bool, - ) { - if possible_duplicate { - warn!( - "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" - ); - return; - } - error!( - "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." - ); - let reason = format!("sequence number too low (actual {actual}, expected {expected})"); - self.logout_and_terminate(&reason).await; - self.state = SessionState::new_disconnected(false, &reason); - } - - async fn handle_sequence_number_too_high( - &mut self, - expected: u64, - actual: u64, - ) -> Result<(), SessionOperationError> { - match self - .state - .try_transition_to_awaiting_resend(expected, actual) - { - AwaitingResendTransitionOutcome::Success => { - debug!( - "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." - ); - self.send_resend_request(expected, actual).await?; - } - AwaitingResendTransitionOutcome::InvalidState(reason) => { - error!("failed to request resend: {reason}"); + let transition = match state { + SessionState::Active(s) => s.on_fix_message(&mut ctx, application, message).await?, + SessionState::AwaitingLogon(s) => { + s.on_fix_message(&mut ctx, application, message).await? } - AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected( - false, - "awaiting resend begin seq number unexpectedly lower than the previous resend request's", - ); - } - AwaitingResendTransitionOutcome::AttemptsExceeded => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected( - false, - "resend request attempts exceeded, manual intervention required", - ); - } - } - - Ok(()) - } - - async fn handle_invalid_msg_type(&mut self, message: Message, msg_type: &str) { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidMsgtype) - .text(&format!("invalid message type {msg_type}")); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject message for invalid msgtype: {err}"); - }; - - #[allow(clippy::collapsible_if)] - if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) - && self.store.next_target_seq_number() == seq_num - { - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - }; - } + SessionState::AwaitingResend(s) => { + s.on_fix_message(&mut ctx, application, message).await? } - Err(err) => { - error!("failed to get message seq num: {:?}", err); + SessionState::AwaitingLogout(s) => { + s.on_fix_message(&mut ctx, application, message).await? } - } - } - - async fn handle_sending_time_accuracy_problem(&mut self, msg_seq_num: u64, text: &str) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) - .text(text); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject for time accuracy problem: {err}"); - }; - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); + SessionState::Disconnected(_) => TransitionResult::Stay, }; - } - async fn handle_original_sending_time_missing(&mut self, msg_seq_num: u64) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("original sending time is required"); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject for time missing tag: {err}"); - }; - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - }; + Ok(transition) } - async fn resend_messages( + async fn apply_transition( &mut self, - begin: u64, - end: u64, - _message: &Message, + transition: TransitionResult, ) -> Result<(), SessionOperationError> { - info!(begin, end, "resending messages as requested"); - let messages = self.store.get_slice(begin as usize, end as usize).await?; - - let no = messages.len(); - debug!(number_of_messages = no, "number of messages"); - - let mut reset_start: Option = None; - let mut sequence_number = 0; - - for msg in messages { - let mut message = self - .message_builder - .build(msg.as_slice()) - .into_message() - .ok_or_else(|| { - SessionOperationError::StoredMessageParse(format!( - "failed to build message for raw message: {msg:?}" - )) - })?; - sequence_number = get_msg_seq_num(&message); - let message_type: String = message - .header() - .get::<&str>(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? - .to_string(); - - if is_admin(&message_type) { - if reset_start.is_none() { - reset_start = Some(sequence_number); + match transition { + TransitionResult::Stay => {} + TransitionResult::TransitionTo(new_state) => { + self.state = new_state; + } + TransitionResult::TransitionWithBacklog { + new_state, + mut backlog, + } => { + self.state = new_state; + debug!("resend is done, processing backlog"); + while let Some(msg) = backlog.pop_front() { + let seq_number: u64 = msg.get(MSG_SEQ_NUM).unwrap_or_else(|e| { + error!("failed to get seq number: {:?}", e); + 0 + }); + let msg_type: &str = msg + .header() + .get(hotfix_message::session_fields::MSG_TYPE) + .unwrap_or(""); + debug!(seq_number, msg_type, "processing queued message"); + + if msg_type == ResendRequest::MSG_TYPE { + // ResendRequest was already processed when it arrived (it bypasses + // the queue). Just increment the target seq number + // for sequence accounting purposes. + self.store.increment_target_seq_number().await?; + } else { + let inner_transition = self.dispatch_to_state(msg).await?; + // Backlog messages can't produce more backlogs (only AwaitingResend + // produces TransitionWithBacklog, and we've already transitioned to Active) + if let TransitionResult::TransitionTo(s) = inner_transition { + self.state = s; + } + } } - continue; + debug!("resend backlog is cleared, resuming normal operation"); } - - if let Some(begin) = reset_start { - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(begin, end).await?; - reset_start = None; - } - - if let Err(e) = prepare_message_for_resend(&mut message) { - error!( - error = e, - "failed to prepare message for resend, sending original" - ); - } - self.send_raw(&message_type, message.encode(&self.message_config)?) - .await; - - if enabled!(tracing::Level::DEBUG) - && let Ok(m) = String::from_utf8(msg.clone()) - { - debug!(sequence_number, message = m, "resent message"); - } - } - - if let Some(begin) = reset_start { - // the final reset if needed - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(begin, end).await?; } - Ok(()) } - fn log_skipped_admin_messages(begin: u64, end: u64) { - info!( - begin, - end, "skipped admin message(s) during resend, requesting reset for these" - ); - } - - fn reset_heartbeat_timer(&mut self) { - self.state - .reset_heartbeat_timer(self.config.heartbeat_interval); - } - - fn reset_peer_timer(&mut self, test_request_id: Option) { - self.state - .reset_peer_timer(self.config.heartbeat_interval, test_request_id); - } - - async fn send_app_message(&mut self, message: App::Outbound) -> Result { - if !self.state.is_connected() { - return Err(SendError::Disconnected); - } - - match self.application.on_outbound_message(&message).await { - OutboundDecision::Send => { - let sequence_number = self.send_message(message).await?; - Ok(SendOutcome::Sent { sequence_number }) - } - OutboundDecision::Drop => { - debug!("dropped outbound message as instructed by the application"); - Ok(SendOutcome::Dropped) - } - OutboundDecision::TerminateSession => { - warn!("the application indicated we should terminate the session"); - self.state.disconnect_writer().await; - Err(SendError::SessionTerminated) - } + async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { + if let SessionState::Disconnected(s) = &self.state { + self.state = s.on_connect(writer, Duration::from_secs(self.config.logon_timeout)); } - } - - async fn send_message( - &mut self, - message: impl OutboundMessage, - ) -> Result { - let seq_num = self.store.next_sender_seq_number(); - let msg_type = message.message_type().to_string(); - let msg = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - seq_num, - message, - ) - .map_err(|e| { - InternalSendError::Persist(crate::store::StoreError::PersistMessage { - sequence_number: seq_num, - source: e.into(), - }) - })?; - - self.store - .increment_sender_seq_number() - .await - .map_err(InternalSendError::SequenceNumber)?; - - self.store - .add(seq_num, &msg) - .await - .map_err(InternalSendError::Persist)?; - - self.send_raw(&msg_type, msg).await; - - Ok(seq_num) - } - - async fn send_raw(&mut self, message_type: &str, data: Vec) { - self.state - .send_message(message_type, RawFixMessage::new(data)) - .await; - self.reset_heartbeat_timer(); - } - - async fn send_sequence_reset( - &mut self, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - let sequence_reset = SequenceReset { - gap_fill: true, - new_seq_no: end, - }; - let raw_message = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - begin, - sequence_reset, - )?; - - self.send_raw(SequenceReset::MSG_TYPE, raw_message).await; - debug!(begin, end, "sent reset sequence"); + // Reset peer timer on the new AwaitingLogon state — no-op since it uses logon_timeout + // Send logon + self.send_logon().await?; Ok(()) } - async fn send_resend_request( - &mut self, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - let request = ResendRequest::new(begin, end); - self.send_message(request) - .await - .with_send_context("resend request")?; - Ok(()) + async fn on_disconnect(&mut self, reason: String) { + let transition = match &self.state { + SessionState::Active(s) => Some(s.on_disconnect(&reason).await), + SessionState::AwaitingLogon(s) => Some(s.on_disconnect(&reason).await), + SessionState::AwaitingResend(s) => Some(s.on_disconnect(&reason).await), + SessionState::AwaitingLogout(s) => Some(s.on_disconnect(&reason).await), + SessionState::Disconnected(_) => { + warn!("disconnect message was received, but the session is already disconnected"); + None + } + }; + if let Some(new_state) = transition { + self.state = new_state; + } } async fn send_logon(&mut self) -> Result<(), SessionOperationError> { @@ -967,49 +283,61 @@ where self.reset_on_next_logon = false; let logon = Logon::new(self.config.heartbeat_interval, reset_config); - self.send_message(logon).await.with_send_context("logon")?; + let writer = self.state.get_writer(); + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + if let Some(writer) = writer { + ctx.send_message(writer, logon) + .await + .with_send_context("logon")?; + } Ok(()) } async fn send_logout(&mut self, reason: &str) -> Result<(), SessionOperationError> { let logout = Logout::with_reason(reason.to_string()); - self.send_message(logout) - .await - .with_send_context("logout")?; - Ok(()) - } - - /// Sends a logout message and immediately disconnects the counterparty. - /// - /// This should be used sparingly in scenarios where there is a major issue - /// requiring operational intervention, such as the sequence number being lower - /// than expected, or some other key header field containing an invalid value. - /// - /// In other scenarios, [`initiate_graceful_logout`] should be preferred. - async fn logout_and_terminate(&mut self, reason: &str) { - if let Err(err) = self.send_logout(reason).await { - warn!("failed to send logout during session termination: {}", err); + let writer = self.state.get_writer(); + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + if let Some(writer) = writer { + ctx.send_message(writer, logout) + .await + .with_send_context("logout")?; } - self.state.disconnect_writer().await; + Ok(()) } - /// Sends a logout message and puts the session state into an [`AwaitingLogout`] state. - /// - /// The session waits for a configurable timeout period for the counterparty to - /// respond with a `Logout` message. If no response is received within the timeout - /// period, it disconnects the counterparty. + /// Sends a logout message and puts the session state into an AwaitingLogout state. async fn initiate_graceful_logout( &mut self, reason: &str, reconnect: bool, ) -> Result<(), SessionOperationError> { - if self.state.try_transition_to_awaiting_logout( - Duration::from_secs(self.config.logout_timeout), - reconnect, - ) { - self.send_logout(reason).await?; + if matches!(self.state, SessionState::AwaitingLogout(_)) { + debug!("already in awaiting logout state"); + return Ok(()); } + let Some(writer) = self.state.get_writer().cloned() else { + error!("trying to transition to awaiting logout without an established connection"); + return Ok(()); + }; + + self.state = SessionState::AwaitingLogout(state::AwaitingLogoutState::new( + writer, + Instant::now() + Duration::from_secs(self.config.logout_timeout), + reconnect, + )); + self.send_logout(reason).await?; + Ok(()) } @@ -1021,7 +349,15 @@ where if let Err(err) = self.on_incoming(fix_message).await { let reason = err.to_string(); error!(reason, "fatal error in message processing"); - self.logout_and_terminate("internal error").await; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "internal error") + .await; self.state = SessionState::new_disconnected(true, &reason); } } @@ -1047,10 +383,31 @@ where async fn handle_outbound_message(&mut self, request: OutboundRequest) { let OutboundRequest { message, confirm } = request; - let result = self.send_app_message(message).await; + + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + ref mut application, + .. + } = *self; + + let result = if let SessionState::Active(s) = state { + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + s.send_app_message(&mut ctx, application, message).await + } else { + Err(SendError::Disconnected) + }; + match confirm { Some(tx) => { - // Ignore send errors - receiver may have been dropped let _ = tx.send(result); } None => { @@ -1086,29 +443,55 @@ where } async fn handle_heartbeat_timeout(&mut self) { - if let Err(err) = self.send_message(Heartbeat::default()).await { - error!(err = ?err, "failed to send heartbeat message"); + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + if let SessionState::Active(active) = state { + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + active.on_heartbeat_timeout(&mut ctx).await; } } async fn handle_peer_timeout(&mut self) { - if self.state.is_expecting_test_response() { - warn!("peer didn't respond, terminating.."); - self.logout_and_terminate("peer timeout").await; - } else if self.state.is_awaiting_logon() { - warn!("peer didn't respond to our Logon, disconnecting.."); - self.state.disconnect_writer().await; - } else if self.state.is_awaiting_logout() { - warn!("peer didn't respond to our Logout, disconnecting.."); - self.state.disconnect_writer().await; - } else { - let req_id = format!("TEST_{}", self.store.next_target_seq_number()); - info!("sending TestRequest due to peer timer expiring"); - let request = TestRequest::new(req_id.clone()); - if let Err(err) = self.send_message(request).await { - error!(err = ?err, "failed to send TestRequest"); + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + let transition = match state { + SessionState::Active(active) => { + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + active.on_peer_timeout(&mut ctx).await } - self.reset_peer_timer(Some(req_id)); + SessionState::AwaitingLogon(awaiting_logon) => { + awaiting_logon.on_peer_timeout().await; + None + } + SessionState::AwaitingLogout(awaiting_logout) => { + Some(awaiting_logout.on_peer_timeout().await) + } + _ => None, + }; + if let Some(new_state) = transition { + self.state = new_state; } } @@ -1126,9 +509,15 @@ where // we are in the same period, nothing needs to be done } Ok(SessionPeriodComparison::DifferentPeriod) => { - // the message store is for a previous session, - // we need to terminate this session, reset the store, and reestablish the session - self.logout_and_terminate("session period changed").await; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "session period changed") + .await; if let Err(err) = self.store.reset().await { error!("error resetting session store: {err:}"); self.state = @@ -1136,10 +525,15 @@ where } } Ok(SessionPeriodComparison::OutsideSessionTime { .. }) => { - // the creation_time was recorded outside the session schedule, - // treat this similarly to a different period - reset the store warn!("store creation time is outside session schedule, resetting store"); - self.logout_and_terminate("creation time outside schedule") + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "creation time outside schedule") .await; if let Err(err) = self.store.reset().await { error!("error resetting session store: {err:}"); @@ -1148,22 +542,26 @@ where } } Err(err) => { - // actual schedule calculation error (e.g., DST transition, date overflow) error!("error checking session period: {err:?}"); - self.logout_and_terminate("internal error").await; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "internal error") + .await; } } - } else if self.state.is_connected() { - // we are currently outside scheduled session time - if let Err(err) = self + } else if self.state.is_connected() + && let Err(err) = self .initiate_graceful_logout("End of session time", true) .await - { - error!(err = ?err, "failed to initiate graceful logout"); - } + { + error!(err = ?err, "failed to initiate graceful logout"); } - // we always need to reschedule the check, otherwise we won't be able to resume an inactive session let deadline = Instant::now() + Duration::from_secs(SCHEDULE_CHECK_INTERVAL); self.schedule_check_timer.as_mut().reset(deadline); } @@ -1179,14 +577,11 @@ where /// Extracts MsgSeqNum from a message header. /// -/// To be removed once https://github.com/Validus-Risk-Management/hotfix/issues/301 -/// is implemented. -/// /// # Panics /// Panics if the message does not contain a valid MsgSeqNum field. /// This should never happen for messages that have passed validation. #[allow(clippy::expect_used)] -fn get_msg_seq_num(message: &Message) -> u64 { +pub(crate) fn get_msg_seq_num(message: &Message) -> u64 { message .header() .get(MSG_SEQ_NUM) @@ -1491,8 +886,6 @@ mod tests { session.handle_schedule_check().await; // Store reset should have been called (indicates DifferentPeriod branch was taken) - // Note: logout_and_terminate disconnects the writer but state transition to - // Disconnected happens asynchronously via event processing, not in this call assert!( session.store.was_reset_called(), "Store reset should be called for different period" @@ -1557,7 +950,6 @@ mod tests { let state = SessionState::new_active(writer, 30); // Creation time is today but at a time outside the schedule window - // Use a time that's definitely outside the window (6 hours from now) let outside_hour = (current_hour + 6) % 24; let creation_time = DateTime::from_naive_utc_and_offset( NaiveDate::from_ymd_opt(now.year(), now.month(), now.day()) @@ -1606,7 +998,7 @@ mod tests { // State should be AwaitingLogout (graceful logout initiated) assert!( - session.state.is_awaiting_logout(), + matches!(session.state, SessionState::AwaitingLogout(_)), "State should be AwaitingLogout when schedule is inactive and was connected" ); } diff --git a/crates/hotfix/src/session/ctx.rs b/crates/hotfix/src/session/ctx.rs new file mode 100644 index 00000000..e38c467a --- /dev/null +++ b/crates/hotfix/src/session/ctx.rs @@ -0,0 +1,89 @@ +use crate::config::SessionConfig; +use crate::message::parser::RawFixMessage; +use crate::message::{OutboundMessage, generate_message}; +use crate::session::error::InternalSendError; +use crate::session::state::SessionState; +use crate::store::StoreError; +use crate::transport::writer::WriterRef; +use hotfix_message::MessageBuilder; +use hotfix_message::message::{Config as MessageConfig, Message}; +use hotfix_store::MessageStore; +use std::collections::VecDeque; + +pub(crate) struct SessionCtx<'a, Store> { + pub config: &'a SessionConfig, + pub store: &'a mut Store, + pub message_builder: &'a MessageBuilder, + pub message_config: &'a MessageConfig, +} + +pub(crate) struct PreparedMessage { + pub seq_num: u64, + #[allow(dead_code)] + pub msg_type: String, + pub raw: RawFixMessage, +} + +pub(crate) enum TransitionResult { + Stay, + TransitionTo(SessionState), + TransitionWithBacklog { + new_state: SessionState, + backlog: VecDeque, + }, +} + +pub(crate) enum VerifyResult { + Passed, + SeqTooHigh { expected: u64, actual: u64 }, + Handled(TransitionResult), +} + +impl SessionCtx<'_, Store> { + pub async fn prepare_message( + &mut self, + message: impl OutboundMessage, + ) -> Result { + let seq_num = self.store.next_sender_seq_number(); + let msg_type = message.message_type().to_string(); + let msg = generate_message( + &self.config.begin_string, + &self.config.sender_comp_id, + &self.config.target_comp_id, + seq_num, + message, + ) + .map_err(|e| { + InternalSendError::Persist(StoreError::PersistMessage { + sequence_number: seq_num, + source: e.into(), + }) + })?; + + self.store + .increment_sender_seq_number() + .await + .map_err(InternalSendError::SequenceNumber)?; + self.store + .add(seq_num, &msg) + .await + .map_err(InternalSendError::Persist)?; + + Ok(PreparedMessage { + seq_num, + msg_type, + raw: RawFixMessage::new(msg), + }) + } + + /// Prepare, persist, and send a message via the given writer. + pub async fn send_message( + &mut self, + writer: &WriterRef, + message: impl OutboundMessage, + ) -> Result { + let prepared = self.prepare_message(message).await?; + writer.send_raw_message(prepared.raw).await; + Ok(prepared.seq_num) + } +} diff --git a/crates/hotfix/src/session/message_handling.rs b/crates/hotfix/src/session/message_handling.rs new file mode 100644 index 00000000..90785239 --- /dev/null +++ b/crates/hotfix/src/session/message_handling.rs @@ -0,0 +1,402 @@ +use crate::message::logout::Logout; +use crate::message::parser::RawFixMessage; +use crate::message::reject::Reject; +use crate::message::sequence_reset::SequenceReset; +use crate::message::verification::verify_message as verify_message_impl; +use crate::message::verification_error::{CompIdType, MessageVerificationError}; +use crate::message::{generate_message, is_admin, prepare_message_for_resend}; +use crate::session::ctx::{SessionCtx, TransitionResult, VerifyResult}; +use crate::session::error::{InternalSendResultExt, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::state::SessionState; +use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::message::Message; +use hotfix_message::parsed_message::InvalidReason; +use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE, SessionRejectReason}; +use hotfix_store::MessageStore; +use tracing::{debug, enabled, error, info, warn}; + +fn verify_message( + ctx: &SessionCtx<'_, Store>, + message: &Message, + check_too_high: bool, + check_too_low: bool, +) -> Result<(), MessageVerificationError> { + let expected_seq_number = if check_too_high || check_too_low { + Some(ctx.store.next_target_seq_number()) + } else { + None + }; + verify_message_impl( + message, + ctx.config, + expected_seq_number, + check_too_high, + check_too_low, + ) +} + +/// Verify a message and handle the error if verification fails. +/// For SeqNumberTooHigh, returns `VerifyResult::SeqTooHigh` instead of handling it, +/// allowing the caller to handle the transition. +pub async fn verify_and_handle( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + message: &Message, + check_too_high: bool, + check_too_low: bool, +) -> Result { + match verify_message(ctx, message, check_too_high, check_too_low) { + Ok(()) => Ok(VerifyResult::Passed), + Err(MessageVerificationError::SeqNumberTooHigh { expected, actual }) => { + Ok(VerifyResult::SeqTooHigh { expected, actual }) + } + Err(err) => { + let transition = handle_verification_error(ctx, writer, err).await?; + Ok(VerifyResult::Handled(transition)) + } + } +} + +/// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). +/// Returns the `TransitionResult` to use — either `Stay` (error was handled in-place) +/// or `TransitionTo` (a state change is needed). +async fn handle_verification_error( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + error: MessageVerificationError, +) -> Result { + match error { + MessageVerificationError::SeqNumberTooLow { + expected, + actual, + possible_duplicate, + } => Ok( + handle_sequence_number_too_low(ctx, writer, expected, actual, possible_duplicate).await, + ), + MessageVerificationError::SeqNumberTooHigh { expected, actual } => { + // This shouldn't be called for SeqTooHigh anymore (it's returned via VerifyResult), + // but handle gracefully if it is. + warn!( + "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" + ); + Ok(TransitionResult::Stay) + } + MessageVerificationError::IncorrectBeginString(begin_string) => { + let new_state = handle_incorrect_begin_string(ctx, writer, begin_string).await; + Ok(TransitionResult::TransitionTo(new_state)) + } + MessageVerificationError::IncorrectCompId { + comp_id, + comp_id_type, + msg_seq_num, + } => { + let new_state = + handle_incorrect_comp_id(ctx, writer, comp_id, comp_id_type, msg_seq_num).await; + Ok(TransitionResult::TransitionTo(new_state)) + } + MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { + handle_sending_time_accuracy_problem( + ctx, + writer, + msg_seq_num, + "unexpected sending time", + ) + .await; + Ok(TransitionResult::Stay) + } + MessageVerificationError::SendingTimeMissing { msg_seq_num } => { + handle_sending_time_accuracy_problem(ctx, writer, msg_seq_num, "sending time missing") + .await; + Ok(TransitionResult::Stay) + } + MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { + handle_original_sending_time_missing(ctx, writer, msg_seq_num).await; + Ok(TransitionResult::Stay) + } + MessageVerificationError::OriginalSendingTimeAfterSendingTime { msg_seq_num, .. } => { + handle_sending_time_accuracy_problem( + ctx, + writer, + msg_seq_num, + "original sending time is after sending time", + ) + .await; + Ok(TransitionResult::Stay) + } + } +} + +async fn handle_incorrect_begin_string( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + received_begin_string: String, +) -> SessionState { + logout_and_terminate( + ctx, + writer, + &format!("beginString={received_begin_string} is not supported"), + ) + .await; + SessionState::new_disconnected(true, "incorrect begin string") +} + +async fn handle_incorrect_comp_id( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + received_comp_id: String, + comp_id_type: CompIdType, + msg_seq_num: u64, +) -> SessionState { + error!("rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})"); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&format!("invalid comp ID {received_comp_id}")); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject message with invalid comp ID: {err}"); + } + logout_and_terminate(ctx, writer, "incorrect comp ID received").await; + SessionState::new_disconnected(true, "incorrect comp ID") +} + +async fn handle_sequence_number_too_low( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + expected: u64, + actual: u64, + possible_duplicate: bool, +) -> TransitionResult { + if possible_duplicate { + warn!( + "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" + ); + return TransitionResult::Stay; + } + error!( + "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." + ); + let reason = format!("sequence number too low (actual {actual}, expected {expected})"); + logout_and_terminate(ctx, writer, &reason).await; + TransitionResult::TransitionTo(SessionState::new_disconnected(false, &reason)) +} + +async fn handle_sending_time_accuracy_problem( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + msg_seq_num: u64, + text: &str, +) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) + .text(text); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject for time accuracy problem: {err}"); + } + if let Err(err) = ctx.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } +} + +async fn handle_original_sending_time_missing( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + msg_seq_num: u64, +) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("original sending time is required"); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject for time missing tag: {err}"); + } + if let Err(err) = ctx.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } +} + +/// Send a logout message and immediately disconnect. +pub(crate) async fn logout_and_terminate( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + reason: &str, +) { + let logout = Logout::with_reason(reason.to_string()); + match ctx.prepare_message(logout).await { + Ok(prepared) => writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout during session termination: {err}"), + } + writer.disconnect().await; +} + +pub async fn resend_messages( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + begin: u64, + end: u64, +) -> Result<(), SessionOperationError> { + info!(begin, end, "resending messages as requested"); + let messages = ctx.store.get_slice(begin as usize, end as usize).await?; + + let no = messages.len(); + debug!(number_of_messages = no, "number of messages"); + + let mut reset_start: Option = None; + let mut sequence_number = 0; + + for msg in messages { + let mut message = ctx + .message_builder + .build(msg.as_slice()) + .into_message() + .ok_or_else(|| { + SessionOperationError::StoredMessageParse(format!( + "failed to build message for raw message: {msg:?}" + )) + })?; + sequence_number = get_msg_seq_num(&message); + let message_type: String = message + .header() + .get::<&str>(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? + .to_string(); + + if is_admin(&message_type) { + if reset_start.is_none() { + reset_start = Some(sequence_number); + } + continue; + } + + if let Some(begin) = reset_start { + let end = sequence_number; + log_skipped_admin_messages(begin, end); + send_sequence_reset(ctx, writer, begin, end).await?; + reset_start = None; + } + + if let Err(e) = prepare_message_for_resend(&mut message) { + error!( + error = e, + "failed to prepare message for resend, sending original" + ); + } + writer + .send_raw_message(RawFixMessage::new(message.encode(ctx.message_config)?)) + .await; + + if enabled!(tracing::Level::DEBUG) + && let Ok(m) = String::from_utf8(msg.clone()) + { + debug!(sequence_number, message = m, "resent message"); + } + } + + if let Some(begin) = reset_start { + // the final reset if needed + let end = sequence_number; + log_skipped_admin_messages(begin, end); + send_sequence_reset(ctx, writer, begin, end).await?; + } + + Ok(()) +} + +async fn send_sequence_reset( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + begin: u64, + end: u64, +) -> Result<(), SessionOperationError> { + let sequence_reset = SequenceReset { + gap_fill: true, + new_seq_no: end, + }; + let raw_message = generate_message( + &ctx.config.begin_string, + &ctx.config.sender_comp_id, + &ctx.config.target_comp_id, + begin, + sequence_reset, + )?; + + writer + .send_raw_message(RawFixMessage::new(raw_message)) + .await; + debug!(begin, end, "sent reset sequence"); + + Ok(()) +} + +fn log_skipped_admin_messages(begin: u64, end: u64) { + info!( + begin, + end, "skipped admin message(s) during resend, requesting reset for these" + ); +} + +pub async fn handle_invalid_parsed_message( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + message: &Message, + reason: InvalidReason, +) -> Result<(), SessionOperationError> { + match reason { + InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidTagNumber) + .text(&format!("invalid field {tag}")); + ctx.send_message(writer, reject) + .await + .with_send_context("reject for invalid field")?; + } + } + InvalidReason::InvalidComponent(_component_name) => { + warn!("received invalid component"); + } + InvalidReason::InvalidMsgType(msg_type) => { + handle_invalid_msg_type(ctx, writer, message, &msg_type).await; + } + InvalidReason::InvalidOrderInGroup { tag, .. } => { + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RepeatingGroupFieldsOutOfOrder) + .text(&format!("field appears in incorrect order:{tag}")); + ctx.send_message(writer, reject) + .await + .with_send_context("reject for invalid group order")?; + } + } + } + Ok(()) +} + +async fn handle_invalid_msg_type( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + message: &Message, + msg_type: &str, +) { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidMsgtype) + .text(&format!("invalid message type {msg_type}")); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject message for invalid msgtype: {err}"); + } + + #[allow(clippy::collapsible_if)] + if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) + && ctx.store.next_target_seq_number() == seq_num + { + if let Err(err) = ctx.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + } + Err(err) => { + error!("failed to get message seq num: {:?}", err); + } + } +} diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index fa84472d..0b2f76d6 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -1,36 +1,36 @@ -use crate::message::logon::Logon; -use crate::message::logout::Logout; -use crate::message::parser::RawFixMessage; +mod active; +mod awaiting_logon; +mod awaiting_logout; +mod awaiting_resend; +mod disconnected; + +pub(crate) use crate::session::ctx::{SessionCtx, TransitionResult, VerifyResult}; +pub(crate) use active::{ActiveState, calculate_peer_interval}; +pub(crate) use awaiting_logon::AwaitingLogonState; +pub(crate) use awaiting_logout::AwaitingLogoutState; +pub(crate) use awaiting_resend::AwaitingResendState; +pub(crate) use disconnected::DisconnectedState; + use crate::session::event::AwaitingActiveSessionResponse; use crate::session::info::Status as SessionInfoStatus; use crate::transport::writer::WriterRef; -use hotfix_message::message::Message; -use std::collections::VecDeque; +use hotfix_store::MessageStore; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; -use tracing::{debug, error}; +use tracing::error; -const TEST_REQUEST_THRESHOLD: f64 = 1.2; -const MAX_RESEND_ATTEMPTS: usize = 3; +pub(crate) const TEST_REQUEST_THRESHOLD: f64 = 1.2; pub(crate) type TestRequestId = String; pub enum SessionState { /// We have established a connection, sent a logon message and await a response. - AwaitingLogon { - writer: WriterRef, - logon_sent: bool, - logon_timeout: Instant, - }, + AwaitingLogon(AwaitingLogonState), /// We are awaiting the target to resend the gap we have. AwaitingResend(AwaitingResendState), /// We are in the process of gracefully logging out - AwaitingLogout { - writer: WriterRef, // we need the writer so we can disconnect it on successful logout - logout_timeout: Instant, - reconnect: bool, // we carry this forward for the subsequent disconnected state - }, + AwaitingLogout(AwaitingLogoutState), /// The session is active, we have connected and mutually logged on. Active(ActiveState), /// The TCP connection has been dropped. @@ -58,121 +58,32 @@ impl SessionState { pub fn should_reconnect(&self) -> bool { match self { - SessionState::Disconnected(DisconnectedState { reconnect, .. }) => *reconnect, + SessionState::Disconnected(state) => state.should_reconnect(), _ => true, } } - pub async fn send_message(&mut self, message_type: &str, message: RawFixMessage) { - match self { - Self::Active(ActiveState { writer, .. }) - | Self::AwaitingResend(AwaitingResendState { writer, .. }) => { - if message_type == Logon::MSG_TYPE { - error!("logon message is invalid for active sessions") - } else { - writer.send_raw_message(message).await - } - } - Self::AwaitingLogon { - writer, logon_sent, .. - } => match message_type { - Logon::MSG_TYPE => { - if *logon_sent { - error!("trying to send logon twice"); - } else { - writer.send_raw_message(message).await; - *logon_sent = true; - } - } - Logout::MSG_TYPE => { - writer.send_raw_message(message).await; - } - _ => error!("invalid outgoing message for AwaitingLogon state"), - }, - Self::AwaitingLogout { writer, .. } => { - // Logout messages are allowed because we first transition into AwaitingLogout - // and only then send the logout message - if message_type == Logout::MSG_TYPE { - writer.send_raw_message(message).await - } - } - _ => error!("trying to write without an established connection"), - } - } - - pub async fn disconnect_writer(&self) { - match self { - Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon { writer, .. } - | Self::AwaitingLogout { writer, .. } - | Self::AwaitingResend(AwaitingResendState { writer, .. }) => writer.disconnect().await, - _ => debug!("disconnecting an already disconnected session"), - } + pub(crate) fn is_connected(&self) -> bool { + self.get_writer().is_some() } - fn get_writer(&self) -> Option<&WriterRef> { + pub(crate) fn get_writer(&self) -> Option<&WriterRef> { match self { Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon { writer, .. } - | Self::AwaitingLogout { writer, .. } + | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) + | Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) | Self::AwaitingResend(AwaitingResendState { writer, .. }) => Some(writer), _ => None, } } - pub fn try_transition_to_awaiting_logout( - &mut self, - logout_timeout: Duration, - reconnect: bool, - ) -> bool { - if matches!(self, SessionState::AwaitingLogout { .. }) { - debug!("already in awaiting logout state"); - return false; - } - - if let Some(writer) = self.get_writer() { - *self = SessionState::AwaitingLogout { - writer: writer.clone(), - logout_timeout: Instant::now() + logout_timeout, - reconnect, - }; - true - } else { - error!("trying to transition to awaiting logout without an established connection"); - false - } - } - - pub fn try_transition_to_awaiting_resend( - &mut self, - begin: u64, - end: u64, - ) -> AwaitingResendTransitionOutcome { - match self { - SessionState::AwaitingLogon { writer, .. } - | SessionState::Active(ActiveState { writer, .. }) => { - let awaiting_resend = AwaitingResendState::new(writer.to_owned(), begin, end); - *self = SessionState::AwaitingResend(awaiting_resend); - AwaitingResendTransitionOutcome::Success - } - SessionState::AwaitingResend(state) => state.update(begin, end), - SessionState::AwaitingLogout { .. } => AwaitingResendTransitionOutcome::InvalidState( - "trying to request a resend while we are already logging out".to_string(), - ), - SessionState::Disconnected(_) => AwaitingResendTransitionOutcome::InvalidState( - "trying to transition to awaiting resend without an established connection" - .to_string(), - ), - } - } - pub fn register_session_awaiter( &mut self, responder: oneshot::Sender, ) { match self { SessionState::Disconnected(state) => { - if state.has_session_awaiter() { + if let Err(responder) = state.register_session_awaiter(responder) { let reason = &state.reason; error!( "session awaiter already registered on state disconnected due to: {reason}" @@ -180,9 +91,6 @@ impl SessionState { if let Err(err) = responder.send(AwaitingActiveSessionResponse::Shutdown) { error!("failed to send session awaiter response: {err:?}"); } - } else { - state.set_session_awaiter(responder); - debug!("registered session awaiter"); } } _ => { @@ -195,95 +103,49 @@ impl SessionState { } pub fn notify_session_awaiter(&mut self) { - if let SessionState::Disconnected(state) = self - && let Some(awaiter) = state.take_session_awaiter() - { - if let Err(err) = awaiter.send(AwaitingActiveSessionResponse::Active) { - error!("failed to send session awaiter response: {err:?}"); - } else { - debug!("notified session awaiter"); - } + if let SessionState::Disconnected(state) = self { + state.notify_session_awaiter(); } } - pub fn heartbeat_deadline(&self) -> Option<&Instant> { - match self { - Self::Active(ActiveState { - heartbeat_deadline, .. - }) => Some(heartbeat_deadline), - _ => None, - } - } - - pub fn reset_heartbeat_timer(&mut self, heartbeat_interval: u64) { - if let Self::Active(ActiveState { - heartbeat_deadline, .. - }) = self - { - *heartbeat_deadline = Instant::now() + Duration::from_secs(heartbeat_interval); + /// Send a logout message and immediately disconnect, if connected. + pub(crate) async fn logout_and_terminate( + &self, + ctx: &mut SessionCtx<'_, Store>, + reason: &str, + ) { + if let Some(writer) = self.get_writer() { + super::message_handling::logout_and_terminate(ctx, writer, reason).await; } } - pub fn peer_deadline(&self) -> Option<&Instant> { + pub fn heartbeat_deadline(&self) -> Option<&Instant> { match self { - Self::Active(ActiveState { peer_deadline, .. }) => Some(peer_deadline), - Self::AwaitingLogon { logon_timeout, .. } => Some(logon_timeout), - Self::AwaitingLogout { logout_timeout, .. } => Some(logout_timeout), + Self::Active(state) => Some(state.heartbeat_deadline()), _ => None, } } - pub fn reset_peer_timer( - &mut self, - heartbeat_interval: u64, - test_request_id: Option, - ) { - if let Self::Active(ActiveState { - peer_deadline, - sent_test_request_id, - .. - }) = self - { - let interval = calculate_peer_interval(heartbeat_interval); - *peer_deadline = Instant::now() + Duration::from_secs(interval); - *sent_test_request_id = test_request_id; - } - } - - pub fn expected_test_response_id(&self) -> Option<&TestRequestId> { + pub fn peer_deadline(&self) -> Option<&Instant> { match self { - Self::Active(ActiveState { - sent_test_request_id: expected_test_response_id, - .. - }) => expected_test_response_id.as_ref(), + Self::Active(state) => Some(state.peer_deadline()), + Self::AwaitingLogon(AwaitingLogonState { logon_timeout, .. }) => Some(logon_timeout), + Self::AwaitingLogout(AwaitingLogoutState { logout_timeout, .. }) => { + Some(logout_timeout) + } _ => None, } } - pub fn is_connected(&self) -> bool { - self.get_writer().is_some() - } - + #[cfg(test)] pub fn is_logged_on(&self) -> bool { matches!(self, SessionState::Active(_)) || matches!(self, SessionState::AwaitingResend { .. }) } - pub fn is_expecting_test_response(&self) -> bool { - self.expected_test_response_id().is_some() - } - - pub fn is_awaiting_logon(&self) -> bool { - matches!(self, SessionState::AwaitingLogon { .. }) - } - - pub fn is_awaiting_logout(&self) -> bool { - matches!(self, SessionState::AwaitingLogout { .. }) - } - pub fn as_status(&self) -> SessionInfoStatus { match self { - SessionState::AwaitingLogon { .. } => SessionInfoStatus::AwaitingLogon, + SessionState::AwaitingLogon(_) => SessionInfoStatus::AwaitingLogon, SessionState::AwaitingResend(AwaitingResendState { begin_seq_number, end_seq_number, @@ -294,165 +156,9 @@ impl SessionState { end: *end_seq_number, attempts: *resend_attempts, }, - SessionState::AwaitingLogout { .. } => SessionInfoStatus::AwaitingLogout, + SessionState::AwaitingLogout(_) => SessionInfoStatus::AwaitingLogout, SessionState::Active(_) => SessionInfoStatus::Active, SessionState::Disconnected(_) => SessionInfoStatus::Disconnected, } } } - -#[inline] -fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { - (heartbeat_interval as f64 * TEST_REQUEST_THRESHOLD).round() as u64 -} - -pub struct ActiveState { - /// The writer's reference to send messages to the counterparty - writer: WriterRef, - /// When we should send the next heartbeat message to the counterparty - heartbeat_deadline: Instant, - /// When the next message from the counterparty is expected at the latest - peer_deadline: Instant, - /// The ID of the test request we sent on peer timer expiry - sent_test_request_id: Option, -} - -/// Session state we're in while processing messages we requested to be resent. -pub struct AwaitingResendState { - /// The reference to the writer loop. - pub(crate) writer: WriterRef, - /// The beginning of the gap we're waiting for the target to resend. - pub(crate) begin_seq_number: u64, - /// The end of the gap we're waiting for the target to resend. - pub(crate) end_seq_number: u64, - /// Inbound messages we receive while processing the resend. - pub(crate) inbound_queue: VecDeque, - /// The number of times we've attempted to ask the counterparty to resend the gap. - pub(crate) resend_attempts: usize, -} - -impl AwaitingResendState { - fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { - Self { - writer, - begin_seq_number, - end_seq_number, - inbound_queue: Default::default(), - resend_attempts: 1, - } - } - - fn update( - &mut self, - begin_seq_number: u64, - end_seq_number: u64, - ) -> AwaitingResendTransitionOutcome { - let resend_attempts = if self.begin_seq_number == begin_seq_number { - if self.resend_attempts + 1 > MAX_RESEND_ATTEMPTS { - return AwaitingResendTransitionOutcome::AttemptsExceeded; - } - self.resend_attempts + 1 - } else if begin_seq_number < self.begin_seq_number { - return AwaitingResendTransitionOutcome::BeginSeqNumberTooLow; - } else { - 1 - }; - - self.resend_attempts = resend_attempts; - self.begin_seq_number = begin_seq_number; - self.end_seq_number = end_seq_number; - - AwaitingResendTransitionOutcome::Success - } -} - -pub struct DisconnectedState { - reconnect: bool, - session_awaiter: Option>, - reason: String, -} - -impl DisconnectedState { - fn new(reconnect: bool, reason: &str) -> Self { - Self { - reconnect, - session_awaiter: None, - reason: reason.to_string(), - } - } - - fn set_session_awaiter(&mut self, responder: oneshot::Sender) { - self.session_awaiter = Some(responder); - } - - fn has_session_awaiter(&self) -> bool { - self.session_awaiter.is_some() - } - - fn take_session_awaiter(&mut self) -> Option> { - self.session_awaiter.take() - } -} - -pub enum AwaitingResendTransitionOutcome { - Success, - InvalidState(String), - BeginSeqNumberTooLow, - AttemptsExceeded, -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::sync::mpsc; - - #[test] - fn test_awaiting_resend_transition_begin_seq_number_too_low() { - let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - let result = state.try_transition_to_awaiting_resend(0, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::BeginSeqNumberTooLow - )); - } - - #[test] - fn test_awaiting_resend_transition_attempts_exceeded() { - let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - - // we can transition twice more without hitting the limit - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - - // the fourth time we'd get into an AwaitingResendState with the same begin seq number, we get an error - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::AttemptsExceeded - )); - } - - #[test] - fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { - let mut state = SessionState::AwaitingLogout { - writer: create_writer_ref(), - logout_timeout: Instant::now(), - reconnect: false, - }; - - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::InvalidState(_) - )); - } - - fn create_writer_ref() -> WriterRef { - let (sender, _) = mpsc::channel(10); - WriterRef::new(sender) - } -} diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs new file mode 100644 index 00000000..da3b86d3 --- /dev/null +++ b/crates/hotfix/src/session/state/active.rs @@ -0,0 +1,481 @@ +use crate::Application; +use crate::application::{InboundDecision, OutboundDecision}; +use crate::message::business_reject::BusinessReject; +use crate::message::heartbeat::Heartbeat; +use crate::message::logon::Logon; +use crate::message::logout::Logout; +use crate::message::reject::Reject; +use crate::message::resend_request::ResendRequest; +use crate::message::sequence_reset::SequenceReset; +use crate::message::test_request::TestRequest; +use crate::session::error::{InternalSendResultExt, SendError, SendOutcome, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::message_handling; +use crate::session::state::{ + AwaitingResendState, SessionCtx, SessionState, TestRequestId, TransitionResult, VerifyResult, +}; +use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::session_fields::{ + BEGIN_SEQ_NO, END_SEQ_NO, GAP_FILL_FLAG, MSG_TYPE, NEW_SEQ_NO, SessionRejectReason, TEST_REQ_ID, +}; +use hotfix_store::MessageStore; +use std::time::Duration; +use tokio::time::Instant; +use tracing::{debug, error, info, warn}; + +pub(crate) struct ActiveState { + /// The writer's reference to send messages to the counterparty + pub(crate) writer: WriterRef, + /// When we should send the next heartbeat message to the counterparty + pub(crate) heartbeat_deadline: Instant, + /// When the next message from the counterparty is expected at the latest + pub(crate) peer_deadline: Instant, + /// The ID of the test request we sent on peer timer expiry + pub(crate) sent_test_request_id: Option, +} + +impl ActiveState { + pub(crate) fn heartbeat_deadline(&self) -> &Instant { + &self.heartbeat_deadline + } + + pub(crate) fn reset_heartbeat_timer(&mut self, heartbeat_interval: u64) { + self.heartbeat_deadline = Instant::now() + Duration::from_secs(heartbeat_interval); + } + + pub(crate) fn peer_deadline(&self) -> &Instant { + &self.peer_deadline + } + + pub(crate) fn reset_peer_timer( + &mut self, + heartbeat_interval: u64, + test_request_id: Option, + ) { + let interval = calculate_peer_interval(heartbeat_interval); + self.peer_deadline = Instant::now() + Duration::from_secs(interval); + self.sent_test_request_id = test_request_id; + } + + pub(crate) fn expected_test_response_id(&self) -> Option<&TestRequestId> { + self.sent_test_request_id.as_ref() + } + + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { + self.writer.disconnect().await; + SessionState::new_disconnected(true, reason) + } + + pub(crate) async fn on_peer_timeout( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + ) -> Option { + if self.sent_test_request_id.is_some() { + warn!("peer didn't respond, terminating.."); + let logout = Logout::with_reason("peer timeout".to_string()); + if let Ok(prepared) = ctx.prepare_message(logout).await { + self.writer.send_raw_message(prepared.raw).await; + } + self.writer.disconnect().await; + return Some(SessionState::new_disconnected(true, "peer timeout")); + } + + let req_id = format!("TEST_{}", ctx.store.next_target_seq_number()); + info!("sending TestRequest due to peer timer expiring"); + let request = TestRequest::new(req_id.clone()); + match ctx.prepare_message(request).await { + Ok(prepared) => { + self.writer.send_raw_message(prepared.raw).await; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + Err(err) => { + error!(err = ?err, "failed to send TestRequest"); + } + } + self.reset_peer_timer(ctx.config.heartbeat_interval, Some(req_id)); + None + } + + pub(crate) async fn on_heartbeat_timeout( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + ) { + let prepared = match ctx.prepare_message(Heartbeat::default()).await { + Ok(prepared) => prepared, + Err(err) => { + error!(err = ?err, "failed to send heartbeat message"); + return; + } + }; + self.writer.send_raw_message(prepared.raw).await; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + + pub(crate) async fn on_fix_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: hotfix_message::message::Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + match message_type { + Heartbeat::MSG_TYPE => self.on_heartbeat(ctx, &message).await, + TestRequest::MSG_TYPE => self.on_test_request(ctx, &message).await, + ResendRequest::MSG_TYPE => self.on_resend_request(ctx, &message).await, + Reject::MSG_TYPE => self.on_reject(ctx, &message).await, + SequenceReset::MSG_TYPE => self.on_sequence_reset(ctx, &message).await, + Logout::MSG_TYPE => self.on_logout(ctx, app, &message).await, + Logon::MSG_TYPE => { + error!("received unexpected logon message"); + Ok(TransitionResult::Stay) + } + _ => self.on_app_message(ctx, app, &message).await, + } + } + + async fn on_heartbeat( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + if let (Some(expected_req_id), Ok(message_req_id)) = ( + self.expected_test_response_id(), + message.get::<&str>(TEST_REQ_ID), + ) && expected_req_id.as_str() == message_req_id + { + debug!("received response for TestRequest, resetting timer"); + self.reset_peer_timer(ctx.config.heartbeat_interval, None); + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_test_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { + // TODO: send reject? + todo!() + }); + + ctx.store.increment_target_seq_number().await?; + + ctx.send_message(&self.writer, Heartbeat::for_request(req_id.to_string())) + .await + .with_send_context("heartbeat response")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + + Ok(TransitionResult::Stay) + } + + async fn on_resend_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // ResendRequest with check_too_high=false should never get SeqTooHigh, + // but handle gracefully + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + let msg_seq_num = get_msg_seq_num(message); + let expected = ctx.store.next_target_seq_number(); + + let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { + Ok(seq_number) => seq_number, + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing begin sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing BEGIN_SEQ_NO")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + }; + + let end_seq_number: u64 = match message.get(END_SEQ_NO) { + Ok(seq_number) => { + let last_seq_number = ctx.store.next_sender_seq_number() - 1; + if seq_number == 0 { + last_seq_number + } else { + std::cmp::min(seq_number, last_seq_number) + } + } + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing end sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing END_SEQ_NO")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + }; + + // Only increment target seq if seq matches expected + if msg_seq_num == expected { + ctx.store.increment_target_seq_number().await?; + } + + message_handling::resend_messages(ctx, &self.writer, begin_seq_number, end_seq_number) + .await?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + + Ok(TransitionResult::Stay) + } + + async fn on_reject( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_sequence_reset( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + let msg_seq_num = get_msg_seq_num(message); + let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); + match message_handling::verify_and_handle( + ctx, + &self.writer, + message, + is_gap_fill, + is_gap_fill, + ) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + let end: u64 = match message.get(NEW_SEQ_NO) { + Ok(new_seq_no) => new_seq_no, + Err(err) => { + error!( + "received sequence reset message without new sequence number: {:?}", + err + ); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing NewSeqNo tag in sequence reset message"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing NEW_SEQ_NO")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + }; + + if end <= ctx.store.next_target_seq_number() { + error!( + "received sequence reset message which would move target seq number backwards: {end}", + ); + let text = + format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&text); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for invalid sequence reset")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + + ctx.store.set_target_seq_number(end - 1).await?; + Ok(TransitionResult::Stay) + } + + async fn on_logout( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &hotfix_message::message::Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, false).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // verify with check_too_high=false, so this shouldn't happen + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + // We are logged on, send logout response + let logout = Logout::with_reason("Logout acknowledged".to_string()); + match ctx.prepare_message(logout).await { + Ok(prepared) => { + self.writer.send_raw_message(prepared.raw).await; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + Err(err) => warn!("failed to send logout acknowledgement: {err}"), + } + + app.on_logout("peer has logged us out").await; + + self.writer.disconnect().await; + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected(true, "peer has logged us out"), + )) + } + + async fn on_app_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &hotfix_message::message::Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + match app.on_inbound_message(message).await { + InboundDecision::Accept => {} + InboundDecision::Reject { reason, text } => { + let msg_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + let mut reject = + BusinessReject::new(msg_type, reason).ref_seq_num(get_msg_seq_num(message)); + if let Some(text) = text { + reject = reject.text(&text); + } + ctx.send_message(&self.writer, reject) + .await + .with_send_context("business message reject")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + InboundDecision::TerminateSession => { + error!("failed to send inbound message to application"); + self.writer.disconnect().await; + } + } + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::Stay) + } + + async fn transition_to_awaiting_resend( + &self, + ctx: &mut SessionCtx<'_, Store>, + expected: u64, + actual: u64, + ) -> Result { + debug!("we are behind target (ours: {expected}, theirs: {actual}), requesting resend."); + let request = ResendRequest::new(expected, actual); + ctx.send_message(&self.writer, request) + .await + .with_send_context("resend request")?; + let new_state = SessionState::AwaitingResend(AwaitingResendState::new( + self.writer.clone(), + expected, + actual, + )); + Ok(TransitionResult::TransitionTo(new_state)) + } + + pub(crate) async fn send_app_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: App::Outbound, + ) -> Result { + match app.on_outbound_message(&message).await { + OutboundDecision::Send => { + let seq_num = + ctx.send_message(&self.writer, message) + .await + .map_err(|e| match e { + crate::session::error::InternalSendError::Persist(e) => { + SendError::Persist(e) + } + crate::session::error::InternalSendError::SequenceNumber(e) => { + SendError::SequenceNumber(e) + } + })?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + Ok(SendOutcome::Sent { + sequence_number: seq_num, + }) + } + OutboundDecision::Drop => { + debug!("dropped outbound message as instructed by the application"); + Ok(SendOutcome::Dropped) + } + OutboundDecision::TerminateSession => { + warn!("the application indicated we should terminate the session"); + self.writer.disconnect().await; + Err(SendError::SessionTerminated) + } + } + } +} + +#[inline] +pub(crate) fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { + (heartbeat_interval as f64 * super::TEST_REQUEST_THRESHOLD).round() as u64 +} diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs new file mode 100644 index 00000000..7c5f4850 --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -0,0 +1,79 @@ +use crate::Application; +use crate::message::logon::Logon; +use crate::session::error::SessionOperationError; +use crate::session::message_handling; +use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; +use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::session_fields::MSG_TYPE; +use hotfix_store::MessageStore; +use tokio::time::Instant; +use tracing::warn; + +pub(crate) struct AwaitingLogonState { + pub(crate) writer: WriterRef, + pub(crate) logon_timeout: Instant, +} + +impl AwaitingLogonState { + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { + self.writer.disconnect().await; + SessionState::new_disconnected(true, reason) + } + + pub(crate) async fn on_peer_timeout(&self) { + warn!("peer didn't respond to our Logon, disconnecting.."); + self.writer.disconnect().await; + } + + pub(crate) async fn on_fix_message( + &self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: hotfix_message::message::Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + if message_type != Logon::MSG_TYPE { + self.writer.disconnect().await; + return Ok(TransitionResult::Stay); + } + + // process logon + match message_handling::verify_and_handle(ctx, &self.writer, &message, true, true).await? { + VerifyResult::Passed => { + // happy logon flow, the session is now active + let new_state = + SessionState::new_active(self.writer.clone(), ctx.config.heartbeat_interval); + app.on_logon().await; + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::TransitionTo(new_state)) + } + VerifyResult::SeqTooHigh { expected, actual } => { + // Unusual during logon, but handle it + use crate::message::resend_request::ResendRequest; + use crate::session::error::InternalSendResultExt; + use crate::session::state::AwaitingResendState; + use tracing::debug; + + debug!( + "we are behind target during logon (ours: {expected}, theirs: {actual}), requesting resend." + ); + let request = ResendRequest::new(expected, actual); + ctx.send_message(&self.writer, request) + .await + .with_send_context("resend request")?; + let new_state = SessionState::AwaitingResend(AwaitingResendState::new( + self.writer.clone(), + expected, + actual, + )); + Ok(TransitionResult::TransitionTo(new_state)) + } + VerifyResult::Handled(transition) => Ok(transition), + } + } +} diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs new file mode 100644 index 00000000..544d2e01 --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -0,0 +1,74 @@ +use crate::Application; +use crate::message::logout::Logout; +use crate::session::error::SessionOperationError; +use crate::session::message_handling; +use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; +use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::session_fields::MSG_TYPE; +use hotfix_store::MessageStore; +use tokio::time::Instant; +use tracing::warn; + +pub(crate) struct AwaitingLogoutState { + pub(crate) writer: WriterRef, + pub(crate) logout_timeout: Instant, + pub(crate) reconnect: bool, +} + +impl AwaitingLogoutState { + pub(crate) fn new(writer: WriterRef, logout_timeout: Instant, reconnect: bool) -> Self { + Self { + writer, + logout_timeout, + reconnect, + } + } + + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { + SessionState::new_disconnected(self.reconnect, reason) + } + + pub(crate) async fn on_peer_timeout(&self) -> SessionState { + warn!("peer didn't respond to our Logout, disconnecting.."); + self.writer.disconnect().await; + SessionState::new_disconnected(self.reconnect, "logout timeout") + } + + pub(crate) async fn on_fix_message( + &self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: hotfix_message::message::Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + if message_type == Logout::MSG_TYPE { + // Process the logout + match message_handling::verify_and_handle(ctx, &self.writer, &message, false, false) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // verify with check_too_high=false, shouldn't happen + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + app.on_logout("peer has logged us out").await; + self.writer.disconnect().await; + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected(self.reconnect, "logout completed"), + )) + } else { + // Other messages during logout: increment target seq and stay + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + } +} diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs new file mode 100644 index 00000000..9a17a100 --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -0,0 +1,489 @@ +use crate::Application; +use crate::application::InboundDecision; +use crate::message::business_reject::BusinessReject; +use crate::message::heartbeat::Heartbeat; +use crate::message::logon::Logon; +use crate::message::logout::Logout; +use crate::message::reject::Reject; +use crate::message::resend_request::ResendRequest; +use crate::message::sequence_reset::SequenceReset; +use crate::message::test_request::TestRequest; +use crate::session::error::{InternalSendResultExt, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::message_handling; +use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; +use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::message::Message; +use hotfix_message::session_fields::{ + BEGIN_SEQ_NO, END_SEQ_NO, GAP_FILL_FLAG, MSG_TYPE, NEW_SEQ_NO, SessionRejectReason, TEST_REQ_ID, +}; +use hotfix_store::MessageStore; +use std::collections::VecDeque; +use tracing::{debug, error, warn}; + +const MAX_RESEND_ATTEMPTS: usize = 3; + +/// Session state we're in while processing messages we requested to be resent. +pub(crate) struct AwaitingResendState { + /// The reference to the writer loop. + pub(crate) writer: WriterRef, + /// The beginning of the gap we're waiting for the target to resend. + pub(crate) begin_seq_number: u64, + /// The end of the gap we're waiting for the target to resend. + pub(crate) end_seq_number: u64, + /// Inbound messages we receive while processing the resend. + pub(crate) inbound_queue: VecDeque, + /// The number of times we've attempted to ask the counterparty to resend the gap. + pub(crate) resend_attempts: usize, +} + +impl AwaitingResendState { + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { + self.writer.disconnect().await; + SessionState::new_disconnected(true, reason) + } + + pub(crate) fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { + Self { + writer, + begin_seq_number, + end_seq_number, + inbound_queue: Default::default(), + resend_attempts: 1, + } + } + + pub(crate) fn update( + &mut self, + begin_seq_number: u64, + end_seq_number: u64, + ) -> AwaitingResendTransitionOutcome { + let resend_attempts = if self.begin_seq_number == begin_seq_number { + if self.resend_attempts + 1 > MAX_RESEND_ATTEMPTS { + return AwaitingResendTransitionOutcome::AttemptsExceeded; + } + self.resend_attempts + 1 + } else if begin_seq_number < self.begin_seq_number { + return AwaitingResendTransitionOutcome::BeginSeqNumberTooLow; + } else { + 1 + }; + + self.resend_attempts = resend_attempts; + self.begin_seq_number = begin_seq_number; + self.end_seq_number = end_seq_number; + + AwaitingResendTransitionOutcome::Success + } + + pub(crate) async fn on_fix_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + let seq_number = get_msg_seq_num(&message); + + // If msg seq > end_seq_number AND not ResendRequest: queue it + if seq_number > self.end_seq_number && message_type != ResendRequest::MSG_TYPE { + self.inbound_queue.push_back(message); + return Ok(TransitionResult::Stay); + } + + // Dispatch by message type + let result = match message_type { + Heartbeat::MSG_TYPE => self.on_heartbeat(ctx, &message).await?, + TestRequest::MSG_TYPE => self.on_test_request(ctx, &message).await?, + ResendRequest::MSG_TYPE => self.on_resend_request(ctx, &message).await?, + Reject::MSG_TYPE => self.on_reject(ctx, &message).await?, + SequenceReset::MSG_TYPE => self.on_sequence_reset(ctx, &message).await?, + Logout::MSG_TYPE => self.on_logout(ctx, app, &message).await?, + Logon::MSG_TYPE => { + error!("received unexpected logon message"); + TransitionResult::Stay + } + _ => self.on_app_message(ctx, app, &message).await?, + }; + + // If a transition happened, return it directly + if !matches!(result, TransitionResult::Stay) { + return Ok(result); + } + + // Check if resend is done + self.check_end_of_resend(ctx) + } + + fn check_end_of_resend( + &mut self, + ctx: &SessionCtx<'_, Store>, + ) -> Result { + if ctx.store.next_target_seq_number() > self.end_seq_number { + let new_state = + SessionState::new_active(self.writer.clone(), ctx.config.heartbeat_interval); + let backlog = std::mem::take(&mut self.inbound_queue); + Ok(TransitionResult::TransitionWithBacklog { new_state, backlog }) + } else { + Ok(TransitionResult::Stay) + } + } + + async fn on_heartbeat( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_test_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| todo!()); + + ctx.store.increment_target_seq_number().await?; + + ctx.send_message(&self.writer, Heartbeat::for_request(req_id.to_string())) + .await + .with_send_context("heartbeat response")?; + + Ok(TransitionResult::Stay) + } + + async fn on_resend_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // check_too_high=false, shouldn't happen + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + let msg_seq_num = get_msg_seq_num(message); + let expected = ctx.store.next_target_seq_number(); + + // If seq is too high, queue it for seq accounting when the gap fill catches up, + // but still process the resend below. + if msg_seq_num > expected { + self.inbound_queue.push_back(message.clone()); + } + + let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { + Ok(seq_number) => seq_number, + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing begin sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing BEGIN_SEQ_NO")?; + return Ok(TransitionResult::Stay); + } + }; + + let end_seq_number: u64 = match message.get(END_SEQ_NO) { + Ok(seq_number) => { + let last_seq_number = ctx.store.next_sender_seq_number() - 1; + if seq_number == 0 { + last_seq_number + } else { + std::cmp::min(seq_number, last_seq_number) + } + } + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing end sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing END_SEQ_NO")?; + return Ok(TransitionResult::Stay); + } + }; + + if msg_seq_num == expected { + ctx.store.increment_target_seq_number().await?; + } + + message_handling::resend_messages(ctx, &self.writer, begin_seq_number, end_seq_number) + .await?; + + Ok(TransitionResult::Stay) + } + + async fn on_reject( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_sequence_reset( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + let msg_seq_num = get_msg_seq_num(message); + let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); + match message_handling::verify_and_handle( + ctx, + &self.writer, + message, + is_gap_fill, + is_gap_fill, + ) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + let end: u64 = match message.get(NEW_SEQ_NO) { + Ok(new_seq_no) => new_seq_no, + Err(err) => { + error!( + "received sequence reset message without new sequence number: {:?}", + err + ); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing NewSeqNo tag in sequence reset message"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing NEW_SEQ_NO")?; + return Ok(TransitionResult::Stay); + } + }; + + if end <= ctx.store.next_target_seq_number() { + error!( + "received sequence reset message which would move target seq number backwards: {end}", + ); + let text = + format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&text); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for invalid sequence reset")?; + return Ok(TransitionResult::Stay); + } + + ctx.store.set_target_seq_number(end - 1).await?; + Ok(TransitionResult::Stay) + } + + async fn on_logout( + &self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, false).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => {} + VerifyResult::Handled(transition) => return Ok(transition), + } + + // We are in AwaitingResend (logged on), send logout response + let logout = Logout::with_reason("Logout acknowledged".to_string()); + match ctx.prepare_message(logout).await { + Ok(prepared) => self.writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout acknowledgement: {err}"), + } + + app.on_logout("peer has logged us out").await; + + self.writer.disconnect().await; + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected(true, "peer has logged us out"), + )) + } + + async fn on_app_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &Message, + ) -> Result { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::Handled(transition) => return Ok(transition), + } + + match app.on_inbound_message(message).await { + InboundDecision::Accept => {} + InboundDecision::Reject { reason, text } => { + let msg_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + let mut reject = + BusinessReject::new(msg_type, reason).ref_seq_num(get_msg_seq_num(message)); + if let Some(text) = text { + reject = reject.text(&text); + } + ctx.send_message(&self.writer, reject) + .await + .with_send_context("business message reject")?; + } + InboundDecision::TerminateSession => { + error!("failed to send inbound message to application"); + self.writer.disconnect().await; + } + } + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::Stay) + } + + async fn handle_seq_too_high( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + expected: u64, + actual: u64, + ) -> Result { + match self.update(expected, actual) { + AwaitingResendTransitionOutcome::Success => { + debug!( + "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." + ); + let request = ResendRequest::new(expected, actual); + ctx.send_message(&self.writer, request) + .await + .with_send_context("resend request")?; + Ok(TransitionResult::Stay) + } + AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { + self.writer.disconnect().await; + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected( + false, + "awaiting resend begin seq number unexpectedly lower than the previous resend request's", + ), + )) + } + AwaitingResendTransitionOutcome::AttemptsExceeded => { + self.writer.disconnect().await; + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected( + false, + "resend request attempts exceeded, manual intervention required", + ), + )) + } + } + } +} + +pub(crate) enum AwaitingResendTransitionOutcome { + Success, + BeginSeqNumberTooLow, + AttemptsExceeded, +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::mpsc; + + #[test] + fn test_update_begin_seq_number_too_low() { + let writer = create_writer_ref(); + let mut state = AwaitingResendState::new(writer, 1, 5); + let result = state.update(0, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::BeginSeqNumberTooLow + )); + } + + #[test] + fn test_update_attempts_exceeded() { + let writer = create_writer_ref(); + let mut state = AwaitingResendState::new(writer, 1, 5); + + // we can update twice more without hitting the limit + let result = state.update(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + let result = state.update(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + + // the fourth time with the same begin seq number, we get an error + let result = state.update(1, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::AttemptsExceeded + )); + } + + #[test] + fn test_update_resets_attempts_on_new_begin_seq() { + let writer = create_writer_ref(); + let mut state = AwaitingResendState::new(writer, 1, 5); + + // Use up attempts on begin=1 + let result = state.update(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + let result = state.update(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + + // A new begin_seq resets the counter + let result = state.update(3, 10); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + assert_eq!(state.resend_attempts, 1); + } + + fn create_writer_ref() -> WriterRef { + let (sender, _) = mpsc::channel(10); + WriterRef::new(sender) + } +} diff --git a/crates/hotfix/src/session/state/disconnected.rs b/crates/hotfix/src/session/state/disconnected.rs new file mode 100644 index 00000000..04efa2fd --- /dev/null +++ b/crates/hotfix/src/session/state/disconnected.rs @@ -0,0 +1,77 @@ +use crate::session::event::AwaitingActiveSessionResponse; +use crate::session::state::AwaitingLogonState; +use crate::transport::writer::WriterRef; +use std::time::Duration; +use tokio::sync::oneshot; +use tokio::time::Instant; +use tracing::{debug, error}; + +pub(crate) struct DisconnectedState { + pub(crate) reconnect: bool, + session_awaiter: Option>, + pub(crate) reason: String, +} + +impl DisconnectedState { + pub(crate) fn new(reconnect: bool, reason: &str) -> Self { + Self { + reconnect, + session_awaiter: None, + reason: reason.to_string(), + } + } + + pub(crate) fn set_session_awaiter( + &mut self, + responder: oneshot::Sender, + ) { + self.session_awaiter = Some(responder); + } + + pub(crate) fn has_session_awaiter(&self) -> bool { + self.session_awaiter.is_some() + } + + pub(crate) fn take_session_awaiter( + &mut self, + ) -> Option> { + self.session_awaiter.take() + } + + pub(crate) fn on_connect( + &self, + writer: WriterRef, + logon_timeout: Duration, + ) -> super::SessionState { + super::SessionState::AwaitingLogon(AwaitingLogonState { + writer, + logon_timeout: Instant::now() + logon_timeout, + }) + } + + pub(crate) fn should_reconnect(&self) -> bool { + self.reconnect + } + + pub(crate) fn register_session_awaiter( + &mut self, + responder: oneshot::Sender, + ) -> Result<(), oneshot::Sender> { + if self.has_session_awaiter() { + Err(responder) + } else { + self.set_session_awaiter(responder); + Ok(()) + } + } + + pub(crate) fn notify_session_awaiter(&mut self) { + if let Some(awaiter) = self.take_session_awaiter() { + if let Err(err) = awaiter.send(AwaitingActiveSessionResponse::Active) { + error!("failed to send session awaiter response: {err:?}"); + } else { + debug!("notified session awaiter"); + } + } + } +}