pull/20080/head
Eugene Yurtsev 2 months ago
parent 11b0d5a157
commit feea422d7c

@ -3,7 +3,6 @@ from __future__ import annotations
import contextlib
import enum
import json
import logging
import uuid
from typing import (
@ -90,9 +89,8 @@ SUPPORTED_OPERATORS = (
)
def _get_embedding_collection_store(
vector_dimension: Optional[int] = None, *, use_jsonb: bool = True
) -> Any:
def _get_embedding_collection_store(vector_dimension: int) -> Any:
"""Get the Embedding and Collection store classes."""
global _classes
if _classes is not None:
return _classes
@ -141,60 +139,35 @@ def _get_embedding_collection_store(
created = True
return collection, created
if use_jsonb:
# TODO(PRIOR TO LANDING): Create a gin index on the cmetadata field
class EmbeddingStore(BaseModel):
"""Embedding store."""
class EmbeddingStore(BaseModel):
"""Embedding store."""
__tablename__ = "langchain_pg_embedding"
__tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSONB, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
__table_args__ = (
sqlalchemy.Index(
"ix_cmetadata_gin",
"cmetadata",
postgresql_using="gin",
postgresql_ops={"cmetadata": "jsonb_path_ops"},
),
)
else:
# For backwards comaptibilty with older versions of pgvector
# This should be removed in the future (remove during migration)
class EmbeddingStore(BaseModel): # type: ignore[no-redef]
"""Embedding store."""
__tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSONB, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
__table_args__ = (
sqlalchemy.Index(
"ix_cmetadata_gin",
"cmetadata",
postgresql_using="gin",
postgresql_ops={"cmetadata": "jsonb_path_ops"},
),
)
_classes = (EmbeddingStore, CollectionStore)
@ -227,11 +200,6 @@ class PGVector(VectorStore):
pre_delete_collection: If True, will delete the collection if it exists.
(default: False). Useful for testing.
engine_args: SQLAlchemy's create engine arguments.
use_jsonb: Use JSONB instead of JSON for metadata. (default: True)
Strongly discouraged from using JSON as it's not as efficient
for querying.
It's provided here for backwards compatibility with older versions,
and will be removed in the future.
create_extension: If True, will create the vector extension if it doesn't exist.
disabling creation is useful when using ReadOnly Databases.
@ -249,7 +217,6 @@ class PGVector(VectorStore):
documents=docs,
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
use_jsonb=True,
)
"""
@ -267,7 +234,6 @@ class PGVector(VectorStore):
*,
connection: Optional[sqlalchemy.engine.Connection] = None,
engine_args: Optional[dict[str, Any]] = None,
use_jsonb: bool = False,
create_extension: bool = True,
) -> None:
"""Initialize the PGVector store."""
@ -282,30 +248,8 @@ class PGVector(VectorStore):
self.override_relevance_score_fn = relevance_score_fn
self.engine_args = engine_args or {}
self._bind = connection if connection else self._create_engine()
self.use_jsonb = use_jsonb
self.create_extension = create_extension
if not use_jsonb:
# Replace with a deprecation warning.
warn_deprecated(
"0.0.29",
pending=True,
message=(
"Please use JSONB instead of JSON for metadata. "
"This change will allow for more efficient querying that "
"involves filtering based on metadata."
"Please note that filtering operators have been changed "
"when using JSOB metadata to be prefixed with a $ sign "
"to avoid name collisions with columns. "
"If you're using an existing database, you will need to create a"
"db migration for your metadata column to be JSONB and update your "
"queries to use the new operators. "
),
alternative=(
"Instantiate with use_jsonb=True to use JSONB instead "
"of JSON for metadata."
),
)
self.__post_init__()
def __post_init__(
@ -316,7 +260,7 @@ class PGVector(VectorStore):
self.create_vector_extension()
EmbeddingStore, CollectionStore = _get_embedding_collection_store(
self._embedding_length, use_jsonb=self.use_jsonb
self._embedding_length
)
self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore
@ -434,8 +378,6 @@ class PGVector(VectorStore):
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
connection_string: Optional[str] = None,
pre_delete_collection: bool = False,
*,
use_jsonb: bool = False,
**kwargs: Any,
) -> PGVector:
if ids is None:
@ -452,7 +394,6 @@ class PGVector(VectorStore):
embedding_function=embedding,
distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection,
use_jsonb=use_jsonb,
**kwargs,
)
@ -711,99 +652,6 @@ class PGVector(VectorStore):
else:
raise NotImplementedError()
def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
"""Deprecated functionality.
This is for backwards compatibility with the JSON based schema for metadata.
It uses incorrect operator syntax (operators are not prefixed with $).
This implementation is not efficient, and has bugs associated with
the way that it handles numeric filter clauses.
"""
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
value_case_insensitive = {k.lower(): v for k, v in value.items()}
if IN in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_(
value_case_insensitive[IN]
)
elif NIN in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in(
value_case_insensitive[NIN]
)
elif BETWEEN in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between(
str(value_case_insensitive[BETWEEN][0]),
str(value_case_insensitive[BETWEEN][1]),
)
elif GT in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str(
value_case_insensitive[GT]
)
elif LT in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str(
value_case_insensitive[LT]
)
elif NE in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str(
value_case_insensitive[NE]
)
elif EQ in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
value_case_insensitive[EQ]
)
elif LIKE in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like(
value_case_insensitive[LIKE]
)
elif CONTAINS in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains(
value_case_insensitive[CONTAINS]
)
elif OR in map(str.lower, value):
or_clauses = [
self._create_filter_clause(key, sub_value)
for sub_value in value_case_insensitive[OR]
]
filter_by_metadata = sqlalchemy.or_(*or_clauses)
elif AND in map(str.lower, value):
and_clauses = [
self._create_filter_clause(key, sub_value)
for sub_value in value_case_insensitive[AND]
]
filter_by_metadata = sqlalchemy.and_(*and_clauses)
else:
filter_by_metadata = None
return filter_by_metadata
def _create_filter_clause_json_deprecated(
self, filter: Any
) -> List[SQLColumnExpression]:
"""Convert filters from IR to SQL clauses.
**DEPRECATED** This functionality will be deprecated in the future.
It implements translation of filters for a schema that uses JSON
for metadata rather than the JSONB field which is more efficient
for querying.
"""
filter_clauses = []
for key, value in filter.items():
if isinstance(value, dict):
filter_by_metadata = self._create_filter_clause_deprecated(key, value)
if filter_by_metadata is not None:
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
value
)
filter_clauses.append(filter_by_metadata)
return filter_clauses
def _create_filter_clause(self, filters: Any) -> Any:
"""Convert LangChain IR filter representation to matching SQLAlchemy clauses.
@ -904,14 +752,9 @@ class PGVector(VectorStore):
filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
if filter:
if self.use_jsonb:
filter_clauses = self._create_filter_clause(filter)
if filter_clauses is not None:
filter_by.append(filter_clauses)
else:
# Old way of doing things
filter_clauses = self._create_filter_clause_json_deprecated(filter)
filter_by.extend(filter_clauses)
filter_clauses = self._create_filter_clause(filter)
if filter_clauses is not None:
filter_by.append(filter_clauses)
_type = self.EmbeddingStore
@ -964,8 +807,6 @@ class PGVector(VectorStore):
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
*,
use_jsonb: bool = False,
**kwargs: Any,
) -> PGVector:
"""
@ -985,7 +826,6 @@ class PGVector(VectorStore):
collection_name=collection_name,
distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection,
use_jsonb=use_jsonb,
**kwargs,
)
@ -1087,8 +927,6 @@ class PGVector(VectorStore):
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
*,
use_jsonb: bool = False,
**kwargs: Any,
) -> PGVector:
"""
@ -1112,7 +950,6 @@ class PGVector(VectorStore):
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
use_jsonb=use_jsonb,
**kwargs,
)

Loading…
Cancel
Save