Allowing additional params for OpenAIEmbeddings. (#7752)

(#7654)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/7910/head
Hanit 11 months ago committed by GitHub
parent 862268175e
commit 0d23c0c82a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,7 +36,7 @@ from langchain.schema import (
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__)
@ -155,7 +155,7 @@ class JinaChat(BaseChatModel):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls._all_required_field_names()
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:

@ -41,7 +41,7 @@ from langchain.schema.messages import (
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
if TYPE_CHECKING:
import tiktoken
@ -205,7 +205,7 @@ class ChatOpenAI(BaseChatModel):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls._all_required_field_names()
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:

@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import warnings
from typing import (
Any,
Callable,
@ -16,7 +17,7 @@ from typing import (
)
import numpy as np
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
AsyncRetrying,
before_sleep_log,
@ -27,7 +28,7 @@ from tenacity import (
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__)
@ -193,12 +194,40 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
when tiktoken is called, you can specify a model name to use here."""
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -261,6 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"api_base": self.openai_api_base,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
**self.model_kwargs,
}
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
openai_args["engine"] = self.deployment

@ -28,7 +28,7 @@ from langchain.callbacks.manager import (
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__)
@ -186,13 +186,13 @@ class BaseOpenAI(BaseLLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls._all_required_field_names()
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""

@ -7,6 +7,7 @@ from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.utils import get_pydantic_field_names
if TYPE_CHECKING:
from langchain.callbacks.manager import Callbacks
@ -246,9 +247,8 @@ class BaseLanguageModel(Serializable, ABC):
@classmethod
def _all_required_field_names(cls) -> Set:
all_required_field_names = set()
for field in cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names
"""DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names.
"""
return get_pydantic_field_names(cls)

@ -4,7 +4,7 @@ import datetime
import importlib
import os
from importlib.metadata import version
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from packaging.version import parse
from requests import HTTPError, Response
@ -183,3 +183,16 @@ def check_package_version(
f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}."
)
def get_pydantic_field_names(pydantic_cls: Any) -> Set:
"""Get field names, including aliases, for a pydantic class.
Args:
pydantic_cls: Pydantic class."""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names

4
poetry.lock generated

@ -12847,7 +12847,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"]
javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"]
openai = ["openai", "tiktoken"]
@ -12857,4 +12857,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "cae082b5f45fe5564de8320fd1f39370f5e59389bf3aaa72291be531bce2e705"
content-hash = "f322b36103013bd59c34dddadf84209292ea61ed73bd26fbfa355d372011238b"

@ -362,6 +362,7 @@ extended_testing = [
"openai",
"sympy",
"rapidfuzz",
"openai",
"rank_bm25",
]

@ -0,0 +1,20 @@
import os
import pytest
from langchain.embeddings.openai import OpenAIEmbeddings
os.environ["OPENAI_API_KEY"] = "foo"
@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
OpenAIEmbeddings(model_kwargs={"model": "foo"})
@pytest.mark.requires("openai")
def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = OpenAIEmbeddings(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}

@ -0,0 +1,28 @@
import os
import pytest
from langchain.llms.openai import OpenAI
os.environ["OPENAI_API_KEY"] = "foo"
@pytest.mark.requires("openai")
def test_openai_model_param() -> None:
llm = OpenAI(model="foo")
assert llm.model_name == "foo"
llm = OpenAI(model_name="foo")
assert llm.model_name == "foo"
@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
OpenAI(model_kwargs={"model_name": "foo"})
@pytest.mark.requires("openai")
def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = OpenAI(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}
Loading…
Cancel
Save