You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/docs/integrations/document_transformers/cross_encoder_reranker.ipynb

272 lines
9.9 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "fc0db1bc",
"metadata": {},
"source": [
"# Cross Encoder Reranker\n",
"\n",
"This notebook shows how to implement reranker in a retriever with your own cross encoder from [Hugging Face cross encoder models](https://huggingface.co/cross-encoder) or Hugging Face models that implements cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)). `SagemakerEndpointCrossEncoder` enables you to use these HuggingFace models loaded on Sagemaker.\n",
"\n",
"This builds on top of ideas in the [ContextualCompressionRetriever](/docs/modules/data_connection/retrievers/contextual_compression/). Overall structure of this document came from [Cohere Reranker documentation](/docs/integrations/retrievers/cohere-reranker).\n",
"\n",
"For more about why cross encoder can be used as reranking mechanism in conjunction with embeddings for better retrieval, refer to [Hugging Face Cross-Encoders documentation](https://www.sbert.net/examples/applications/cross-encoder/README.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37bd138-4f3c-4d2c-bc4b-be705ce27a09",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install faiss sentence_transformers\n",
"\n",
"# OR (depending on Python version)\n",
"\n",
"#!pip install faiss-cpu sentence_transformers"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "28e8dc12",
"metadata": {},
"outputs": [],
"source": [
"# Helper function for printing docs\n",
"\n",
"\n",
"def pretty_print_docs(docs):\n",
" print(\n",
" f\"\\n{'-' * 100}\\n\".join(\n",
" [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n",
" )\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "6fa3d916",
"metadata": {
"tags": []
},
"source": [
"## Set up the base vector store retriever\n",
"Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fbcc58f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
"from langchain_community.vectorstores import FAISS\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"\n",
"documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
"texts = text_splitter.split_documents(documents)\n",
"embeddingsModel = HuggingFaceEmbeddings(\n",
" model_name=\"sentence-transformers/msmarco-distilbert-dot-v5\"\n",
")\n",
"retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(\n",
" search_kwargs={\"k\": 20}\n",
")\n",
"\n",
"query = \"What is the plan for the economy?\"\n",
"docs = retriever.invoke(query)\n",
"pretty_print_docs(docs)"
]
},
{
"cell_type": "markdown",
"id": "b7648612",
"metadata": {},
"source": [
"## Doing reranking with CrossEncoderReranker\n",
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. `CrossEncoderReranker` uses `HuggingFaceCrossEncoder` to rerank the returned results."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "9a658023",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Document 1:\n",
"\n",
"More infrastructure and innovation in America. \n",
"\n",
"More goods moving faster and cheaper in America. \n",
"\n",
"More jobs where you can earn a good living in America. \n",
"\n",
"And instead of relying on foreign supply chains, lets make it in America. \n",
"\n",
"Economists call it “increasing the productive capacity of our economy.” \n",
"\n",
"I call it building a better America. \n",
"\n",
"My plan to fight inflation will lower your costs and lower the deficit.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 2:\n",
"\n",
"Second cut energy costs for families an average of $500 a year by combatting climate change. \n",
"\n",
"Lets provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double Americas clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because youll never have to pay at the gas pump again.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 3:\n",
"\n",
"Look at cars. \n",
"\n",
"Last year, there werent enough semiconductors to make all the cars that people wanted to buy. \n",
"\n",
"And guess what, prices of automobiles went up. \n",
"\n",
"So—we have a choice. \n",
"\n",
"One way to fight inflation is to drive down wages and make Americans poorer. \n",
"\n",
"I have a better plan to fight inflation. \n",
"\n",
"Lower your costs, not your wages. \n",
"\n",
"Make more cars and semiconductors in America. \n",
"\n",
"More infrastructure and innovation in America. \n",
"\n",
"More goods moving faster and cheaper in America.\n"
]
}
],
"source": [
"from langchain.retrievers import ContextualCompressionRetriever\n",
"from langchain.retrievers.document_compressors import CrossEncoderReranker\n",
"from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n",
"\n",
"model = HuggingFaceCrossEncoder(model_name=\"BAAI/bge-reranker-base\")\n",
"compressor = CrossEncoderReranker(model=model, top_n=3)\n",
"compression_retriever = ContextualCompressionRetriever(\n",
" base_compressor=compressor, base_retriever=retriever\n",
")\n",
"\n",
"compressed_docs = compression_retriever.invoke(\"What is the plan for the economy?\")\n",
"pretty_print_docs(compressed_docs)"
]
},
{
"cell_type": "markdown",
"id": "419a2bf3-de4b-4c4d-9a40-4336552f604c",
"metadata": {},
"source": [
"## Uploading Hugging Face model to SageMaker endpoint\n",
"\n",
"Here is a sample `inference.py` for creating an endpoint that works with `SagemakerEndpointCrossEncoder`. For more details with step-by-step guidance, refer to [this article](https://huggingface.co/blog/kchoe/deploy-any-huggingface-model-to-sagemaker). \n",
"\n",
"It downloads Hugging Face model on the fly, so you do not need to keep the model artifacts such as `pytorch_model.bin` in your `model.tar.gz`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e579c743-40c3-432f-9483-0982e2808f9a",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import logging\n",
"from typing import List\n",
"\n",
"import torch\n",
"from sagemaker_inference import encoder\n",
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
"\n",
"PAIRS = \"pairs\"\n",
"SCORES = \"scores\"\n",
"\n",
"\n",
"class CrossEncoder:\n",
" def __init__(self) -> None:\n",
" self.device = (\n",
" torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
" )\n",
" logging.info(f\"Using device: {self.device}\")\n",
" model_name = \"BAAI/bge-reranker-base\"\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
" self.model = self.model.to(self.device)\n",
"\n",
" def __call__(self, pairs: List[List[str]]) -> List[float]:\n",
" with torch.inference_mode():\n",
" inputs = self.tokenizer(\n",
" pairs,\n",
" padding=True,\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" max_length=512,\n",
" )\n",
" inputs = inputs.to(self.device)\n",
" scores = (\n",
" self.model(**inputs, return_dict=True)\n",
" .logits.view(\n",
" -1,\n",
" )\n",
" .float()\n",
" )\n",
"\n",
" return scores.detach().cpu().tolist()\n",
"\n",
"\n",
"def model_fn(model_dir: str) -> CrossEncoder:\n",
" try:\n",
" return CrossEncoder()\n",
" except Exception:\n",
" logging.exception(f\"Failed to load model from: {model_dir}\")\n",
" raise\n",
"\n",
"\n",
"def transform_fn(\n",
" cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str\n",
") -> bytes:\n",
" payload = json.loads(input_data)\n",
" model_output = cross_encoder(**payload)\n",
" output = {SCORES: model_output}\n",
" return encoder.encode(output, accept)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}