You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/docs/how_to/custom_llm.ipynb

450 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "9e9b7651",
"metadata": {},
"source": [
"# How to create a custom LLM class\n",
"\n",
"This notebook goes over how to create a custom LLM wrapper, in case you want to use your own LLM or a different wrapper than one that is supported in LangChain.\n",
"\n",
"Wrapping your LLM with the standard `LLM` interface allow you to use your LLM in existing LangChain programs with minimal code modifications!\n",
"\n",
"As an bonus, your LLM will automatically become a LangChain `Runnable` and will benefit from some optimizations out of the box, async support, the `astream_events` API, etc.\n",
"\n",
"## Implementation\n",
"\n",
"There are only two required things that a custom LLM needs to implement:\n",
"\n",
"\n",
"| Method | Description |\n",
"|---------------|---------------------------------------------------------------------------|\n",
"| `_call` | Takes in a string and some optional stop words, and returns a string. Used by `invoke`. |\n",
"| `_llm_type` | A property that returns a string, used for logging purposes only. \n",
"\n",
"\n",
"\n",
"Optional implementations: \n",
"\n",
"\n",
"| Method | Description |\n",
"|----------------------|-----------------------------------------------------------------------------------------------------------|\n",
"| `_identifying_params` | Used to help with identifying the model and printing the LLM; should return a dictionary. This is a **@property**. |\n",
"| `_acall` | Provides an async native implementation of `_call`, used by `ainvoke`. |\n",
"| `_stream` | Method to stream the output token by token. |\n",
"| `_astream` | Provides an async native implementation of `_stream`; in newer LangChain versions, defaults to `_stream`. |\n",
"\n",
"\n",
"\n",
"Let's implement a simple custom LLM that just returns the first n characters of the input."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2e9bb32f-6fd1-46ac-b32f-d175663710c0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from typing import Any, Dict, Iterator, List, Mapping, Optional\n",
"\n",
"from langchain_core.callbacks.manager import CallbackManagerForLLMRun\n",
"from langchain_core.language_models.llms import LLM\n",
"from langchain_core.outputs import GenerationChunk\n",
"\n",
"\n",
"class CustomLLM(LLM):\n",
" \"\"\"A custom chat model that echoes the first `n` characters of the input.\n",
"\n",
" When contributing an implementation to LangChain, carefully document\n",
" the model including the initialization parameters, include\n",
" an example of how to initialize the model and include any relevant\n",
" links to the underlying models documentation or API.\n",
"\n",
" Example:\n",
"\n",
" .. code-block:: python\n",
"\n",
" model = CustomChatModel(n=2)\n",
" result = model.invoke([HumanMessage(content=\"hello\")])\n",
" result = model.batch([[HumanMessage(content=\"hello\")],\n",
" [HumanMessage(content=\"world\")]])\n",
" \"\"\"\n",
"\n",
" n: int\n",
" \"\"\"The number of characters from the last message of the prompt to be echoed.\"\"\"\n",
"\n",
" def _call(\n",
" self,\n",
" prompt: str,\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> str:\n",
" \"\"\"Run the LLM on the given input.\n",
"\n",
" Override this method to implement the LLM logic.\n",
"\n",
" Args:\n",
" prompt: The prompt to generate from.\n",
" stop: Stop words to use when generating. Model output is cut off at the\n",
" first occurrence of any of the stop substrings.\n",
" If stop tokens are not supported consider raising NotImplementedError.\n",
" run_manager: Callback manager for the run.\n",
" **kwargs: Arbitrary additional keyword arguments. These are usually passed\n",
" to the model provider API call.\n",
"\n",
" Returns:\n",
" The model output as a string. Actual completions SHOULD NOT include the prompt.\n",
" \"\"\"\n",
" if stop is not None:\n",
" raise ValueError(\"stop kwargs are not permitted.\")\n",
" return prompt[: self.n]\n",
"\n",
" def _stream(\n",
" self,\n",
" prompt: str,\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> Iterator[GenerationChunk]:\n",
" \"\"\"Stream the LLM on the given prompt.\n",
"\n",
" This method should be overridden by subclasses that support streaming.\n",
"\n",
" If not implemented, the default behavior of calls to stream will be to\n",
" fallback to the non-streaming version of the model and return\n",
" the output as a single chunk.\n",
"\n",
" Args:\n",
" prompt: The prompt to generate from.\n",
" stop: Stop words to use when generating. Model output is cut off at the\n",
" first occurrence of any of these substrings.\n",
" run_manager: Callback manager for the run.\n",
" **kwargs: Arbitrary additional keyword arguments. These are usually passed\n",
" to the model provider API call.\n",
"\n",
" Returns:\n",
" An iterator of GenerationChunks.\n",
" \"\"\"\n",
" for char in prompt[: self.n]:\n",
" chunk = GenerationChunk(text=char)\n",
" if run_manager:\n",
" run_manager.on_llm_new_token(chunk.text, chunk=chunk)\n",
"\n",
" yield chunk\n",
"\n",
" @property\n",
" def _identifying_params(self) -> Dict[str, Any]:\n",
" \"\"\"Return a dictionary of identifying parameters.\"\"\"\n",
" return {\n",
" # The model name allows users to specify custom token counting\n",
" # rules in LLM monitoring applications (e.g., in LangSmith users\n",
" # can provide per token pricing for their model and monitor\n",
" # costs for the given LLM.)\n",
" \"model_name\": \"CustomChatModel\",\n",
" }\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" \"\"\"Get the type of language model used by this chat model. Used for logging purposes only.\"\"\"\n",
" return \"custom\""
]
},
{
"cell_type": "markdown",
"id": "f614fb7b-e476-4d81-821b-57a2ebebe21c",
"metadata": {
"tags": []
},
"source": [
"### Let's test it 🧪"
]
},
{
"cell_type": "markdown",
"id": "e3feae15-4afc-49f4-8542-93867d4ea769",
"metadata": {
"tags": []
},
"source": [
"This LLM will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "dfff4a95-99b2-4dba-b80d-9c3855046ef1",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mCustomLLM\u001b[0m\n",
"Params: {'model_name': 'CustomChatModel'}\n"
]
}
],
"source": [
"llm = CustomLLM(n=5)\n",
"print(llm)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8cd49199",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'This '"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.invoke(\"This is a foobar thing\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "511b3cb1-9c6f-49b6-9002-a2ec490632b0",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'world'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.ainvoke(\"world\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d9d5bec2-d60a-4ebd-a97d-ac32c98ab02f",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['woof ', 'meow ']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.batch([\"woof woof woof\", \"meow meow meow\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fe246b29-7a93-4bef-8861-389445598c25",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['woof ', 'meow ']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.abatch([\"woof woof woof\", \"meow meow meow\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3a67c38f-b83b-4eb9-a231-441c55ee8c82",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"h|e|l|l|o|"
]
}
],
"source": [
"async for token in llm.astream(\"hello\"):\n",
" print(token, end=\"|\", flush=True)"
]
},
{
"cell_type": "markdown",
"id": "b62c282b-3a35-4529-aac4-2c2f0916790e",
"metadata": {},
"source": [
"Let's confirm that in integrates nicely with other `LangChain` APIs."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d5578e74-7fa8-4673-afee-7a59d442aaff",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "672ff664-8673-4832-9f4f-335253880141",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", \"you are a bot\"), (\"human\", \"{input}\")]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c400538a-9146-4c93-9fac-293d8f9ca6bf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = CustomLLM(n=7)\n",
"chain = prompt | llm"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "080964af-3e2d-4573-85cb-0d7cc58a6f42",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n",
"{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n",
"{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}\n",
"{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\\nHuman: hello there!']}}}\n",
"{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}\n",
"{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}\n",
"{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}\n",
"{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}\n"
]
}
],
"source": [
"idx = 0\n",
"async for event in chain.astream_events({\"input\": \"hello there!\"}, version=\"v1\"):\n",
" print(event)\n",
" idx += 1\n",
" if idx > 7:\n",
" # Truncate\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "a85e848a-5316-4318-b770-3f8fd34f4231",
"metadata": {},
"source": [
"## Contributing\n",
"\n",
"We appreciate all chat model integration contributions. \n",
"\n",
"Here's a checklist to help make sure your contribution gets added to LangChain:\n",
"\n",
"Documentation:\n",
"\n",
"* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [APIReference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).\n",
"* The class doc-string for the model contains a link to the model API if the model is powered by a service.\n",
"\n",
"Tests:\n",
"\n",
"* [ ] Add unit or integration tests to the overridden methods. Verify that `invoke`, `ainvoke`, `batch`, `stream` work if you've over-ridden the corresponding code.\n",
"\n",
"Streaming (if you're implementing it):\n",
"\n",
"* [ ] Make sure to invoke the `on_llm_new_token` callback\n",
"* [ ] `on_llm_new_token` is invoked BEFORE yielding the chunk\n",
"\n",
"Stop Token Behavior:\n",
"\n",
"* [ ] Stop token should be respected\n",
"* [ ] Stop token should be INCLUDED as part of the response\n",
"\n",
"Secret API Keys:\n",
"\n",
"* [ ] If your model connects to an API it will likely accept API keys as part of its initialization. Use Pydantic's `SecretStr` type for secrets, so they don't get accidentally printed out when folks print the model."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}