1515"""Utility functions for agent engines."""
1616
1717import abc
18+ import asyncio
1819from importlib import metadata as importlib_metadata
1920import inspect
2021import io
108109_AGENT_FRAMEWORK_ATTR = "agent_framework"
109110_ASYNC_API_MODE = "async"
110111_ASYNC_STREAM_API_MODE = "async_stream"
112+ _BIDI_STREAM_API_MODE = "bidi_stream"
111113_BASE_MODULES = set (_BUILTIN_MODULE_NAMES + tuple (_STDLIB_MODULE_NAMES ))
112114_BLOB_FILENAME = "agent_engine.pkl"
113115_DEFAULT_AGENT_FRAMEWORK = "custom"
132134_DEFAULT_STREAM_METHOD_RETURN_TYPE = "Iterable[Any]"
133135_DEFAULT_REQUIRED_PACKAGES = frozenset (["cloudpickle" , "pydantic" ])
134136_DEFAULT_STREAM_METHOD_NAME = "stream_query"
137+ _DEFAULT_BIDI_STREAM_METHOD_NAME = "bidi_stream_query"
135138_EXTRA_PACKAGES_FILE = "dependencies.tar.gz"
136139_FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE = (
137140 "Failed to register API methods. Please follow the guide to "
@@ -202,6 +205,15 @@ def stream_query(self, **kwargs) -> Iterator[Any]: # type: ignore[no-untyped-de
202205 """Stream responses to serve the user query."""
203206
204207
208+ @typing .runtime_checkable
209+ class BidiStreamQueryable (Protocol ):
210+ """Protocol for Agent Engines that can stream requests and responses."""
211+
212+ @abc .abstractmethod
213+ async def bidi_stream_query (self , input_queue : asyncio .Queue ) -> AsyncIterator [Any ]:
214+ """Stream requests and responses to serve the user queries."""
215+
216+
205217@typing .runtime_checkable
206218class Cloneable (Protocol ):
207219 """Protocol for Agent Engines that can be cloned."""
@@ -234,6 +246,7 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]:
234246 OperationRegistrable ,
235247 Queryable ,
236248 StreamQueryable ,
249+ BidiStreamQueryable ,
237250]
238251
239252
@@ -557,6 +570,9 @@ def _generate_schema(
557570 inspect .Parameter .KEYWORD_ONLY ,
558571 inspect .Parameter .POSITIONAL_ONLY ,
559572 )
573+ # For a bidi endpoint, it requires an asyncio.Queue as the input, but
574+ # it is not JSON serializable. We hence exclude it from the schema.
575+ and param .annotation != asyncio .Queue
560576 }
561577 parameters = pydantic .create_model (f .__name__ , ** fields_dict ).schema ()
562578 # Postprocessing
@@ -656,6 +672,8 @@ def _get_registered_operations(
656672 operations [_STREAM_API_MODE ] = [_DEFAULT_STREAM_METHOD_NAME ]
657673 if isinstance (agent , AsyncStreamQueryable ):
658674 operations [_ASYNC_STREAM_API_MODE ] = [_DEFAULT_ASYNC_STREAM_METHOD_NAME ]
675+ if isinstance (agent , BidiStreamQueryable ):
676+ operations [_BIDI_STREAM_API_MODE ] = [_DEFAULT_BIDI_STREAM_METHOD_NAME ]
659677 return operations
660678
661679
@@ -839,6 +857,10 @@ def _register_api_methods_or_raise(
839857 f" contain an `{ _MODE_KEY_IN_SCHEMA } ` field."
840858 )
841859 api_mode = operation_schema .get (_MODE_KEY_IN_SCHEMA )
860+ # For bidi stream api mode, we don't need to wrap the operation.
861+ if api_mode == _BIDI_STREAM_API_MODE :
862+ continue
863+
842864 if _METHOD_NAME_KEY_IN_SCHEMA not in operation_schema :
843865 raise ValueError (
844866 f"Operation schema { operation_schema } does not"
@@ -1212,6 +1234,7 @@ def _validate_agent_or_raise(
12121234 * a callable method named `query`
12131235 * a callable method named `stream_query`
12141236 * a callable method named `async_stream_query`
1237+ * a callable method named `bidi_stream_query`
12151238 * a callable method named `register_operations`
12161239
12171240 Args:
@@ -1246,6 +1269,9 @@ def _validate_agent_or_raise(
12461269 is_async_stream_queryable = isinstance (agent , AsyncStreamQueryable ) and callable (
12471270 agent .async_stream_query
12481271 )
1272+ is_bidi_stream_queryable = isinstance (agent , BidiStreamQueryable ) and callable (
1273+ agent .bidi_stream_query
1274+ )
12491275 is_operation_registrable = isinstance (agent , OperationRegistrable ) and callable (
12501276 agent .register_operations
12511277 )
@@ -1255,12 +1281,13 @@ def _validate_agent_or_raise(
12551281 or is_async_queryable
12561282 or is_stream_queryable
12571283 or is_operation_registrable
1284+ or is_bidi_stream_queryable
12581285 or is_async_stream_queryable
12591286 ):
12601287 raise TypeError (
12611288 "agent_engine has none of the following callable methods: "
1262- "`query`, `async_query`, `stream_query`, `async_stream_query` or "
1263- "`register_operations`."
1289+ "`query`, `async_query`, `stream_query`, `async_stream_query`, "
1290+ "`bidi_stream_query`, or ` register_operations`."
12641291 )
12651292
12661293 if is_queryable :
@@ -1299,6 +1326,15 @@ def _validate_agent_or_raise(
12991326 " missing `self` argument in the agent.async_stream_query method."
13001327 ) from err
13011328
1329+ if is_bidi_stream_queryable :
1330+ try :
1331+ inspect .signature (getattr (agent , "bidi_stream_query" ))
1332+ except ValueError as err :
1333+ raise ValueError (
1334+ "Invalid bidi_stream_query signature. This might be due to a "
1335+ " missing `self` argument in the agent.bidi_stream_query method."
1336+ ) from err
1337+
13021338 if is_operation_registrable :
13031339 try :
13041340 inspect .signature (getattr (agent , "register_operations" ))
0 commit comments