|
|
|
@ -174,9 +174,16 @@ class ChatMlflow(BaseChatModel):
|
|
|
|
|
)
|
|
|
|
|
# TODO: check if `_client.predict_stream` is available.
|
|
|
|
|
chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data)
|
|
|
|
|
first_chunk_role = None
|
|
|
|
|
for chunk in chunk_iter:
|
|
|
|
|
choice = chunk["choices"][0]
|
|
|
|
|
chunk = ChatMlflow._convert_delta_to_message_chunk(choice["delta"])
|
|
|
|
|
|
|
|
|
|
chunk_delta = choice["delta"]
|
|
|
|
|
if first_chunk_role is None:
|
|
|
|
|
first_chunk_role = chunk_delta["role"]
|
|
|
|
|
chunk = ChatMlflow._convert_delta_to_message_chunk(
|
|
|
|
|
chunk_delta, first_chunk_role
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
generation_info = {}
|
|
|
|
|
if finish_reason := choice.get("finish_reason"):
|
|
|
|
@ -225,8 +232,10 @@ class ChatMlflow(BaseChatModel):
|
|
|
|
|
return ChatMessage(content=content, role=role)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
|
|
|
|
|
role = _dict["role"]
|
|
|
|
|
def _convert_delta_to_message_chunk(
|
|
|
|
|
_dict: Mapping[str, Any], default_role: str
|
|
|
|
|
) -> BaseMessageChunk:
|
|
|
|
|
role = _dict.get("role", default_role)
|
|
|
|
|
content = _dict["content"]
|
|
|
|
|
if role == "user":
|
|
|
|
|
return HumanMessageChunk(content=content)
|
|
|
|
|