WeichenXu 2 weeks ago committed by GitHub
commit 281c3a96e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -174,9 +174,16 @@ class ChatMlflow(BaseChatModel):
)
# TODO: check if `_client.predict_stream` is available.
chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data)
first_chunk_role = None
for chunk in chunk_iter:
choice = chunk["choices"][0]
chunk = ChatMlflow._convert_delta_to_message_chunk(choice["delta"])
chunk_delta = choice["delta"]
if first_chunk_role is None:
first_chunk_role = chunk_delta["role"]
chunk = ChatMlflow._convert_delta_to_message_chunk(
chunk_delta, first_chunk_role
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
@ -225,8 +232,10 @@ class ChatMlflow(BaseChatModel):
return ChatMessage(content=content, role=role)
@staticmethod
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
role = _dict["role"]
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_role: str
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
content = _dict["content"]
if role == "user":
return HumanMessageChunk(content=content)

Loading…
Cancel
Save