Surya Rath 2 weeks ago committed by GitHub
commit 0b82f48cdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -26,7 +26,9 @@ from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
import openai
from openai._types import NotGiven, NOT_GIVEN
from openai.types.beta.threads import ThreadMessage
from openai.types.beta.assistant import ToolResources as AssistantToolResources
from openai.types.beta.threads.required_action_function_tool_call import (
RequiredActionFunctionToolCall,
)
@ -59,7 +61,7 @@ def _get_openai_client() -> openai.OpenAI:
try:
import openai
return openai.OpenAI()
return openai.OpenAI(default_headers={"OpenAI-Beta": "assistants=v2"})
except ImportError as e:
raise ImportError(
"Unable to import openai, please install with `pip install openai`."
@ -98,6 +100,22 @@ def _is_assistants_builtin_tool(
and (tool["type"] in assistants_builtin_tools)
)
def _convert_file_ids_into_attachments(file_ids: list) -> list:
"""
Convert file_ids into attachments
File search and Code interpreter will be turned on by default
"""
attachments = []
for id in file_ids:
attachments.append({
'file_id': id,
'tools': [
{'type': 'file_search'},
{'type': 'code_interpreter'}
]
})
return attachments
def _get_assistants_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
@ -225,6 +243,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
name: str,
instructions: str,
tools: Sequence[Union[BaseTool, dict]],
tool_resources: Optional[Union[AssistantToolResources, dict]] | NotGiven = NOT_GIVEN,
model: str,
*,
client: Optional[Union[openai.OpenAI, openai.AzureOpenAI]] = None,
@ -236,20 +255,22 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
name: Assistant name.
instructions: Assistant instructions.
tools: Assistant tools. Can be passed in OpenAI format or as BaseTools.
tool_resources: Assistant tool resources. Can be passed in OpenAI format
model: Assistant model to use.
client: OpenAI or AzureOpenAI client.
Will create default OpenAI client if not specified.
Will create default OpenAI client (Assistant v2) if not specified.
Returns:
OpenAIAssistantRunnable configured to run using the created assistant.
"""
client = client or _get_openai_client()
assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
tool_resources=tool_resources,
model=model,
file_ids=kwargs.get("file_ids"),
)
return cls(assistant_id=assistant.id, client=client, **kwargs)
@ -264,13 +285,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation.
file_ids: File ids to include in new run. Used for retrieval.
file_ids: (deprecated) File ids to include in new run. Use 'attachments' instead
attachments: Assistant files to include in new run. (v2 API).
message_metadata: Metadata to associate with new message.
thread_metadata: Metadata to associate with new thread. Only relevant
when new thread being created.
instructions: Additional run instructions.
model: Override Assistant model for this run.
tools: Override Assistant tools for this run.
tool_resources: Override Assistant tool resources for this run (v2 API).
run_metadata: Metadata to associate with new run.
config: Runnable config:
@ -290,6 +313,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
files = _convert_file_ids_into_attachments(kwargs.get("file_ids", []))
attachments = kwargs.get("attachments", []) + files
try:
# Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"):
@ -304,7 +331,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
{
"role": "user",
"content": input["content"],
"file_ids": input.get("file_ids", []),
"attachments": attachments,
"metadata": input.get("message_metadata"),
}
],
@ -317,7 +344,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
input["thread_id"],
content=input["content"],
role="user",
file_ids=input.get("file_ids", []),
attachments=attachments,
metadata=input.get("message_metadata"),
)
run = self._create_run(input)
@ -344,6 +371,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
name: str,
instructions: str,
tools: Sequence[Union[BaseTool, dict]],
tool_resources: Optional[Union[AssistantToolResources, dict]] | NotGiven = NOT_GIVEN,
model: str,
*,
async_client: Optional[
@ -357,6 +385,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
name: Assistant name.
instructions: Assistant instructions.
tools: Assistant tools. Can be passed in OpenAI format or as BaseTools.
tool_resources: Assistant tool resources. Can be passed in OpenAI format
model: Assistant model to use.
async_client: AsyncOpenAI client.
Will create default async_client if not specified.
@ -366,12 +395,13 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
"""
async_client = async_client or _get_openai_async_client()
openai_tools = [_get_assistants_tool(tool) for tool in tools]
assistant = await async_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=openai_tools, # type: ignore
tool_resources=tool_resources,
model=model,
file_ids=kwargs.get("file_ids"),
)
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
@ -386,13 +416,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation.
file_ids: File ids to include in new run. Used for retrieval.
file_ids: (deprecated) File ids to include in new run. Use 'attachments' instead
attachments: Assistant files to include in new run. (v2 API).
message_metadata: Metadata to associate with new message.
thread_metadata: Metadata to associate with new thread. Only relevant
when new thread being created.
instructions: Additional run instructions.
model: Override Assistant model for this run.
tools: Override Assistant tools for this run.
tool_resources: Override Assistant tool resources for this run (v2 API).
run_metadata: Metadata to associate with new run.
config: Runnable config:
@ -412,6 +444,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
files = _convert_file_ids_into_attachments(kwargs.get("file_ids", []))
attachments = kwargs.get("attachments", []) + files
try:
# Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"):
@ -428,7 +464,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
{
"role": "user",
"content": input["content"],
"file_ids": input.get("file_ids", []),
"attachments": attachments,
"metadata": input.get("message_metadata"),
}
],
@ -441,7 +477,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
input["thread_id"],
content=input["content"],
role="user",
file_ids=input.get("file_ids", []),
attachments=attachments,
metadata=input.get("message_metadata"),
)
run = await self._acreate_run(input)
@ -488,7 +524,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "run_metadata")
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
}
return self.client.beta.threads.runs.create(
input["thread_id"],
@ -500,7 +536,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "run_metadata")
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
}
run = self.client.beta.threads.create_and_run(
assistant_id=self.assistant_id,
@ -616,7 +652,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "run_metadata")
if k in ("instructions", "model", "tools", "tool_resources" "run_metadata")
}
return await self.async_client.beta.threads.runs.create(
input["thread_id"],
@ -628,7 +664,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "run_metadata")
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
}
run = await self.async_client.beta.threads.create_and_run(
assistant_id=self.assistant_id,

Loading…
Cancel
Save