|
|
|
@ -142,3 +142,32 @@ def test_streaming_tool_call_no_tool_calls() -> None:
|
|
|
|
|
acc = chunk if acc is None else acc + chunk
|
|
|
|
|
assert acc.content != ""
|
|
|
|
|
assert "tool_calls" not in acc.additional_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_invoke_multiple_tools() -> None:
|
|
|
|
|
llm = ChatCohere(temperature=0)
|
|
|
|
|
|
|
|
|
|
class Calculator(BaseModel):
|
|
|
|
|
a: int
|
|
|
|
|
b: int
|
|
|
|
|
|
|
|
|
|
class ConversationResponseTool(BaseModel):
|
|
|
|
|
response: str
|
|
|
|
|
|
|
|
|
|
tool_llm = llm.bind_tools([Calculator, ConversationResponseTool])
|
|
|
|
|
|
|
|
|
|
result = tool_llm.invoke("What is the capital of paris")
|
|
|
|
|
print(result)
|
|
|
|
|
|
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|
additional_kwargs = result.additional_kwargs
|
|
|
|
|
assert "tool_calls" in additional_kwargs
|
|
|
|
|
assert len(additional_kwargs["tool_calls"]) == 1
|
|
|
|
|
assert (
|
|
|
|
|
additional_kwargs["tool_calls"][0]["function"]["name"]
|
|
|
|
|
== "ConversationResponseTool"
|
|
|
|
|
)
|
|
|
|
|
args = json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"])
|
|
|
|
|
|
|
|
|
|
assert "response" in args
|
|
|
|
|
assert "Paris" in args["response"]
|
|
|
|
|