diff --git a/src/agent-client-protocol-core/src/jsonrpc.rs b/src/agent-client-protocol-core/src/jsonrpc.rs index e6ac431..346069d 100644 --- a/src/agent-client-protocol-core/src/jsonrpc.rs +++ b/src/agent-client-protocol-core/src/jsonrpc.rs @@ -35,7 +35,7 @@ use crate::jsonrpc::task_actor::{Task, TaskTx}; use crate::mcp_server::McpServer; use crate::role::HasPeer; use crate::role::Role; -use crate::util::json_cast; +use crate::util::json_cast_params; use crate::{Agent, Client, ConnectTo, RoleId}; /// Handlers process incoming JSON-RPC messages on a connection. @@ -203,6 +203,12 @@ pub trait HandleDispatchFrom: Send { /// You should avoid blocking in this callback unless you wish to block the server (e.g., for rate limiting). /// The recommended approach to manage expensive operations is to the [`ConnectionTo::spawn`] method available on the message context. /// + /// When implementing this directly, prefer [`Dispatch::match_request`], + /// [`Dispatch::match_notification`], or [`Dispatch::match_typed_dispatch`] + /// over manually calling `parse_message` and propagating the resulting error. + /// These helpers preserve per-message failures as structured rejections instead + /// of tearing down the entire connection. + /// /// # Parameters /// /// * `message` - The incoming message to handle. @@ -1786,7 +1792,10 @@ impl ConnectionTo { ) } - /// Send an error notification (no reply expected). + /// Send an out-of-band JSON-RPC error message. + /// + /// This is serialized as a JSON-RPC error response with no `id`, so it is + /// not correlated with any specific request. pub fn send_error_notification(&self, error: crate::Error) -> Result<(), crate::Error> { send_raw_message(&self.message_tx, OutgoingMessage::Error { error }) } @@ -2189,8 +2198,13 @@ pub trait JsonRpcMessage: 'static + Debug + Sized + Send + Clone { /// Parse this type from a method name and parameters. /// - /// Returns an error if the method doesn't match or deserialization fails. - /// Callers should use `matches_method` first to check if this type handles the method. + /// Return `crate::Error::method_not_found()` only when `method` is not handled by this + /// type. When the method matches, any other error is treated as terminal for that dispatch: + /// the message will be rejected rather than falling through to later handlers for the same + /// method. + /// + /// For incoming request/notification params, prefer `crate::util::json_cast_params()` so + /// malformed payloads become `crate::Error::invalid_params()`. fn parse_message(method: &str, params: &impl Serialize) -> Result; } @@ -2328,7 +2342,8 @@ impl Dispatch { /// /// If this message is a request, this error becomes the reply to the request. /// - /// If this message is a notification, the error is sent as a notification. + /// If this message is a notification, the error is sent as an out-of-band + /// JSON-RPC error message. /// /// If this message is a response, the error is forwarded to the waiting handler. pub fn respond_with_error( @@ -2343,24 +2358,64 @@ impl Dispatch { } } - /// Convert to a `Responder` that expects a JSON value - /// and which checks (dynamically) that the JSON value it receives - /// can be converted to `T`. + /// Handle a rejected typed match when no [`ConnectionTo`] is available. + /// + /// * **Requests** – sends the error back to the caller via the [`Responder`]. + /// * **Responses** – forwards the error to the waiting handler via the + /// [`ResponseRouter`]. + /// * **Notifications** – there is no request ID to reply to and no + /// connection available to send an out-of-band error message, so the + /// error is logged and swallowed. + /// + /// Returns `Ok(Handled::Yes)` in all cases so the connection loop + /// continues. + /// + /// **Prefer [`respond_with_error`](Self::respond_with_error)** when a + /// [`ConnectionTo`] is available — it can send an out-of-band error for + /// malformed notifications, which is consistent with + /// [`TypeNotification`](crate::util::TypeNotification). + pub(crate) fn handle_rejection_without_connection( + self, + error: crate::Error, + ) -> Result, crate::Error> { + match self { + Dispatch::Request(_, responder) => { + responder.respond_with_error(error).map(|()| Handled::Yes) + } + Dispatch::Notification(_) => { + tracing::warn!( + ?error, + "rejecting malformed notification without connection: no out-of-band error possible" + ); + Ok(Handled::Yes) + } + Dispatch::Response(_, router) => { + router.respond_with_error(error).map(|()| Handled::Yes) + } + } + } + + /// Erase typed payloads to a JSON/untyped `Dispatch`. /// - /// Note: Response variants cannot be erased since their payload is already - /// parsed. This returns an error for Response variants. + /// Requests and notifications are converted to [`UntypedMessage`]. + /// Responses are converted back to JSON values using the original request method. pub fn erase_to_json(self) -> Result { match self { - Dispatch::Request(response, responder) => Ok(Dispatch::Request( - response.to_untyped_message()?, + Dispatch::Request(request, responder) => Ok(Dispatch::Request( + request.to_untyped_message()?, responder.erase_to_json(), )), Dispatch::Notification(notification) => { Ok(Dispatch::Notification(notification.to_untyped_message()?)) } - Dispatch::Response(_, _) => Err(crate::util::internal_error( - "cannot erase Response variant to JSON", - )), + Dispatch::Response(result, router) => { + let method = router.method(); + let untyped_result = match result { + Ok(response) => response.into_json(method).map(Ok), + Err(err) => Ok(Err(err)), + }?; + Ok(Dispatch::Response(untyped_result, router.erase_to_json())) + } } } @@ -2378,22 +2433,16 @@ impl Dispatch { } } - /// Convert self to an untyped message context. - /// - /// Note: Response variants cannot be converted. This returns an error for Response variants. - pub fn into_untyped_dispatch(self) -> Result { - match self { - Dispatch::Request(request, responder) => Ok(Dispatch::Request( - request.to_untyped_message()?, - responder.erase_to_json(), - )), - Dispatch::Notification(notification) => { - Ok(Dispatch::Notification(notification.to_untyped_message()?)) - } - Dispatch::Response(_, _) => Err(crate::util::internal_error( - "cannot convert Response variant to untyped message context", - )), - } + /// Erase this typed dispatch back to untyped JSON and wrap it in + /// [`Handled::No`] so the connection loop can offer it to later handlers. + pub(crate) fn erase_into_unhandled( + self, + retry: bool, + ) -> Result, crate::Error> { + Ok(Handled::No { + message: self.erase_to_json()?, + retry, + }) } /// Returns the request ID if this is a request or response, None if notification. @@ -2418,86 +2467,167 @@ impl Dispatch { } } +/// Outcome of matching an untyped [`Dispatch`] against a request type. +#[derive(Debug)] +pub enum RequestMatch { + /// The dispatch is a request whose method matched and whose params parsed successfully. + Matched(Req, Responder), + /// The dispatch was not a request of the requested type. + Unhandled(Dispatch), + /// The request method matched, but parsing failed. + Rejected { + /// The original untyped dispatch. + dispatch: Dispatch, + /// The error explaining why the match was rejected. + error: crate::Error, + }, +} + +/// Outcome of matching an untyped [`Dispatch`] against a notification type. +#[derive(Debug)] +pub enum NotificationMatch { + /// The dispatch is a notification whose method matched and whose params parsed successfully. + Matched(Notif), + /// The dispatch was not a notification of the requested type. + Unhandled(Dispatch), + /// The notification method matched, but parsing failed. + Rejected { + /// The original untyped dispatch. + dispatch: Dispatch, + /// The error explaining why the match was rejected. + error: crate::Error, + }, +} + +/// Outcome of matching an untyped [`Dispatch`] against typed request and notification types. +#[derive(Debug)] +pub enum TypedDispatchMatch { + /// The dispatch matched one of the provided types and parsed successfully. + Matched(Dispatch), + /// The dispatch did not match either provided type. + Unhandled(Dispatch), + /// The dispatch method matched, but parsing failed. + Rejected { + /// The original untyped dispatch. + dispatch: Dispatch, + /// The error explaining why the match was rejected. + error: crate::Error, + }, +} + impl Dispatch { - /// Attempts to parse `self` into a typed message context. - /// - /// # Returns + /// Match this dispatch against a request type without using `Err` for parse failures. /// - /// * `Ok(Ok(typed))` if this is a request/notification of the given types - /// * `Ok(Err(self))` if not - /// * `Err` if has the correct method for the given types but parsing fails - #[tracing::instrument(skip(self), fields(Request = ?std::any::type_name::(), Notif = ?std::any::type_name::()), level = "trace", ret)] - pub(crate) fn into_typed_dispatch( - self, - ) -> Result, Dispatch>, crate::Error> { - tracing::debug!( - message = ?self, - "into_typed_dispatch" - ); + /// Once `Req::matches_method` is true, any parse error is terminal for that method. + /// The request is rejected rather than offered to later handlers for the same method. + #[must_use] + pub fn match_request(self) -> RequestMatch { match self { Dispatch::Request(message, responder) => { if Req::matches_method(&message.method) { match Req::parse_message(&message.method, &message.params) { - Ok(req) => { - tracing::trace!(?req, "parsed ok"); - Ok(Ok(Dispatch::Request(req, responder.cast()))) - } - Err(err) => { - tracing::trace!(?err, "parse error"); - Err(err) - } + Ok(req) => RequestMatch::Matched(req, responder.cast()), + Err(error) => RequestMatch::Rejected { + dispatch: Dispatch::Request(message, responder), + error, + }, } } else { - tracing::trace!("method doesn't match"); - Ok(Err(Dispatch::Request(message, responder))) + RequestMatch::Unhandled(Dispatch::Request(message, responder)) } } + dispatch @ (Dispatch::Notification(_) | Dispatch::Response(_, _)) => { + RequestMatch::Unhandled(dispatch) + } + } + } + /// Match this dispatch against a notification type without using `Err` for parse failures. + /// + /// Once `Notif::matches_method` is true, any parse error is terminal for that method. + /// The notification is rejected rather than offered to later handlers for the same method. + #[must_use] + pub fn match_notification(self) -> NotificationMatch { + match self { Dispatch::Notification(message) => { if Notif::matches_method(&message.method) { match Notif::parse_message(&message.method, &message.params) { - Ok(notif) => { - tracing::trace!(?notif, "parse ok"); - Ok(Ok(Dispatch::Notification(notif))) - } - Err(err) => { - tracing::trace!(?err, "parse error"); - Err(err) - } + Ok(notif) => NotificationMatch::Matched(notif), + Err(error) => NotificationMatch::Rejected { + dispatch: Dispatch::Notification(message), + error, + }, } } else { - tracing::trace!("method doesn't match"); - Ok(Err(Dispatch::Notification(message))) + NotificationMatch::Unhandled(Dispatch::Notification(message)) } } + dispatch @ (Dispatch::Request(_, _) | Dispatch::Response(_, _)) => { + NotificationMatch::Unhandled(dispatch) + } + } + } + /// Match this dispatch against typed request and notification types without using `Err` + /// for parse failures. + /// + /// Once the request/notification side matches by method, any parse error is terminal for + /// that method and yields [`TypedDispatchMatch::Rejected`]. Likewise, a matched response + /// whose result cannot be decoded is rejected instead of bubbling up as a fatal `Err`. + #[must_use] + pub fn match_typed_dispatch( + self, + ) -> TypedDispatchMatch { + match self { + dispatch @ Dispatch::Request(_, _) => match dispatch.match_request::() { + RequestMatch::Matched(request, responder) => { + TypedDispatchMatch::Matched(Dispatch::Request(request, responder)) + } + RequestMatch::Unhandled(dispatch) => TypedDispatchMatch::Unhandled(dispatch), + RequestMatch::Rejected { dispatch, error } => { + TypedDispatchMatch::Rejected { dispatch, error } + } + }, + dispatch @ Dispatch::Notification(_) => match dispatch.match_notification::() { + NotificationMatch::Matched(notification) => { + TypedDispatchMatch::Matched(Dispatch::Notification(notification)) + } + NotificationMatch::Unhandled(dispatch) => TypedDispatchMatch::Unhandled(dispatch), + NotificationMatch::Rejected { dispatch, error } => { + TypedDispatchMatch::Rejected { dispatch, error } + } + }, Dispatch::Response(result, cx) => { - let method = cx.method(); - if Req::matches_method(method) { - // Parse the response result - let typed_result = match result { - Ok(value) => { - match ::from_value(method, value) { - Ok(parsed) => { - tracing::trace!(?parsed, "parse ok"); - Ok(parsed) - } - Err(err) => { - tracing::trace!(?err, "parse error"); - return Err(err); - } - } - } - Err(err) => { - tracing::trace!("error, passthrough"); - Err(err) - } - }; - Ok(Ok(Dispatch::Response(typed_result, cx.cast()))) - } else { + if !Req::matches_method(cx.method()) { tracing::trace!("method doesn't match"); - Ok(Err(Dispatch::Response(result, cx))) + return TypedDispatchMatch::Unhandled(Dispatch::Response(result, cx)); } + let typed_result = match result { + Ok(value) => { + let parsed = ::from_value( + cx.method(), + value.clone(), + ); + match parsed { + Ok(parsed) => { + tracing::trace!(?parsed, "parse ok"); + Ok(parsed) + } + Err(err) => { + tracing::trace!(?err, "parse error"); + return TypedDispatchMatch::Rejected { + dispatch: Dispatch::Response(Ok(value), cx), + error: err, + }; + } + } + } + Err(err) => { + tracing::trace!("error, passthrough"); + Err(err) + } + }; + TypedDispatchMatch::Matched(Dispatch::Response(typed_result, cx.cast())) } } } @@ -2529,57 +2659,9 @@ impl Dispatch { let Some(value) = message.params().get("sessionId") else { return Ok(None); }; - let session_id = serde_json::from_value(value.clone())?; + let session_id = json_cast_params(value)?; Ok(Some(session_id)) } - - /// Try to parse this as a notification of the given type. - /// - /// # Returns - /// - /// * `Ok(Ok(typed))` if this is a request/notification of the given types - /// * `Ok(Err(self))` if not - /// * `Err` if has the correct method for the given types but parsing fails - pub fn into_notification( - self, - ) -> Result, crate::Error> { - match self { - Dispatch::Notification(msg) => { - if !N::matches_method(&msg.method) { - return Ok(Err(Dispatch::Notification(msg))); - } - match N::parse_message(&msg.method, &msg.params) { - Ok(n) => Ok(Ok(n)), - Err(err) => Err(err), - } - } - Dispatch::Request(..) | Dispatch::Response(..) => Ok(Err(self)), - } - } - - /// Try to parse this as a request of the given type. - /// - /// # Returns - /// - /// * `Ok(Ok(typed))` if this is a request/notification of the given types - /// * `Ok(Err(self))` if not - /// * `Err` if has the correct method for the given types but parsing fails - pub fn into_request( - self, - ) -> Result), Dispatch>, crate::Error> { - match self { - Dispatch::Request(msg, responder) => { - if !Req::matches_method(&msg.method) { - return Ok(Err(Dispatch::Request(msg, responder))); - } - match Req::parse_message(&msg.method, &msg.params) { - Ok(req) => Ok(Ok((req, responder.cast()))), - Err(err) => Err(err), - } - } - Dispatch::Notification(..) | Dispatch::Response(..) => Ok(Err(self)), - } - } } impl Dispatch { @@ -2653,7 +2735,11 @@ impl UntypedMessage { id: Option, ) -> Result { let Self { method, params } = self; - Ok(jsonrpcmsg::Request::new_v2(method, json_cast(params)?, id)) + Ok(jsonrpcmsg::Request::new_v2( + method, + crate::util::json_cast(params)?, + id, + )) } } diff --git a/src/agent-client-protocol-core/src/jsonrpc/handlers.rs b/src/agent-client-protocol-core/src/jsonrpc/handlers.rs index 3575ef0..a145050 100644 --- a/src/agent-client-protocol-core/src/jsonrpc/handlers.rs +++ b/src/agent-client-protocol-core/src/jsonrpc/handlers.rs @@ -1,4 +1,6 @@ -use crate::jsonrpc::{HandleDispatchFrom, Handled, IntoHandled, JsonRpcResponse}; +use crate::jsonrpc::{ + HandleDispatchFrom, Handled, IntoHandled, NotificationMatch, RequestMatch, TypedDispatchMatch, +}; use crate::role::{HasPeer, Role, handle_incoming_dispatch}; use crate::{ConnectionTo, Dispatch, JsonRpcNotification, JsonRpcRequest, UntypedMessage}; @@ -108,67 +110,44 @@ where dispatch, connection, async |dispatch, connection| { - match dispatch { - Dispatch::Request(message, responder) => { - tracing::debug!( - request_type = std::any::type_name::(), - message = ?message, - "RequestHandler::handle_request" - ); - if Req::matches_method(&message.method) { - match Req::parse_message(&message.method, &message.params) { - Ok(req) => { - tracing::trace!( - ?req, - "RequestHandler::handle_request: parse completed" - ); - let typed_responder = responder.cast(); - let result = (self.to_future_hack)( - &mut self.handler, - req, - typed_responder, - connection, - ) - .await?; - match result.into_handled() { - Handled::Yes => Ok(Handled::Yes), - Handled::No { - message: (request, responder), - retry, - } => { - // Handler returned the request back, convert to untyped - let untyped = request.to_untyped_message()?; - Ok(Handled::No { - message: Dispatch::Request( - untyped, - responder.erase_to_json(), - ), - retry, - }) - } - } - } - Err(err) => { - tracing::trace!( - ?err, - "RequestHandler::handle_request: parse errored" - ); - Err(err) - } - } - } else { - tracing::trace!("RequestHandler::handle_request: method doesn't match"); - Ok(Handled::No { - message: Dispatch::Request(message, responder), - retry: false, - }) + if let Dispatch::Request(message, _) = &dispatch { + tracing::debug!( + request_type = std::any::type_name::(), + message = ?message, + "RequestHandler::handle_request" + ); + } + match dispatch.match_request::() { + RequestMatch::Matched(req, typed_responder) => { + tracing::trace!(?req, "RequestHandler::handle_request: parse completed"); + let result = (self.to_future_hack)( + &mut self.handler, + req, + typed_responder, + connection, + ) + .await?; + match result.into_handled() { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message: (request, responder), + retry, + } => Dispatch::::Request(request, responder) + .erase_into_unhandled(retry), } } - - Dispatch::Notification(..) | Dispatch::Response(..) => Ok(Handled::No { - message: dispatch, - retry: false, - }), + RequestMatch::Unhandled(dispatch) => { + tracing::trace!("RequestHandler::handle_request: method doesn't match"); + Ok(Handled::No { + message: dispatch, + retry: false, + }) + } + RequestMatch::Rejected { dispatch, error } => { + tracing::trace!(?error, "RequestHandler::handle_request: parse errored"); + dispatch.respond_with_error(error, connection)?; + Ok(Handled::Yes) + } } }, ) @@ -236,61 +215,47 @@ where dispatch, connection, async |dispatch, connection| { - match dispatch { - Dispatch::Notification(message) => { - tracing::debug!( - request_type = std::any::type_name::(), - message = ?message, - "NotificationHandler::handle_dispatch" + if let Dispatch::Notification(message) = &dispatch { + tracing::debug!( + request_type = std::any::type_name::(), + message = ?message, + "NotificationHandler::handle_dispatch" + ); + } + match dispatch.match_notification::() { + NotificationMatch::Matched(notif) => { + tracing::trace!( + ?notif, + "NotificationHandler::handle_notification: parse completed" ); - if Notif::matches_method(&message.method) { - match Notif::parse_message(&message.method, &message.params) { - Ok(notif) => { - tracing::trace!( - ?notif, - "NotificationHandler::handle_notification: parse completed" - ); - let result = - (self.to_future_hack)(&mut self.handler, notif, connection) - .await?; - match result.into_handled() { - Handled::Yes => Ok(Handled::Yes), - Handled::No { - message: (notification, _cx), - retry, - } => { - // Handler returned the notification back, convert to untyped - let untyped = notification.to_untyped_message()?; - Ok(Handled::No { - message: Dispatch::Notification(untyped), - retry, - }) - } - } - } - Err(err) => { - tracing::trace!( - ?err, - "NotificationHandler::handle_notification: parse errored" - ); - Err(err) - } - } - } else { - tracing::trace!( - "NotificationHandler::handle_notification: method doesn't match" - ); - Ok(Handled::No { - message: Dispatch::Notification(message), - retry: false, - }) + let result = + (self.to_future_hack)(&mut self.handler, notif, connection).await?; + match result.into_handled() { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message: (notification, _cx), + retry, + } => Dispatch::::Notification(notification) + .erase_into_unhandled(retry), } } - - Dispatch::Request(..) | Dispatch::Response(..) => Ok(Handled::No { - message: dispatch, - retry: false, - }), + NotificationMatch::Unhandled(dispatch) => { + tracing::trace!( + "NotificationHandler::handle_notification: method doesn't match" + ); + Ok(Handled::No { + message: dispatch, + retry: false, + }) + } + NotificationMatch::Rejected { dispatch, error } => { + tracing::trace!( + ?error, + "NotificationHandler::handle_notification: parse errored" + ); + dispatch.respond_with_error(error, connection)?; + Ok(Handled::Yes) + } } }, ) @@ -362,57 +327,27 @@ where self.peer.clone(), dispatch, connection, - async |dispatch, connection| match dispatch.into_typed_dispatch::()? { - Ok(typed_dispatch) => { + async |dispatch, connection| match dispatch.match_typed_dispatch::() { + TypedDispatchMatch::Matched(typed_dispatch) => { let result = (self.to_future_hack)(&mut self.handler, typed_dispatch, connection) .await?; match result.into_handled() { Handled::Yes => Ok(Handled::Yes), Handled::No { - message: Dispatch::Request(request, responder), + message: typed_dispatch, retry, - } => { - let untyped = request.to_untyped_message()?; - Ok(Handled::No { - message: Dispatch::Request(untyped, responder.erase_to_json()), - retry, - }) - } - Handled::No { - message: Dispatch::Notification(notification), - retry, - } => { - let untyped = notification.to_untyped_message()?; - Ok(Handled::No { - message: Dispatch::Notification(untyped), - retry, - }) - } - Handled::No { - message: Dispatch::Response(result, responder), - retry, - } => { - let method = responder.method(); - let untyped_result = match result { - Ok(response) => response.into_json(method).map(Ok), - Err(err) => Ok(Err(err)), - }?; - Ok(Handled::No { - message: Dispatch::Response( - untyped_result, - responder.erase_to_json(), - ), - retry, - }) - } + } => typed_dispatch.erase_into_unhandled(retry), } } - - Err(dispatch) => Ok(Handled::No { + TypedDispatchMatch::Unhandled(dispatch) => Ok(Handled::No { message: dispatch, retry: false, }), + TypedDispatchMatch::Rejected { dispatch, error } => { + dispatch.respond_with_error(error, connection)?; + Ok(Handled::Yes) + } }, ) .await diff --git a/src/agent-client-protocol-core/src/lib.rs b/src/agent-client-protocol-core/src/lib.rs index 9175751..af28b46 100644 --- a/src/agent-client-protocol-core/src/lib.rs +++ b/src/agent-client-protocol-core/src/lib.rs @@ -113,7 +113,8 @@ pub mod jsonrpcmsg { pub use jsonrpc::{ Builder, ByteStreams, Channel, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, IntoHandled, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Lines, - NullHandler, Responder, ResponseRouter, SentRequest, UntypedMessage, + NotificationMatch, NullHandler, RequestMatch, Responder, ResponseRouter, SentRequest, + TypedDispatchMatch, UntypedMessage, run::{ChainRun, NullRun, RunWithConnectionTo}, }; diff --git a/src/agent-client-protocol-core/src/role.rs b/src/agent-client-protocol-core/src/role.rs index e1c28aa..2ecc74e 100644 --- a/src/agent-client-protocol-core/src/role.rs +++ b/src/agent-client-protocol-core/src/role.rs @@ -9,7 +9,6 @@ use std::{any::TypeId, fmt::Debug, future::Future, hash::Hash}; use serde::{Deserialize, Serialize}; use crate::schema::{METHOD_SUCCESSOR_MESSAGE, SuccessorMessage}; -use crate::util::json_cast; use crate::{Builder, ConnectionTo, Dispatch, Handled, JsonRpcMessage, UntypedMessage}; /// Roles for the ACP protocol. @@ -230,7 +229,14 @@ where "Response variant cannot be unwrapped as SuccessorMessage", ) })?; - let SuccessorMessage { message, meta } = json_cast(untyped_message.params())?; + let SuccessorMessage { message, meta } = + match crate::util::json_cast_params(untyped_message.params()) { + Ok(message) => message, + Err(error) => { + dispatch.respond_with_error(error, connection.clone())?; + return Ok(Handled::Yes); + } + }; let successor_dispatch = dispatch.try_map_message(|_| Ok(message))?; tracing::trace!( unwrapped_method = %successor_dispatch.method(), diff --git a/src/agent-client-protocol-core/src/role/acp.rs b/src/agent-client-protocol-core/src/role/acp.rs index 3ffb334..306c7e3 100644 --- a/src/agent-client-protocol-core/src/role/acp.rs +++ b/src/agent-client-protocol-core/src/role/acp.rs @@ -279,7 +279,14 @@ where MatchDispatchFrom::new(message, &connection) .if_message_from(Agent, async |message| { // If this is for our session-id, proxy it to the client. - if let Some(session_id) = message.get_session_id()? + let session_id = match message.get_session_id() { + Ok(session_id) => session_id, + Err(error) => { + message.respond_with_error(error, connection.clone())?; + return Ok(Handled::Yes); + } + }; + if let Some(session_id) = session_id && session_id == self.session_id { connection.send_proxied_message_to(Client, message)?; diff --git a/src/agent-client-protocol-core/src/schema/mod.rs b/src/agent-client-protocol-core/src/schema/mod.rs index 30f6d93..8367031 100644 --- a/src/agent-client-protocol-core/src/schema/mod.rs +++ b/src/agent-client-protocol-core/src/schema/mod.rs @@ -36,7 +36,7 @@ macro_rules! impl_jsonrpc_request { if method != $method { return Err($crate::Error::method_not_found()); } - $crate::util::json_cast(params) + $crate::util::json_cast_params(params) } } @@ -84,7 +84,7 @@ macro_rules! impl_jsonrpc_notification { if method != $method { return Err($crate::Error::method_not_found()); } - $crate::util::json_cast(params) + $crate::util::json_cast_params(params) } } @@ -133,10 +133,10 @@ macro_rules! impl_jsonrpc_request_enum { params: &impl serde::Serialize, ) -> Result { match method { - $( $(#[$meta])* $method => $crate::util::json_cast(params).map(Self::$variant), )* + $( $(#[$meta])* $method => $crate::util::json_cast_params(params).map(Self::$variant), )* _ => { if let Some(custom_method) = method.strip_prefix('_') { - $crate::util::json_cast(params).map( + $crate::util::json_cast_params(params).map( |ext_req: $crate::schema::ExtRequest| { Self::$ext_variant($crate::schema::ExtRequest::new( custom_method.to_string(), @@ -196,10 +196,10 @@ macro_rules! impl_jsonrpc_notification_enum { params: &impl serde::Serialize, ) -> Result { match method { - $( $(#[$meta])* $method => $crate::util::json_cast(params).map(Self::$variant), )* + $( $(#[$meta])* $method => $crate::util::json_cast_params(params).map(Self::$variant), )* _ => { if let Some(custom_method) = method.strip_prefix('_') { - $crate::util::json_cast(params).map( + $crate::util::json_cast_params(params).map( |ext_notif: $crate::schema::ExtNotification| { Self::$ext_variant($crate::schema::ExtNotification::new( custom_method.to_string(), diff --git a/src/agent-client-protocol-core/src/schema/proxy_protocol.rs b/src/agent-client-protocol-core/src/schema/proxy_protocol.rs index 38b7109..aa35dd1 100644 --- a/src/agent-client-protocol-core/src/schema/proxy_protocol.rs +++ b/src/agent-client-protocol-core/src/schema/proxy_protocol.rs @@ -50,7 +50,7 @@ impl JsonRpcMessage for SuccessorMessage { if method != METHOD_SUCCESSOR_MESSAGE { return Err(crate::Error::method_not_found()); } - let outer = crate::util::json_cast::<_, SuccessorMessage>(params)?; + let outer = crate::util::json_cast_params::<_, SuccessorMessage>(params)?; if !M::matches_method(&outer.message.method) { return Err(crate::Error::method_not_found()); } @@ -161,7 +161,7 @@ impl JsonRpcMessage for McpOverAcpMessage { if method != METHOD_MCP_MESSAGE { return Err(crate::Error::method_not_found()); } - let outer = crate::util::json_cast::<_, McpOverAcpMessage>(params)?; + let outer = crate::util::json_cast_params::<_, McpOverAcpMessage>(params)?; if !M::matches_method(&outer.message.method) { return Err(crate::Error::method_not_found()); } diff --git a/src/agent-client-protocol-core/src/session.rs b/src/agent-client-protocol-core/src/session.rs index 9908c00..21de484 100644 --- a/src/agent-client-protocol-core/src/session.rs +++ b/src/agent-client-protocol-core/src/session.rs @@ -741,7 +741,14 @@ where ); MatchDispatchFrom::new(message, &cx) .if_message_from(Agent, async |message| { - if let Some(session_id) = message.get_session_id()? { + let session_id = match message.get_session_id() { + Ok(session_id) => session_id, + Err(error) => { + message.respond_with_error(error, cx.clone())?; + return Ok(Handled::Yes); + } + }; + if let Some(session_id) = session_id { tracing::trace!( message_session_id = ?session_id, handler_session_id = ?self.session_id, diff --git a/src/agent-client-protocol-core/src/typed.rs b/src/agent-client-protocol-core/src/typed.rs deleted file mode 100644 index 6dc70a2..0000000 --- a/src/agent-client-protocol-core/src/typed.rs +++ /dev/null @@ -1,125 +0,0 @@ -// Types re-exported from crate root -use jsonrpcmsg::Params; - -use crate::{ - ConnectionTo, Responder, JsonRpcNotification, JsonRpcRequest, UntypedMessage, - util::json_cast, -}; - -/// Utility class for handling untyped requests. -#[must_use] -pub struct TypeRequest { - state: Option, -} - -enum TypeMessageState { - Unhandled(String, Option, Responder), - Handled(Result<(), crate::Error>), -} - -impl TypeRequest { - pub fn new(request: UntypedMessage, responder: Responder) -> Self { - let UntypedMessage { method, params } = request; - let params: Option = json_cast(params).expect("valid params"); - Self { - state: Some(TypeMessageState::Unhandled(method, params, responder)), - } - } - - pub async fn handle_if( - mut self, - op: impl AsyncFnOnce(R, Responder) -> Result<(), crate::Error>, - ) -> Self { - self.state = Some(match self.state.take().expect("valid state") { - TypeMessageState::Unhandled(method, params, responder) => { - match R::parse_message(&method, ¶ms) { - Some(Ok(request)) => { - TypeMessageState::Handled(op(request, responder.cast()).await) - } - - Some(Err(err)) => TypeMessageState::Handled(responder.respond_with_error(err)), - - None => TypeMessageState::Unhandled(method, params, responder), - } - } - - TypeMessageState::Handled(err) => TypeMessageState::Handled(err), - }); - self - } - - pub async fn otherwise( - mut self, - op: impl AsyncFnOnce(UntypedMessage, Responder) -> Result<(), crate::Error>, - ) -> Result<(), crate::Error> { - match self.state.take().expect("valid state") { - TypeMessageState::Unhandled(method, params, responder) => { - match UntypedMessage::new(&method, params) { - Ok(m) => op(m, responder).await, - Err(err) => responder.respond_with_error(err), - } - } - TypeMessageState::Handled(r) => r, - } - } -} - -/// Utility class for handling untyped notifications. -#[must_use] -pub struct TypeNotification { - cx: ConnectionTo, - state: Option, -} - -enum TypeNotificationState { - Unhandled(String, Option), - Handled(Result<(), crate::Error>), -} - -impl TypeNotification { - pub fn new(request: UntypedMessage, cx: &ConnectionTo) -> Self { - let UntypedMessage { method, params } = request; - let params: Option = json_cast(params).expect("valid params"); - Self { - cx: cx.clone(), - state: Some(TypeNotificationState::Unhandled(method, params)), - } - } - - pub async fn handle_if( - mut self, - op: impl AsyncFnOnce(N) -> Result<(), crate::Error>, - ) -> Self { - self.state = Some(match self.state.take().expect("valid state") { - TypeNotificationState::Unhandled(method, params) => { - match N::parse_message(&method, ¶ms) { - Some(Ok(request)) => TypeNotificationState::Handled(op(request).await), - - Some(Err(err)) => { - TypeNotificationState::Handled(self.cx.send_error_notification(err)) - } - - None => TypeNotificationState::Unhandled(method, params), - } - } - - TypeNotificationState::Handled(err) => TypeNotificationState::Handled(err), - }); - self - } - - pub async fn otherwise( - mut self, - op: impl AsyncFnOnce(UntypedMessage) -> Result<(), crate::Error>, - ) -> Result<(), crate::Error> { - match self.state.take().expect("valid state") { - TypeNotificationState::Unhandled(method, params) => { - match UntypedMessage::new(&method, params) { - Ok(m) => op(m).await, - Err(err) => self.cx.send_error_notification(err), - } - } - TypeNotificationState::Handled(r) => r, - } - } -} diff --git a/src/agent-client-protocol-core/src/util.rs b/src/agent-client-protocol-core/src/util.rs index 310b02e..684dd53 100644 --- a/src/agent-client-protocol-core/src/util.rs +++ b/src/agent-client-protocol-core/src/util.rs @@ -8,24 +8,64 @@ use futures::{ mod typed; pub use typed::{MatchDispatch, MatchDispatchFrom, TypeNotification}; -/// Cast from `N` to `M` by serializing/deserialization to/from JSON. +fn serde_conversion_error( + kind: impl FnOnce() -> crate::Error, + error: impl ToString, + json: Option, + phase: &'static str, +) -> crate::Error { + let mut data = serde_json::json!({ + "error": error.to_string(), + "phase": phase, + }); + if let Some(json) = json { + data["json"] = json; + } + kind().data(data) +} + +/// Cast between JSON and typed values for local/internal conversions. +/// +/// This is appropriate for response decoding, outbound JSON-RPC conversion, and other +/// framework-internal serde transformations where `InvalidParams` would be misleading. pub fn json_cast(params: N) -> Result where N: serde::Serialize, M: serde::de::DeserializeOwned, { let json = serde_json::to_value(params).map_err(|e| { - crate::Error::parse_error().data(serde_json::json!({ - "error": e.to_string(), - "phase": "serialization" - })) + serde_conversion_error(crate::Error::internal_error, e, None, "serialization") + })?; + let m = serde_json::from_value(json.clone()).map_err(|e| { + serde_conversion_error( + crate::Error::internal_error, + e, + Some(json), + "deserialization", + ) + })?; + Ok(m) +} + +/// Cast incoming request/notification params into a typed payload. +/// +/// Deserialization failures become `InvalidParams`, while serialization failures are +/// treated as local/internal bugs. +pub fn json_cast_params(params: N) -> Result +where + N: serde::Serialize, + M: serde::de::DeserializeOwned, +{ + let json = serde_json::to_value(params).map_err(|e| { + serde_conversion_error(crate::Error::internal_error, e, None, "serialization") })?; let m = serde_json::from_value(json.clone()).map_err(|e| { - crate::Error::parse_error().data(serde_json::json!({ - "error": e.to_string(), - "json": json, - "phase": "deserialization" - })) + serde_conversion_error( + crate::Error::invalid_params, + e, + Some(json), + "deserialization", + ) })?; Ok(m) } diff --git a/src/agent-client-protocol-core/src/util/typed.rs b/src/agent-client-protocol-core/src/util/typed.rs index c94e8dd..5d6b677 100644 --- a/src/agent-client-protocol-core/src/util/typed.rs +++ b/src/agent-client-protocol-core/src/util/typed.rs @@ -21,7 +21,8 @@ use jsonrpcmsg::Params; use crate::{ ConnectionTo, Dispatch, HandleDispatchFrom, Handled, JsonRpcNotification, JsonRpcRequest, - JsonRpcResponse, Responder, ResponseRouter, UntypedMessage, + JsonRpcResponse, NotificationMatch, RequestMatch, Responder, ResponseRouter, + TypedDispatchMatch, UntypedMessage, role::{HasPeer, Role, handle_incoming_dispatch}, util::json_cast, }; @@ -32,6 +33,9 @@ use crate::{ /// such as inside a [`MatchDispatchFrom`] callback or when processing messages /// that don't need peer transforms. /// +/// Because this helper has no [`ConnectionTo`], malformed notifications cannot emit +/// an out-of-band error message; they are logged and swallowed instead. +/// /// For connection handlers where you need proper peer-aware transforms, /// use [`MatchDispatchFrom`] instead. /// @@ -102,46 +106,27 @@ impl MatchDispatch { retry, }) = self.state { - self.state = match dispatch { - Dispatch::Request(untyped_request, untyped_responder) => { - if Req::matches_method(untyped_request.method()) { - match Req::parse_message(untyped_request.method(), untyped_request.params()) - { - Ok(typed_request) => { - let typed_responder = untyped_responder.cast(); - match op(typed_request, typed_responder).await { - Ok(result) => match result.into_handled() { - Handled::Yes => Ok(Handled::Yes), - Handled::No { - message: (request, responder), - retry: request_retry, - } => match request.to_untyped_message() { - Ok(untyped) => Ok(Handled::No { - message: Dispatch::Request( - untyped, - responder.erase_to_json(), - ), - retry: retry | request_retry, - }), - Err(err) => Err(err), - }, - }, - Err(err) => Err(err), - } - } - Err(err) => Err(err), - } - } else { - Ok(Handled::No { - message: Dispatch::Request(untyped_request, untyped_responder), - retry, - }) + self.state = match dispatch.match_request::() { + RequestMatch::Matched(typed_request, typed_responder) => { + match op(typed_request, typed_responder).await { + Ok(result) => match result.into_handled() { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message: (request, responder), + retry: request_retry, + } => Dispatch::::Request(request, responder) + .erase_into_unhandled(retry | request_retry), + }, + Err(err) => Err(err), } } - Dispatch::Notification(_) | Dispatch::Response(_, _) => Ok(Handled::No { + RequestMatch::Unhandled(dispatch) => Ok(Handled::No { message: dispatch, retry, }), + RequestMatch::Rejected { dispatch, error } => { + dispatch.handle_rejection_without_connection(error) + } }; } self @@ -163,42 +148,27 @@ impl MatchDispatch { retry, }) = self.state { - self.state = match dispatch { - Dispatch::Notification(untyped_notification) => { - if N::matches_method(untyped_notification.method()) { - match N::parse_message( - untyped_notification.method(), - untyped_notification.params(), - ) { - Ok(typed_notification) => match op(typed_notification).await { - Ok(result) => match result.into_handled() { - Handled::Yes => Ok(Handled::Yes), - Handled::No { - message: notification, - retry: notification_retry, - } => match notification.to_untyped_message() { - Ok(untyped) => Ok(Handled::No { - message: Dispatch::Notification(untyped), - retry: retry | notification_retry, - }), - Err(err) => Err(err), - }, - }, - Err(err) => Err(err), - }, - Err(err) => Err(err), - } - } else { - Ok(Handled::No { - message: Dispatch::Notification(untyped_notification), - retry, - }) + self.state = match dispatch.match_notification::() { + NotificationMatch::Matched(typed_notification) => { + match op(typed_notification).await { + Ok(result) => match result.into_handled() { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message: notification, + retry: notification_retry, + } => Dispatch::::Notification(notification) + .erase_into_unhandled(retry | notification_retry), + }, + Err(err) => Err(err), } } - Dispatch::Request(_, _) | Dispatch::Response(_, _) => Ok(Handled::No { + NotificationMatch::Unhandled(dispatch) => Ok(Handled::No { message: dispatch, retry, }), + NotificationMatch::Rejected { dispatch, error } => { + dispatch.handle_rejection_without_connection(error) + } }; } self @@ -220,54 +190,27 @@ impl MatchDispatch { retry, }) = self.state { - self.state = match dispatch.into_typed_dispatch::() { - Ok(Ok(typed_dispatch)) => match op(typed_dispatch).await { + self.state = match dispatch.match_typed_dispatch::() { + TypedDispatchMatch::Matched(typed_dispatch) => match op(typed_dispatch).await { Ok(result) => match result.into_handled() { Handled::Yes => Ok(Handled::Yes), Handled::No { message: typed_dispatch, retry: message_retry, - } => { - let untyped = match typed_dispatch { - Dispatch::Request(request, responder) => { - match request.to_untyped_message() { - Ok(untyped) => { - Dispatch::Request(untyped, responder.erase_to_json()) - } - Err(err) => return Self { state: Err(err) }, - } - } - Dispatch::Notification(notification) => { - match notification.to_untyped_message() { - Ok(untyped) => Dispatch::Notification(untyped), - Err(err) => return Self { state: Err(err) }, - } - } - Dispatch::Response(result, router) => { - let method = router.method(); - let untyped_result = match result { - Ok(response) => match response.into_json(method) { - Ok(json) => Ok(json), - Err(err) => return Self { state: Err(err) }, - }, - Err(err) => Err(err), - }; - Dispatch::Response(untyped_result, router.erase_to_json()) - } - }; - Ok(Handled::No { - message: untyped, - retry: retry | message_retry, - }) - } + } => match typed_dispatch.erase_into_unhandled(retry | message_retry) { + Ok(handled) => Ok(handled), + Err(err) => return Self { state: Err(err) }, + }, }, Err(err) => Err(err), }, - Ok(Err(dispatch)) => Ok(Handled::No { + TypedDispatchMatch::Unhandled(dispatch) => Ok(Handled::No { message: dispatch, retry, }), - Err(err) => Err(err), + TypedDispatchMatch::Rejected { dispatch, error } => { + dispatch.handle_rejection_without_connection(error) + } }; } self @@ -315,20 +258,8 @@ impl MatchDispatch { Handled::No { message: (result, router), retry: response_retry, - } => { - // Convert typed result back to untyped - let untyped_result = match result { - Ok(response) => response.into_json(router.method()), - Err(err) => Err(err), - }; - Ok(Handled::No { - message: Dispatch::Response( - untyped_result, - router.erase_to_json(), - ), - retry: retry | response_retry, - }) - } + } => Dispatch::::Response(result, router) + .erase_into_unhandled(retry | response_retry), }, Err(err) => Err(err), } @@ -580,12 +511,28 @@ impl MatchDispatchFrom { peer, message, self.connection.clone(), - async |dispatch, _connection| { - // Delegate to MatchDispatch for parsing - MatchDispatch::new(dispatch) - .if_notification(op) - .await - .done() + async |dispatch, connection| match dispatch.match_notification::() { + NotificationMatch::Matched(typed_notification) => { + match op(typed_notification).await { + Ok(result) => match result.into_handled() { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message: notification, + retry: notification_retry, + } => Dispatch::::Notification(notification) + .erase_into_unhandled(notification_retry), + }, + Err(err) => Err(err), + } + } + NotificationMatch::Unhandled(dispatch) => Ok(Handled::No { + message: dispatch, + retry: false, + }), + NotificationMatch::Rejected { dispatch, error } => { + dispatch.respond_with_error(error, connection)?; + Ok(Handled::Yes) + } }, ) .await; @@ -617,9 +564,25 @@ impl MatchDispatchFrom { peer, message, self.connection.clone(), - async |dispatch, _connection| { - // Delegate to MatchDispatch for parsing - MatchDispatch::new(dispatch).if_message(op).await.done() + async |dispatch, connection| match dispatch.match_typed_dispatch::() { + TypedDispatchMatch::Matched(typed_dispatch) => match op(typed_dispatch).await { + Ok(result) => match result.into_handled() { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message: typed_dispatch, + retry: message_retry, + } => typed_dispatch.erase_into_unhandled(message_retry), + }, + Err(err) => Err(err), + }, + TypedDispatchMatch::Unhandled(dispatch) => Ok(Handled::No { + message: dispatch, + retry: false, + }), + TypedDispatchMatch::Rejected { dispatch, error } => { + dispatch.respond_with_error(error, connection)?; + Ok(Handled::Yes) + } }, ) .await; @@ -844,7 +807,8 @@ impl MatchDispatchFrom { /// ``` /// /// Since notifications don't expect responses, handlers only receive the parsed -/// notification (not a request context). +/// notification (not a request context). If parsing fails, this helper sends an +/// out-of-band JSON-RPC error message via the provided connection. #[must_use] #[derive(Debug)] pub struct TypeNotification { diff --git a/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs b/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs index cdec8e9..28a7476 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs @@ -75,7 +75,7 @@ impl JsonRpcMessage for PingRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -132,7 +132,7 @@ impl JsonRpcMessage for SlowRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs b/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs index 8a093bf..acd71d6 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs @@ -60,7 +60,7 @@ impl JsonRpcMessage for FooRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -116,7 +116,7 @@ impl JsonRpcMessage for BarRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -269,7 +269,7 @@ impl JsonRpcMessage for TrackRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -397,7 +397,7 @@ impl JsonRpcMessage for Method1Request { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -432,7 +432,7 @@ impl JsonRpcMessage for Method2Request { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -626,7 +626,7 @@ impl JsonRpcMessage for EventNotification { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs b/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs index 785e583..c988d34 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs @@ -110,7 +110,7 @@ impl JsonRpcMessage for OptionalParamsRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs b/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs index 11ecbfe..c2aa671 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs @@ -8,14 +8,40 @@ //! - Missing/invalid parameters use agent_client_protocol_core::{ - ConnectionTo, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, Responder, SentRequest, + ConnectionTo, Dispatch, HandleDispatchFrom, Handled, JsonRpcMessage, JsonRpcNotification, + JsonRpcRequest, JsonRpcResponse, RequestMatch, Responder, SentRequest, role::UntypedRole, + util::{MatchDispatch, MatchDispatchFrom}, }; use expect_test::expect; use futures::{AsyncRead, AsyncWrite}; use serde::{Deserialize, Serialize}; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; +async fn read_jsonrpc_response_line( + reader: &mut tokio::io::BufReader, +) -> serde_json::Value { + try_read_jsonrpc_response_line(reader, tokio::time::Duration::from_secs(1)) + .await + .expect("timed out waiting for JSON-RPC response") +} + +async fn try_read_jsonrpc_response_line( + reader: &mut tokio::io::BufReader, + timeout: tokio::time::Duration, +) -> Option { + use tokio::io::AsyncBufReadExt as _; + + let mut line = String::new(); + match tokio::time::timeout(timeout, reader.read_line(&mut line)).await { + Ok(Ok(0)) | Err(_) => None, + Ok(Ok(_)) => { + Some(serde_json::from_str(line.trim()).expect("response should be valid JSON")) + } + Ok(Err(_)) => panic!("failed to read JSON-RPC response line"), + } +} + /// Test helper to block and wait for a JSON-RPC response. async fn recv( response: SentRequest, @@ -78,7 +104,7 @@ impl JsonRpcMessage for SimpleRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -107,6 +133,39 @@ impl JsonRpcResponse for SimpleResponse { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleNotification { + message: String, +} + +impl JsonRpcMessage for SimpleNotification { + fn matches_method(method: &str) -> bool { + method == "simple_notification" + } + + fn method(&self) -> &'static str { + "simple_notification" + } + + fn to_untyped_message( + &self, + ) -> Result { + agent_client_protocol_core::UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + if !Self::matches_method(method) { + return Err(agent_client_protocol_core::Error::method_not_found()); + } + agent_client_protocol_core::util::json_cast_params(params) + } +} + +impl JsonRpcNotification for SimpleNotification {} + // ============================================================================ // Test 1: Invalid JSON (complete line with parse error) // ============================================================================ @@ -282,7 +341,7 @@ impl JsonRpcMessage for ErrorRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -351,7 +410,7 @@ async fn test_handler_returns_error() { } // ============================================================================ -// Test 4: Request without required params +// Test 4: Handler-returned invalid params // ============================================================================ #[derive(Debug, Clone, Serialize, Deserialize)] @@ -388,7 +447,7 @@ impl JsonRpcRequest for EmptyRequest { } #[tokio::test(flavor = "current_thread")] -async fn test_missing_required_params() { +async fn test_handler_returned_invalid_params() { use tokio::task::LocalSet; let local = LocalSet::new(); @@ -397,19 +456,15 @@ async fn test_missing_required_params() { .run_until(async { let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - // Handler that validates params - since EmptyRequest has no params but we're checking - // against SimpleRequest which requires a message field, this will fail + // This test exercises a handler that explicitly returns `Invalid params`. + // It does not cover request deserialization failures; those are covered below + // by the raw-wire malformed-request regression tests. let server_transport = agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); let server = UntypedRole.builder().on_receive_request( async |_request: EmptyRequest, responder: Responder, _connection: ConnectionTo| { - // This will be called, but EmptyRequest parsing already succeeded - // The test is actually checking if EmptyRequest (no params) fails to parse as SimpleRequest - // But with the new API, EmptyRequest parses successfully since it expects no params - // We need to manually check - but actually the parse_request for EmptyRequest - // accepts anything for "strict_method", so the error must come from somewhere else responder .respond_with_error(agent_client_protocol_core::Error::invalid_params()) }, @@ -428,13 +483,12 @@ async fn test_missing_required_params() { .connect_with( client_transport, async |cx| -> Result<(), agent_client_protocol_core::Error> { - // Send request with no params (EmptyRequest has no fields) let request = EmptyRequest; let result: Result = recv(cx.send_request(request)).await; - // Should get invalid_params error + // Should get invalid_params error from the handler. assert!(result.is_err()); if let Err(err) = result { assert!(matches!( @@ -451,3 +505,569 @@ async fn test_missing_required_params() { }) .await; } + +// ============================================================================ +// Test 5: Malformed incoming responses +// ============================================================================ + +#[tokio::test(flavor = "current_thread")] +async fn test_match_dispatch_from_if_message_malformed_response_keeps_connection_alive() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + struct ClientTypedMessageHandler; + + impl HandleDispatchFrom for ClientTypedMessageHandler { + fn describe_chain(&self) -> impl std::fmt::Debug { + "ClientTypedMessageHandler" + } + + async fn handle_dispatch_from( + &mut self, + message: Dispatch, + connection: ConnectionTo, + ) -> Result, agent_client_protocol_core::Error> { + MatchDispatchFrom::new(message, &connection) + .if_message_from( + UntypedRole, + async move |dispatch: Dispatch| { + match dispatch { + Dispatch::Request(request, responder) => { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + } + Dispatch::Notification(_) => Ok(()), + Dispatch::Response(result, router) => { + router.respond_with_result(result) + } + } + }, + ) + .await + .done() + } + } + + let local = LocalSet::new(); + + local + .run_until(async { + let (client_writer, server_reader) = tokio::io::duplex(2048); + let (mut server_writer, client_reader) = tokio::io::duplex(2048); + + let client_transport = agent_client_protocol_core::ByteStreams::new( + client_writer.compat_write(), + client_reader.compat(), + ); + let client = UntypedRole + .builder() + .with_handler(ClientTypedMessageHandler); + + let server_task = tokio::task::spawn_local(async move { + let mut server_reader = BufReader::new(server_reader); + + let first_request = read_jsonrpc_response_line(&mut server_reader).await; + assert_eq!(first_request["jsonrpc"], "2.0"); + assert_eq!(first_request["method"], "simple_method"); + assert_eq!(first_request["params"]["message"], "first"); + let first_id = first_request["id"].clone(); + assert_ne!(first_id, serde_json::Value::Null); + + let malformed_response = serde_json::json!({ + "jsonrpc": "2.0", + "id": first_id, + "result": { + "wrong_field": "oops" + } + }); + let malformed_line = + format!("{}\n", serde_json::to_string(&malformed_response).unwrap()); + server_writer + .write_all(malformed_line.as_bytes()) + .await + .unwrap(); + server_writer.flush().await.unwrap(); + + let second_request = read_jsonrpc_response_line(&mut server_reader).await; + assert_eq!(second_request["jsonrpc"], "2.0"); + assert_eq!(second_request["method"], "simple_method"); + assert_eq!(second_request["params"]["message"], "second"); + let second_id = second_request["id"].clone(); + assert_ne!(second_id, serde_json::Value::Null); + + let good_response = serde_json::json!({ + "jsonrpc": "2.0", + "id": second_id, + "result": { + "result": "echo: second" + } + }); + let good_line = format!("{}\n", serde_json::to_string(&good_response).unwrap()); + server_writer.write_all(good_line.as_bytes()).await.unwrap(); + server_writer.flush().await.unwrap(); + }); + + let client_result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + let bad_result = tokio::time::timeout( + tokio::time::Duration::from_secs(1), + recv(cx.send_request(SimpleRequest { + message: "first".to_string(), + })), + ) + .await + .expect("malformed response should complete with an error, not hang"); + + let err = bad_result.expect_err( + "malformed response payload should be reported as an error", + ); + assert!(matches!( + err.code, + agent_client_protocol_core::ErrorCode::InternalError + )); + let err_data = serde_json::to_value(&err.data) + .expect("error data should serialize to JSON"); + assert_eq!(err_data["phase"], "deserialization"); + assert_eq!( + err_data["json"], + serde_json::json!({ + "wrong_field": "oops" + }) + ); + + let good_result = tokio::time::timeout( + tokio::time::Duration::from_secs(1), + recv(cx.send_request(SimpleRequest { + message: "second".to_string(), + })), + ) + .await + .expect("connection should remain alive after malformed response")?; + + assert_eq!(good_result.result, "echo: second"); + Ok(()) + }, + ) + .await; + + server_task.await.unwrap(); + assert!( + client_result.is_ok(), + "client should stay alive: {client_result:?}" + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn test_bad_request_params_return_invalid_params_and_connection_stays_alive() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(2048); + let (server_writer, client_reader) = tokio::io::duplex(2048); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(err) = server.connect_to(server_transport).await { + panic!("server should stay alive: {err:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":3,"method":"simple_method","params":{"content":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let invalid_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "error": { + "code": -32602, + "data": { + "error": "missing field `message`", + "json": { + "content": "hello" + }, + "phase": "deserialization" + }, + "message": "Invalid params" + }, + "id": 3, + "jsonrpc": "2.0" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&invalid_response).unwrap()); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":4,"method":"simple_method","params":{"message":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let ok_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 4, + "jsonrpc": "2.0", + "result": { + "result": "echo: hello" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn test_bad_notification_params_send_error_notification_and_connection_stays_alive() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(2048); + let (server_writer, client_reader) = tokio::io::duplex(2048); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_notification( + async |_notif: SimpleNotification, + _connection: ConnectionTo| { + // If we get here, the notification parsed successfully. + Ok(()) + }, + agent_client_protocol_core::on_receive_notification!(), + ) + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(err) = server.connect_to(server_transport).await { + panic!("server should stay alive: {err:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + // Send a notification with bad params (wrong field name). + // Notifications have no "id", so the server sends an error + // notification (id: null) and keeps the connection alive. + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"simple_notification","params":{"wrong_field":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + // The server sends an error notification (id: null) for the + // malformed notification. + let error_notification = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "error": { + "code": -32602, + "data": { + "error": "missing field `message`", + "json": { + "wrong_field": "hello" + }, + "phase": "deserialization" + }, + "message": "Invalid params" + }, + "jsonrpc": "2.0" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&error_notification).unwrap()); + + // Now send a valid request to prove the connection is still alive. + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":10,"method":"simple_method","params":{"message":"after bad notification"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let ok_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 10, + "jsonrpc": "2.0", + "result": { + "result": "echo: after bad notification" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn test_match_dispatch_connectionless_bad_notification_params_emit_no_error_and_connection_stays_alive() + { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + struct ConnectionlessMatchDispatchHandler; + + impl HandleDispatchFrom for ConnectionlessMatchDispatchHandler { + fn describe_chain(&self) -> impl std::fmt::Debug { + "ConnectionlessMatchDispatchHandler" + } + + async fn handle_dispatch_from( + &mut self, + message: Dispatch, + connection: ConnectionTo, + ) -> Result, agent_client_protocol_core::Error> { + match MatchDispatch::new(message) + .if_notification(async move |_notif: SimpleNotification| Ok(())) + .await + .done()? + { + Handled::Yes => Ok(Handled::Yes), + Handled::No { message, retry } => match message.match_request::() { + RequestMatch::Matched(request, responder) => { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + })?; + Ok(Handled::Yes) + } + RequestMatch::Rejected { dispatch, error } => { + dispatch.respond_with_error(error, connection)?; + Ok(Handled::Yes) + } + RequestMatch::Unhandled(message) => Ok(Handled::No { message, retry }), + }, + } + } + } + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(2048); + let (server_writer, client_reader) = tokio::io::duplex(2048); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .with_handler(ConnectionlessMatchDispatchHandler); + + tokio::task::spawn_local(async move { + if let Err(err) = server.connect_to(server_transport).await { + panic!("server should stay alive: {err:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"simple_notification","params":{"wrong_field":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let unexpected = try_read_jsonrpc_response_line( + &mut client_reader, + tokio::time::Duration::from_millis(100), + ) + .await; + assert!( + unexpected.is_none(), + "connectionless MatchDispatch should not emit an out-of-band error for malformed notifications" + ); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":11,"method":"simple_method","params":{"message":"after connectionless bad notification"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let ok_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 11, + "jsonrpc": "2.0", + "result": { + "result": "echo: after connectionless bad notification" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn test_match_dispatch_from_if_message_invalid_params_keeps_connection_alive() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + struct MatchDispatchFromMessageHandler; + + impl HandleDispatchFrom for MatchDispatchFromMessageHandler { + fn describe_chain(&self) -> impl std::fmt::Debug { + "MatchDispatchFromMessageHandler" + } + + async fn handle_dispatch_from( + &mut self, + message: Dispatch, + connection: ConnectionTo, + ) -> Result, agent_client_protocol_core::Error> { + MatchDispatchFrom::new(message, &connection) + .if_message_from( + UntypedRole, + async move |dispatch: Dispatch| { + match dispatch { + Dispatch::Request(request, responder) => { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + } + Dispatch::Notification(_) => Ok(()), + Dispatch::Response(result, router) => { + router.respond_with_result(result) + } + } + }, + ) + .await + .done() + } + } + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(2048); + let (server_writer, client_reader) = tokio::io::duplex(2048); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .with_handler(MatchDispatchFromMessageHandler); + + tokio::task::spawn_local(async move { + if let Err(err) = server.connect_to(server_transport).await { + panic!("server should stay alive: {err:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":5,"method":"simple_method","params":{"content":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let invalid_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "error": { + "code": -32602, + "data": { + "error": "missing field `message`", + "json": { + "content": "hello" + }, + "phase": "deserialization" + }, + "message": "Invalid params" + }, + "id": 5, + "jsonrpc": "2.0" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&invalid_response).unwrap()); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":6,"method":"simple_method","params":{"message":"hello"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let ok_response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 6, + "jsonrpc": "2.0", + "result": { + "result": "echo: hello" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); + }) + .await; +} diff --git a/src/agent-client-protocol-core/tests/jsonrpc_hello.rs b/src/agent-client-protocol-core/tests/jsonrpc_hello.rs index aa62eb6..2a42063 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_hello.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_hello.rs @@ -74,7 +74,7 @@ impl JsonRpcMessage for PingRequest { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } @@ -194,7 +194,7 @@ impl JsonRpcMessage for LogNotification { if !Self::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-core/tests/match_dispatch.rs b/src/agent-client-protocol-core/tests/match_dispatch.rs index b11b73c..72504f8 100644 --- a/src/agent-client-protocol-core/tests/match_dispatch.rs +++ b/src/agent-client-protocol-core/tests/match_dispatch.rs @@ -37,7 +37,7 @@ impl JsonRpcMessage for EchoRequestResponse { if !::matches_method(method) { return Err(agent_client_protocol_core::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-derive/src/lib.rs b/src/agent-client-protocol-derive/src/lib.rs index 58814df..9275858 100644 --- a/src/agent-client-protocol-derive/src/lib.rs +++ b/src/agent-client-protocol-derive/src/lib.rs @@ -89,7 +89,7 @@ pub fn derive_json_rpc_request(input: TokenStream) -> TokenStream { if method != #method { return Err(#krate::Error::method_not_found()); } - #krate::util::json_cast(params) + #krate::util::json_cast_params(params) } } @@ -149,7 +149,7 @@ pub fn derive_json_rpc_notification(input: TokenStream) -> TokenStream { if method != #method { return Err(#krate::Error::method_not_found()); } - #krate::util::json_cast(params) + #krate::util::json_cast_params(params) } } diff --git a/src/agent-client-protocol-test/src/lib.rs b/src/agent-client-protocol-test/src/lib.rs index 843db92..8833d68 100644 --- a/src/agent-client-protocol-test/src/lib.rs +++ b/src/agent-client-protocol-test/src/lib.rs @@ -129,7 +129,7 @@ macro_rules! impl_jr_message { if !Self::matches_method(method) { return Err(crate::Error::method_not_found()); } - agent_client_protocol_core::util::json_cast(params) + agent_client_protocol_core::util::json_cast_params(params) } } };