ccurme 2 weeks ago committed by GitHub
commit e40d8eac5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -174,7 +174,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
self.response_metadata,
other.response_metadata,
add_ints=True,
)
# Merge tool call chunks

@ -3,7 +3,9 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
def merge_dicts(
left: Dict[str, Any], right: Dict[str, Any], add_ints: bool = False
) -> Dict[str, Any]:
"""Merge two dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
@ -31,11 +33,13 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
elif isinstance(merged[right_k], str):
merged[right_k] += right_v
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
merged[right_k] = merge_dicts(merged[right_k], right_v, add_ints=add_ints)
elif isinstance(merged[right_k], list):
merged[right_k] = merge_lists(merged[right_k], right_v)
elif merged[right_k] == right_v:
continue
elif isinstance(merged[right_k], int) and add_ints:
merged[right_k] += right_v
else:
raise TypeError(
f"Additional kwargs key {right_k} already exists in left dict and "

@ -120,6 +120,38 @@ def test_message_chunks() -> None:
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
# Response metadata
chunk_1 = AIMessageChunk(
content="",
response_metadata={
"token_usage": {
"completion_tokens": 1,
"prompt_tokens": 2,
"total_tokens": 3,
},
},
)
chunk_2 = AIMessageChunk(
content="",
response_metadata={
"token_usage": {
"completion_tokens": 4,
"prompt_tokens": 5,
"total_tokens": 9,
},
},
)
assert chunk_1 + chunk_2 == AIMessageChunk(
content="",
response_metadata={
"token_usage": {
"completion_tokens": 5,
"prompt_tokens": 7,
"total_tokens": 12,
},
},
)
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(

@ -483,23 +483,31 @@ class BaseChatOpenAI(BaseChatModel):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if token_usage := chunk.get("usage"):
generation_info = {"token_usage": token_usage}
chunk = ChatGenerationChunk(
message=default_chunk_class(content=""),
generation_info=generation_info,
)
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
@ -583,23 +591,31 @@ class BaseChatOpenAI(BaseChatModel):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if token_usage := chunk.get("usage"):
generation_info = {"token_usage": token_usage}
chunk = ChatGenerationChunk(
message=default_chunk_class(content=""),
generation_info=generation_info,
)
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
@ -1129,6 +1145,29 @@ class ChatOpenAI(BaseChatOpenAI):
"""Return whether this model can be serialized by Langchain."""
return True
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
return super()._stream(*args, **kwargs)
async def _astream(
self,
*args: Any,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)

@ -1268,4 +1268,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "8c22541c0d451c02c7f8d309a488cdcc4709573b706f1c67719a23735b72e745"
content-hash = "0ce2119c6e0db4ba0b0a7a27b61fd1c58c60dd13fbc1483cde74390e159077ae"

@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.46,<0.3"
openai = "^1.24.0"
openai = "^1.26.0"
tiktoken = ">=0.7,<1"
[tool.poetry.group.test]

@ -346,9 +346,32 @@ def test_stream() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
if "token_usage" in chunk.response_metadata:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected only one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert "token_usage" in full.response_metadata
for key in ["completion_tokens", "prompt_tokens", "total_tokens"]:
assert isinstance(full.response_metadata["token_usage"][key], int)
# check not populated
aggregate: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk)
assert "token_usage" not in aggregate.response_metadata
async def test_astream() -> None:
@ -356,9 +379,32 @@ async def test_astream() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
if "token_usage" in chunk.response_metadata:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected only one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert "token_usage" in full.response_metadata
for key in ["completion_tokens", "prompt_tokens", "total_tokens"]:
assert isinstance(full.response_metadata["token_usage"][key], int)
# check not populated
aggregate: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk)
assert "token_usage" not in aggregate.response_metadata
async def test_abatch() -> None:

Loading…
Cancel
Save