diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py index f7983a5968..87c0bdca4b 100644 --- a/libs/community/langchain_community/embeddings/gpt4all.py +++ b/libs/community/langchain_community/embeddings/gpt4all.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, root_validator @@ -14,9 +14,18 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): from langchain_community.embeddings import GPT4AllEmbeddings - embeddings = GPT4AllEmbeddings() + model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf" + gpt4all_kwargs = {'allow_download': 'True'} + embeddings = GPT4AllEmbeddings( + model_name=model_name, + gpt4all_kwargs=gpt4all_kwargs + ) """ + model_name: str + n_threads: Optional[int] = None + device: Optional[str] = "cpu" + gpt4all_kwargs: Optional[dict] = {} client: Any #: :meta private: @root_validator() @@ -26,7 +35,12 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): try: from gpt4all import Embed4All - values["client"] = Embed4All() + values["client"] = Embed4All( + model_name=values["model_name"], + n_threads=values.get("n_threads"), + device=values.get("device"), + **values.get("gpt4all_kwargs"), + ) except ImportError: raise ImportError( "Could not import gpt4all library. "