diff --git a/cookbook/rag_upstage_layout_analysis_groundedness_check.ipynb b/cookbook/rag_upstage_layout_analysis_groundedness_check.ipynb index 189f421330..6adc441142 100644 --- a/cookbook/rag_upstage_layout_analysis_groundedness_check.ipynb +++ b/cookbook/rag_upstage_layout_analysis_groundedness_check.ipynb @@ -17,15 +17,14 @@ "from typing import List\n", "\n", "from langchain_community.vectorstores import DocArrayInMemorySearch\n", - "from langchain_core.documents.base import Document\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_core.runnables import RunnablePassthrough\n", "from langchain_core.runnables.base import RunnableSerializable\n", "from langchain_upstage import (\n", " ChatUpstage,\n", - " GroundednessCheck,\n", " UpstageEmbeddings,\n", + " UpstageGroundednessCheck,\n", " UpstageLayoutAnalysisLoader,\n", ")\n", "\n", @@ -50,7 +49,7 @@ "\n", "retrieved_docs = retriever.get_relevant_documents(\"How many parameters in SOLAR model?\")\n", "\n", - "groundedness_check = GroundednessCheck()\n", + "groundedness_check = UpstageGroundednessCheck()\n", "groundedness = \"\"\n", "while groundedness != \"grounded\":\n", " chain: RunnableSerializable = RunnablePassthrough() | prompt | model | output_parser\n", @@ -62,14 +61,10 @@ " }\n", " )\n", "\n", - " # convert all Documents to string\n", - " def formatDocumentsAsString(docs: List[Document]) -> str:\n", - " return \"\\n\".join([doc.page_content for doc in docs])\n", - "\n", - " groundedness = groundedness_check.run(\n", + " groundedness = groundedness_check.invoke(\n", " {\n", - " \"context\": formatDocumentsAsString(retrieved_docs),\n", - " \"query\": result,\n", + " \"context\": retrieved_docs,\n", + " \"answer\": result,\n", " }\n", " )" ] diff --git a/docs/docs/integrations/providers/upstage.ipynb b/docs/docs/integrations/providers/upstage.ipynb index c6885d532f..7f15d55e99 100644 --- a/docs/docs/integrations/providers/upstage.ipynb +++ b/docs/docs/integrations/providers/upstage.ipynb @@ -52,7 +52,7 @@ "| --- | --- | --- | --- |\n", "| Chat | Build assistants using Solar Mini Chat | `from langchain_upstage import ChatUpstage` | [Go](../../chat/upstage) |\n", "| Text Embedding | Embed strings to vectors | `from langchain_upstage import UpstageEmbeddings` | [Go](../../text_embedding/upstage) |\n", - "| Groundedness Check | Verify groundedness of assistant's response | `from langchain_upstage import GroundednessCheck` | [Go](../../tools/upstage_groundedness_check) |\n", + "| Groundedness Check | Verify groundedness of assistant's response | `from langchain_upstage import UpstageGroundednessCheck` | [Go](../../tools/upstage_groundedness_check) |\n", "| Layout Analysis | Serialize documents with tables and figures | `from langchain_upstage import UpstageLayoutAnalysisLoader` | [Go](../../document_loaders/upstage) |\n", "\n", "See [documentations](https://developers.upstage.ai/) for more details about the features." @@ -145,15 +145,15 @@ }, "outputs": [], "source": [ - "from langchain_upstage import GroundednessCheck\n", + "from langchain_upstage import UpstageGroundednessCheck\n", "\n", - "groundedness_check = GroundednessCheck()\n", + "groundedness_check = UpstageGroundednessCheck()\n", "\n", "request_input = {\n", " \"context\": \"Mauna Kea is an inactive volcano on the island of Hawaii. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n", - " \"query\": \"Mauna Kea is 5,207.3 meters tall.\",\n", + " \"answer\": \"Mauna Kea is 5,207.3 meters tall.\",\n", "}\n", - "response = groundedness_check.run(request_input)\n", + "response = groundedness_check.invoke(request_input)\n", "print(response)" ] }, diff --git a/docs/docs/integrations/tools/upstage_groundedness_check.ipynb b/docs/docs/integrations/tools/upstage_groundedness_check.ipynb index 5a35e0e257..e6cf2bcab2 100644 --- a/docs/docs/integrations/tools/upstage_groundedness_check.ipynb +++ b/docs/docs/integrations/tools/upstage_groundedness_check.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "a83d4da0", "metadata": {}, "outputs": [], @@ -65,21 +65,21 @@ "source": [ "## Usage\n", "\n", - "Initialize `GroundednessCheck` class." + "Initialize `UpstageGroundednessCheck` class." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "b7373380c01cefbe", "metadata": { "collapsed": false }, "outputs": [], "source": [ - "from langchain_upstage import GroundednessCheck\n", + "from langchain_upstage import UpstageGroundednessCheck\n", "\n", - "groundedness_check = GroundednessCheck()" + "groundedness_check = UpstageGroundednessCheck()" ] }, { @@ -92,38 +92,22 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "1e0115e3b511f57", "metadata": { "collapsed": false, "is_executing": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "content='notGrounded' response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 198, 'total_tokens': 204}, 'model_name': 'solar-1-mini-answer-verification', 'system_fingerprint': '', 'finish_reason': 'stop', 'logprobs': None} id='run-ce7b5787-2ed0-4a68-9de4-c0e91a824147-0'\n" - ] - } - ], + "outputs": [], "source": [ "request_input = {\n", " \"context\": \"Mauna Kea is an inactive volcano on the island of Hawai'i. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n", - " \"query\": \"Mauna Kea is 5,207.3 meters tall.\",\n", + " \"answer\": \"Mauna Kea is 5,207.3 meters tall.\",\n", "}\n", "\n", - "response = groundedness_check.run(request_input)\n", + "response = groundedness_check.invoke(request_input)\n", "print(response)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "054b5031", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/libs/partners/upstage/langchain_upstage/__init__.py b/libs/partners/upstage/langchain_upstage/__init__.py index 9ecef2810d..77c1b91724 100644 --- a/libs/partners/upstage/langchain_upstage/__init__.py +++ b/libs/partners/upstage/langchain_upstage/__init__.py @@ -2,12 +2,16 @@ from langchain_upstage.chat_models import ChatUpstage from langchain_upstage.embeddings import UpstageEmbeddings from langchain_upstage.layout_analysis import UpstageLayoutAnalysisLoader from langchain_upstage.layout_analysis_parsers import UpstageLayoutAnalysisParser -from langchain_upstage.tools.groundedness_check import GroundednessCheck +from langchain_upstage.tools.groundedness_check import ( + GroundednessCheck, + UpstageGroundednessCheck, +) __all__ = [ "ChatUpstage", "UpstageEmbeddings", "UpstageLayoutAnalysisLoader", "UpstageLayoutAnalysisParser", + "UpstageGroundednessCheck", "GroundednessCheck", ] diff --git a/libs/partners/upstage/langchain_upstage/tools/groundedness_check.py b/libs/partners/upstage/langchain_upstage/tools/groundedness_check.py index 68d9453b40..eac1eb9e27 100644 --- a/libs/partners/upstage/langchain_upstage/tools/groundedness_check.py +++ b/libs/partners/upstage/langchain_upstage/tools/groundedness_check.py @@ -1,10 +1,12 @@ import os -from typing import Any, Literal, Optional, Type, Union +from typing import Any, List, Literal, Optional, Type, Union +from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) +from langchain_core.documents import Document from langchain_core.messages import AIMessage, HumanMessage from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.tools import BaseTool @@ -13,16 +15,18 @@ from langchain_core.utils import convert_to_secret_str from langchain_upstage import ChatUpstage -class GroundednessCheckInput(BaseModel): +class UpstageGroundednessCheckInput(BaseModel): """Input for the Groundedness Check tool.""" - context: str = Field(description="context in which the answer should be verified") - query: str = Field( + context: Union[str, List[Document]] = Field( + description="context in which the answer should be verified" + ) + answer: str = Field( description="assistant's reply or a text that is subject to groundedness check" ) -class GroundednessCheck(BaseTool): +class UpstageGroundednessCheck(BaseTool): """Tool that checks the groundedness of a context and an assistant message. To use, you should have the environment variable `UPSTAGE_API_KEY` @@ -31,15 +35,15 @@ class GroundednessCheck(BaseTool): Example: .. code-block:: python - from langchain_upstage import GroundednessCheck + from langchain_upstage import UpstageGroundednessCheck - tool = GroundednessCheck() + tool = UpstageGroundednessCheck() """ name: str = "groundedness_check" description: str = ( "A tool that checks the groundedness of an assistant response " - "to user-provided context. GroundednessCheck ensures that " + "to user-provided context. UpstageGroundednessCheck ensures that " "the assistant’s response is not only relevant but also " "precisely aligned with the user's initial context, " "promoting a more reliable and context-aware interaction. " @@ -50,7 +54,7 @@ class GroundednessCheck(BaseTool): upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") api_wrapper: ChatUpstage - args_schema: Type[BaseModel] = GroundednessCheckInput + args_schema: Type[BaseModel] = UpstageGroundednessCheckInput def __init__(self, **kwargs: Any) -> None: upstage_api_key = kwargs.get("upstage_api_key", None) @@ -73,25 +77,41 @@ class GroundednessCheck(BaseTool): ) super().__init__(upstage_api_key=upstage_api_key, api_wrapper=api_wrapper) + def formatDocumentsAsString(self, docs: List[Document]) -> str: + return "\n".join([doc.page_content for doc in docs]) + def _run( self, - context: str, - query: str, + context: Union[str, List[Document]], + answer: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]: """Use the tool.""" + if isinstance(context, List): + context = self.formatDocumentsAsString(context) response = self.api_wrapper.invoke( - [HumanMessage(context), AIMessage(query)], stream=False + [HumanMessage(context), AIMessage(answer)], stream=False ) return str(response.content) async def _arun( self, - context: str, - query: str, + context: Union[str, List[Document]], + answer: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]: + if isinstance(context, List): + context = self.formatDocumentsAsString(context) response = await self.api_wrapper.ainvoke( - [HumanMessage(context), AIMessage(query)], stream=False + [HumanMessage(context), AIMessage(answer)], stream=False ) return str(response.content) + + +@deprecated( + since="0.1.3", + removal="0.2.0", + alternative_import="langchain_upstage.UpstageGroundednessCheck", +) +class GroundednessCheck(UpstageGroundednessCheck): + pass diff --git a/libs/partners/upstage/tests/integration_tests/test_groundedness_check.py b/libs/partners/upstage/tests/integration_tests/test_groundedness_check.py index 29059f47e1..d4b56d795d 100644 --- a/libs/partners/upstage/tests/integration_tests/test_groundedness_check.py +++ b/libs/partners/upstage/tests/integration_tests/test_groundedness_check.py @@ -2,34 +2,62 @@ import os import openai import pytest +from langchain_core.documents import Document -from langchain_upstage import GroundednessCheck +from langchain_upstage import GroundednessCheck, UpstageGroundednessCheck -def test_langchain_upstage_groundedness_check() -> None: +def test_langchain_upstage_groundedness_check_deprecated() -> None: """Test Upstage Groundedness Check.""" tool = GroundednessCheck() - output = tool.run({"context": "foo bar", "query": "bar foo"}) + output = tool.invoke({"context": "foo bar", "answer": "bar foo"}) assert output in ["grounded", "notGrounded", "notSure"] api_key = os.environ.get("UPSTAGE_API_KEY", None) tool = GroundednessCheck(upstage_api_key=api_key) - output = tool.run({"context": "foo bar", "query": "bar foo"}) + output = tool.invoke({"context": "foo bar", "answer": "bar foo"}) + + assert output in ["grounded", "notGrounded", "notSure"] + + +def test_langchain_upstage_groundedness_check() -> None: + """Test Upstage Groundedness Check.""" + tool = UpstageGroundednessCheck() + output = tool.invoke({"context": "foo bar", "answer": "bar foo"}) + + assert output in ["grounded", "notGrounded", "notSure"] + + api_key = os.environ.get("UPSTAGE_API_KEY", None) + + tool = UpstageGroundednessCheck(upstage_api_key=api_key) + output = tool.invoke({"context": "foo bar", "answer": "bar foo"}) + + assert output in ["grounded", "notGrounded", "notSure"] + + +def test_langchain_upstage_groundedness_check_with_documents_input() -> None: + """Test Upstage Groundedness Check.""" + tool = UpstageGroundednessCheck() + docs = [ + Document(page_content="foo bar"), + Document(page_content="bar foo"), + ] + output = tool.invoke({"context": docs, "answer": "bar foo"}) assert output in ["grounded", "notGrounded", "notSure"] def test_langchain_upstage_groundedness_check_fail_with_wrong_api_key() -> None: - tool = GroundednessCheck(api_key="wrong-key") + tool = UpstageGroundednessCheck(api_key="wrong-key") with pytest.raises(openai.AuthenticationError): - tool.run({"context": "foo bar", "query": "bar foo"}) + tool.invoke({"context": "foo bar", "answer": "bar foo"}) async def test_langchain_upstage_groundedness_check_async() -> None: """Test Upstage Groundedness Check asynchronous.""" - tool = GroundednessCheck() - output = await tool.arun({"context": "foo bar", "query": "bar foo"}) + tool = UpstageGroundednessCheck() + output = await tool.ainvoke({"context": "foo bar", "answer": "bar foo"}) assert output in ["grounded", "notGrounded", "notSure"] diff --git a/libs/partners/upstage/tests/unit_tests/test_groundedness_check.py b/libs/partners/upstage/tests/unit_tests/test_groundedness_check.py index 4891dec01b..fdb0502574 100644 --- a/libs/partners/upstage/tests/unit_tests/test_groundedness_check.py +++ b/libs/partners/upstage/tests/unit_tests/test_groundedness_check.py @@ -1,12 +1,12 @@ import os -from langchain_upstage import GroundednessCheck +from langchain_upstage import UpstageGroundednessCheck os.environ["UPSTAGE_API_KEY"] = "foo" def test_initialization() -> None: """Test embedding model initialization.""" - GroundednessCheck() - GroundednessCheck(upstage_api_key="key") - GroundednessCheck(api_key="key") + UpstageGroundednessCheck() + UpstageGroundednessCheck(upstage_api_key="key") + UpstageGroundednessCheck(api_key="key") diff --git a/libs/partners/upstage/tests/unit_tests/test_imports.py b/libs/partners/upstage/tests/unit_tests/test_imports.py index 7fe6498700..900826a074 100644 --- a/libs/partners/upstage/tests/unit_tests/test_imports.py +++ b/libs/partners/upstage/tests/unit_tests/test_imports.py @@ -2,10 +2,11 @@ from langchain_upstage import __all__ EXPECTED_ALL = [ "ChatUpstage", + "GroundednessCheck", "UpstageEmbeddings", "UpstageLayoutAnalysisLoader", "UpstageLayoutAnalysisParser", - "GroundednessCheck", + "UpstageGroundednessCheck", ]